mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-31 17:31:36 +00:00
lnutil.UpdateAddHtlc: use attrs instead of old-style namedtuple
This commit is contained in:
parent
444610452e
commit
ea0981ebeb
4 changed files with 23 additions and 19 deletions
|
@ -32,6 +32,7 @@ import time
|
|||
import threading
|
||||
|
||||
from aiorpcx import NetAddress
|
||||
import attr
|
||||
|
||||
from . import ecc
|
||||
from . import constants
|
||||
|
@ -434,7 +435,7 @@ class Channel(Logger):
|
|||
assert isinstance(htlc, UpdateAddHtlc)
|
||||
self._check_can_pay(htlc.amount_msat)
|
||||
if htlc.htlc_id is None:
|
||||
htlc = htlc._replace(htlc_id=self.hm.get_next_htlc_id(LOCAL))
|
||||
htlc = attr.evolve(htlc, htlc_id=self.hm.get_next_htlc_id(LOCAL))
|
||||
with self.db_lock:
|
||||
self.hm.send_htlc(htlc)
|
||||
self.logger.info("add_htlc")
|
||||
|
@ -452,7 +453,7 @@ class Channel(Logger):
|
|||
htlc = UpdateAddHtlc(**htlc)
|
||||
assert isinstance(htlc, UpdateAddHtlc)
|
||||
if htlc.htlc_id is None: # used in unit tests
|
||||
htlc = htlc._replace(htlc_id=self.hm.get_next_htlc_id(REMOTE))
|
||||
htlc = attr.evolve(htlc, htlc_id=self.hm.get_next_htlc_id(REMOTE))
|
||||
if 0 <= self.available_to_spend(REMOTE) < htlc.amount_msat:
|
||||
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
|
||||
f' Available at remote: {self.available_to_spend(REMOTE)},' +\
|
||||
|
|
|
@ -878,21 +878,21 @@ def format_short_channel_id(short_channel_id: Optional[bytes]):
|
|||
+ 'x' + str(int.from_bytes(short_channel_id[6:], 'big'))
|
||||
|
||||
|
||||
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id', 'timestamp'])):
|
||||
# note: typing.NamedTuple cannot be used because we are overriding __new__
|
||||
@attr.s(frozen=True)
|
||||
class UpdateAddHtlc:
|
||||
amount_msat = attr.ib(type=int, kw_only=True)
|
||||
payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes)
|
||||
cltv_expiry = attr.ib(type=int, kw_only=True)
|
||||
timestamp = attr.ib(type=int, kw_only=True)
|
||||
htlc_id = attr.ib(type=int, kw_only=True, default=None)
|
||||
|
||||
__slots__ = ()
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# if you pass a hex-string as payment_hash, it is decoded to bytes.
|
||||
# Bytes can't be saved to disk, so we save hex-strings.
|
||||
if len(args) > 0:
|
||||
args = list(args)
|
||||
if type(args[1]) is str:
|
||||
args[1] = bfh(args[1])
|
||||
return super().__new__(cls, *args)
|
||||
if type(kwargs['payment_hash']) is str:
|
||||
kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
|
||||
if len(args) < 4 and 'htlc_id' not in kwargs:
|
||||
kwargs['htlc_id'] = None
|
||||
return super().__new__(cls, **kwargs)
|
||||
@classmethod
|
||||
def from_tuple(cls, amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc':
|
||||
return cls(amount_msat=amount_msat,
|
||||
payment_hash=payment_hash,
|
||||
cltv_expiry=cltv_expiry,
|
||||
htlc_id=htlc_id,
|
||||
timestamp=timestamp)
|
||||
|
||||
def to_tuple(self):
|
||||
return (self.amount_msat, self.payment_hash, self.cltv_expiry, self.htlc_id, self.timestamp)
|
||||
|
|
|
@ -277,6 +277,9 @@ class MyEncoder(json.JSONEncoder):
|
|||
def default(self, obj):
|
||||
# note: this does not get called for namedtuples :( https://bugs.python.org/issue30343
|
||||
from .transaction import Transaction, TxOutput
|
||||
from .lnutil import UpdateAddHtlc
|
||||
if isinstance(obj, UpdateAddHtlc):
|
||||
return obj.to_tuple()
|
||||
if isinstance(obj, Transaction):
|
||||
return obj.serialize()
|
||||
if isinstance(obj, TxOutput):
|
||||
|
|
|
@ -1079,7 +1079,7 @@ class WalletDB(JsonDB):
|
|||
# note: for performance, "deserialize=False" so that we will deserialize these on-demand
|
||||
v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items())
|
||||
elif key == 'adds':
|
||||
v = dict((k, UpdateAddHtlc(*x)) for k, x in v.items())
|
||||
v = dict((k, UpdateAddHtlc.from_tuple(*x)) for k, x in v.items())
|
||||
elif key == 'fee_updates':
|
||||
v = dict((k, FeeUpdate(**x)) for k, x in v.items())
|
||||
elif key == 'tx_fees':
|
||||
|
|
Loading…
Add table
Reference in a new issue