lnchannel: move fee update logic to lnhtlc (and hopefully fix it)

This commit is contained in:
SomberNight 2019-07-27 01:05:37 +02:00 committed by ThomasV
parent 3d7f7dfc82
commit 087994e39a
5 changed files with 150 additions and 84 deletions

View file

@ -91,11 +91,6 @@ def str_bytes_dict_from_save(x):
def str_bytes_dict_to_save(x): def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()} return {str(k): bh2u(v) for k, v in x.items()}
def deserialize_feeupdate(x):
return FeeUpdate(rate=x['rate'], ctn={LOCAL:x['ctn'][str(int(LOCAL))], REMOTE:x['ctn'][str(int(REMOTE))]})
def serialize_feeupdate(x):
return {'rate':x.rate, 'ctn': {int(LOCAL):x.ctn[LOCAL], int(REMOTE):x.ctn[REMOTE]}}
class Channel(Logger): class Channel(Logger):
# note: try to avoid naming ctns/ctxs/etc as "current" and "pending". # note: try to avoid naming ctns/ctxs/etc as "current" and "pending".
@ -110,7 +105,7 @@ class Channel(Logger):
except: except:
return super().diagnostic_name() return super().diagnostic_name()
def __init__(self, state, *, sweep_address=None, name=None, lnworker=None): def __init__(self, state, *, sweep_address=None, name=None, lnworker=None, initial_feerate=None):
self.name = name self.name = name
Logger.__init__(self) Logger.__init__(self)
self.lnworker = lnworker self.lnworker = lnworker
@ -137,7 +132,6 @@ class Channel(Logger):
self.short_channel_id_predicted = self.short_channel_id self.short_channel_id_predicted = self.short_channel_id
self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {})) self.onion_keys = str_bytes_dict_from_save(state.get('onion_keys', {}))
self.force_closed = state.get('force_closed') self.force_closed = state.get('force_closed')
self.fee_updates = [deserialize_feeupdate(x) if type(x) is not FeeUpdate else x for x in state.get('fee_updates')] # populated with initial fee
# FIXME this is a tx serialised in the custom electrum partial tx format. # FIXME this is a tx serialised in the custom electrum partial tx format.
# we should not persist txns in this format. we should persist htlcs, and be able to derive # we should not persist txns in this format. we should persist htlcs, and be able to derive
@ -148,7 +142,8 @@ class Channel(Logger):
log = state.get('log') log = state.get('log')
self.hm = HTLCManager(local_ctn=self.config[LOCAL].ctn, self.hm = HTLCManager(local_ctn=self.config[LOCAL].ctn,
remote_ctn=self.config[REMOTE].ctn, remote_ctn=self.config[REMOTE].ctn,
log=log) log=log,
initial_feerate=initial_feerate)
self._is_funding_txo_spent = None # "don't know" self._is_funding_txo_spent = None # "don't know"
@ -158,42 +153,17 @@ class Channel(Logger):
self.remote_commitment = None self.remote_commitment = None
self.sweep_info = {} self.sweep_info = {}
def pending_fee(self): def get_feerate(self, subject, ctn):
""" return FeeUpdate that has not been commited on any side""" return self.hm.get_feerate(subject, ctn)
for f in self.fee_updates:
if f.ctn[LOCAL] is None and f.ctn[REMOTE] is None:
return f
def locally_pending_fee(self): def get_oldest_unrevoked_feerate(self, subject):
""" return FeeUpdate that been commited remotely and is still pending locally (should used if we are initiator)""" return self.hm.get_feerate_in_oldest_unrevoked_ctx(subject)
for f in self.fee_updates:
if f.ctn[LOCAL] is None and f.ctn[REMOTE] is not None:
return f
def remotely_pending_fee(self): def get_latest_feerate(self, subject):
""" return FeeUpdate that been commited locally and is still pending remotely (should be used if we are not initiator)""" return self.hm.get_feerate_in_latest_ctx(subject)
for f in self.fee_updates:
if f.ctn[LOCAL] is not None and f.ctn[REMOTE] is None:
return f
def get_feerate(self, subject, target_ctn):
next_ctn = self.config[subject].ctn + 1
assert target_ctn <= next_ctn
result = self.fee_updates[0]
for f in self.fee_updates[1:]:
ctn = f.ctn[subject]
if ctn is None:
ctn = next_ctn
if ctn > result.ctn[subject] and ctn <= target_ctn:
best_ctn = ctn
result = f
return result.rate
def get_current_feerate(self, subject):
return self.get_feerate(subject, self.config[subject].ctn)
def get_next_feerate(self, subject): def get_next_feerate(self, subject):
return self.get_feerate(subject, self.config[subject].ctn + 1) return self.hm.get_feerate_in_next_ctx(subject)
def get_payments(self): def get_payments(self):
out = {} out = {}
@ -398,10 +368,6 @@ class Channel(Logger):
current_htlc_signatures=htlc_sigs_string, current_htlc_signatures=htlc_sigs_string,
got_sig_for_next=True) got_sig_for_next=True)
# if a fee update was acked, then we add it locally
f = self.locally_pending_fee()
if f: f.ctn[LOCAL] = next_local_ctn
self.set_local_commitment(pending_local_commitment) self.set_local_commitment(pending_local_commitment)
def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool, ctx) -> int: def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool, ctx) -> int:
@ -476,8 +442,6 @@ class Channel(Logger):
prev_remote_commitment = self.pending_commitment(REMOTE) prev_remote_commitment = self.pending_commitment(REMOTE)
self.config[REMOTE].revocation_store.add_next_entry(revocation.per_commitment_secret) self.config[REMOTE].revocation_store.add_next_entry(revocation.per_commitment_secret)
##### start applying fee/htlc changes ##### start applying fee/htlc changes
f = self.pending_fee()
if f: f.ctn[REMOTE] = next_ctn
next_point = self.config[REMOTE].next_per_commitment_point next_point = self.config[REMOTE].next_per_commitment_point
self.hm.recv_rev() self.hm.recv_rev()
self.config[REMOTE]=self.config[REMOTE]._replace( self.config[REMOTE]=self.config[REMOTE]._replace(
@ -540,7 +504,7 @@ class Channel(Logger):
- calc_onchain_fees( - calc_onchain_fees(
# TODO should we include a potential new htlc, when we are called from receive_htlc? # TODO should we include a potential new htlc, when we are called from receive_htlc?
len(self.included_htlcs(subject, SENT) + self.included_htlcs(subject, RECEIVED)), len(self.included_htlcs(subject, SENT) + self.included_htlcs(subject, RECEIVED)),
self.get_current_feerate(subject), self.get_latest_feerate(subject),
self.constraints.is_initiator, self.constraints.is_initiator,
)[subject] )[subject]
@ -627,13 +591,14 @@ class Channel(Logger):
def pending_local_fee(self): def pending_local_fee(self):
return self.constraints.capacity - sum(x[2] for x in self.pending_commitment(LOCAL).outputs()) return self.constraints.capacity - sum(x[2] for x in self.pending_commitment(LOCAL).outputs())
def update_fee(self, feerate, initiator): def update_fee(self, feerate: int, from_us: bool):
f = self.pending_fee() # feerate uses sat/kw
if f: if self.constraints.is_initiator != from_us:
f.rate = feerate raise Exception(f"Cannot update_fee: wrong initiator. us: {from_us}")
return if from_us:
f = FeeUpdate(rate=feerate, ctn={LOCAL:None, REMOTE:None}) self.hm.send_update_fee(feerate)
self.fee_updates.append(f) else:
self.hm.recv_update_fee(feerate)
def to_save(self): def to_save(self):
to_save = { to_save = {
@ -645,7 +610,6 @@ class Channel(Logger):
"funding_outpoint": self.funding_outpoint, "funding_outpoint": self.funding_outpoint,
"node_id": self.node_id, "node_id": self.node_id,
"remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked), "remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked),
"fee_updates": self.fee_updates,
"log": self.hm.to_save(), "log": self.hm.to_save(),
"onion_keys": str_bytes_dict_to_save(self.onion_keys), "onion_keys": str_bytes_dict_to_save(self.onion_keys),
"force_closed": self.force_closed, "force_closed": self.force_closed,
@ -659,8 +623,6 @@ class Channel(Logger):
for k, v in to_save_ref.items(): for k, v in to_save_ref.items():
if isinstance(v, tuple): if isinstance(v, tuple):
serialized_channel[k] = namedtuples_to_dict(v) serialized_channel[k] = namedtuples_to_dict(v)
elif k == 'fee_updates':
serialized_channel[k] = [serialize_feeupdate(x) for x in v]
else: else:
serialized_channel[k] = v serialized_channel[k] = v
dumped = ChannelJsonEncoder().encode(serialized_channel) dumped = ChannelJsonEncoder().encode(serialized_channel)

View file

@ -7,13 +7,13 @@ from .util import bh2u
class HTLCManager: class HTLCManager:
def __init__(self, *, local_ctn=0, remote_ctn=0, log=None): def __init__(self, *, local_ctn=0, remote_ctn=0, log=None, initial_feerate=None):
# self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub # self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub
self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn} self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn}
# ctx_pending[sub] is True iff sub has received commitment_signed but did not send revoke_and_ack (sub has multiple unrevoked ctxs) # ctx_pending[sub] is True iff sub has received commitment_signed but did not send revoke_and_ack (sub has multiple unrevoked ctxs)
self.ctx_pending = {LOCAL:False, REMOTE: False} # FIXME does this need to be persisted? self.ctx_pending = {LOCAL:False, REMOTE: False} # FIXME does this need to be persisted?
if log is None: if log is None:
initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}} initial = {'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}, 'fee_updates': []}
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)} log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
else: else:
assert type(log) is dict assert type(log) is dict
@ -25,6 +25,14 @@ class HTLCManager:
log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()} log[sub]['locked_in'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['locked_in'].items()}
log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()} log[sub]['settles'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['settles'].items()}
log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()} log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
# "side who initiated fee update" -> action -> list of FeeUpdates
log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
# maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None:
assert type(initial_feerate) is int
for sub in (LOCAL, REMOTE):
if not log[sub]['fee_updates']:
log[sub]['fee_updates'].append(FeeUpdate(initial_feerate, ctns={LOCAL:0, REMOTE:0}))
self.log = log self.log = log
def ctn_latest(self, sub: HTLCOwner) -> int: def ctn_latest(self, sub: HTLCOwner) -> int:
@ -39,8 +47,12 @@ class HTLCManager:
for htlc_id, htlc in log[sub]['adds'].items(): for htlc_id, htlc in log[sub]['adds'].items():
d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:] d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
log[sub]['adds'] = d log[sub]['adds'] = d
# fee_updates
log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
return log return log
##### Actions on channel:
def channel_open_finished(self): def channel_open_finished(self):
self.ctn = {LOCAL: 0, REMOTE: 0} self.ctn = {LOCAL: 0, REMOTE: 0}
self.ctx_pending = {LOCAL:False, REMOTE: False} self.ctx_pending = {LOCAL:False, REMOTE: False}
@ -68,6 +80,25 @@ class HTLCManager:
def recv_fail(self, htlc_id: int) -> None: def recv_fail(self, htlc_id: int) -> None:
self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None} self.log[LOCAL]['fails'][htlc_id] = {LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None}
def send_update_fee(self, feerate: int) -> None:
fee_update = FeeUpdate(rate=feerate,
ctns={LOCAL: None, REMOTE: self.ctn_latest(REMOTE) + 1})
self._new_feeupdate(fee_update, subject=LOCAL)
def recv_update_fee(self, feerate: int) -> None:
fee_update = FeeUpdate(rate=feerate,
ctns={LOCAL: self.ctn_latest(LOCAL) + 1, REMOTE: None})
self._new_feeupdate(fee_update, subject=REMOTE)
def _new_feeupdate(self, fee_update: FeeUpdate, subject: HTLCOwner) -> None:
# overwrite last fee update if not yet committed to by anyone; otherwise append
last_fee_update = self.log[subject]['fee_updates'][-1]
if (last_fee_update.ctns[LOCAL] is None or last_fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL)) \
and (last_fee_update.ctns[REMOTE] is None or last_fee_update.ctns[REMOTE] > self.ctn_latest(REMOTE)):
self.log[subject]['fee_updates'][-1] = fee_update
else:
self.log[subject]['fee_updates'].append(fee_update)
def send_ctx(self) -> None: def send_ctx(self) -> None:
assert self.ctn_latest(REMOTE) == self.ctn[REMOTE], (self.ctn_latest(REMOTE), self.ctn[REMOTE]) assert self.ctn_latest(REMOTE) == self.ctn[REMOTE], (self.ctn_latest(REMOTE), self.ctn[REMOTE])
self.ctx_pending[REMOTE] = True self.ctx_pending[REMOTE] = True
@ -79,6 +110,7 @@ class HTLCManager:
def send_rev(self) -> None: def send_rev(self) -> None:
self.ctn[LOCAL] += 1 self.ctn[LOCAL] += 1
self.ctx_pending[LOCAL] = False self.ctx_pending[LOCAL] = False
# htlcs
for ctns in self.log[REMOTE]['locked_in'].values(): for ctns in self.log[REMOTE]['locked_in'].values():
if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
@ -86,10 +118,15 @@ class HTLCManager:
for ctns in self.log[LOCAL][log_action].values(): for ctns in self.log[LOCAL][log_action].values():
if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL): if ctns[REMOTE] is None and ctns[LOCAL] <= self.ctn_latest(LOCAL):
ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
# fee updates
for fee_update in self.log[REMOTE]['fee_updates']:
if fee_update.ctns[REMOTE] is None and fee_update.ctns[LOCAL] <= self.ctn_latest(LOCAL):
fee_update.ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
def recv_rev(self) -> None: def recv_rev(self) -> None:
self.ctn[REMOTE] += 1 self.ctn[REMOTE] += 1
self.ctx_pending[REMOTE] = False self.ctx_pending[REMOTE] = False
# htlcs
for ctns in self.log[LOCAL]['locked_in'].values(): for ctns in self.log[LOCAL]['locked_in'].values():
if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
@ -97,6 +134,12 @@ class HTLCManager:
for ctns in self.log[REMOTE][log_action].values(): for ctns in self.log[REMOTE][log_action].values():
if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE): if ctns[LOCAL] is None and ctns[REMOTE] <= self.ctn_latest(REMOTE):
ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
# fee updates
for fee_update in self.log[LOCAL]['fee_updates']:
if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE):
fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
##### Queries re HTLCs:
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,
ctn: int = None) -> Sequence[UpdateAddHtlc]: ctn: int = None) -> Sequence[UpdateAddHtlc]:
@ -186,3 +229,40 @@ class HTLCManager:
return [self.log[LOCAL]['adds'][htlc_id] return [self.log[LOCAL]['adds'][htlc_id]
for htlc_id, ctns in self.log[LOCAL]['settles'].items() for htlc_id, ctns in self.log[LOCAL]['settles'].items()
if ctns[LOCAL] == ctn] if ctns[LOCAL] == ctn]
##### Queries re Fees:
def get_feerate(self, subject: HTLCOwner, ctn: int) -> int:
"""Return feerate used in subject's commitment txn at ctn."""
ctn = max(0, ctn) # FIXME rm this
# only one party can update fees; use length of logs to figure out which:
assert not (len(self.log[LOCAL]['fee_updates']) > 1 and len(self.log[REMOTE]['fee_updates']) > 1)
fee_log = self.log[LOCAL]['fee_updates'] # type: Sequence[FeeUpdate]
if len(self.log[REMOTE]['fee_updates']) > 1:
fee_log = self.log[REMOTE]['fee_updates']
# binary search
left = 0
right = len(fee_log)
while True:
i = (left + right) // 2
ctn_at_i = fee_log[i].ctns[subject]
if right - left <= 1:
break
if ctn_at_i is None: # Nones can only be on the right end
right = i
continue
if ctn_at_i <= ctn: # among equals, we want the rightmost
left = i
else:
right = i
assert ctn_at_i <= ctn
return fee_log[i].rate
def get_feerate_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> int:
return self.get_feerate(subject=subject, ctn=self.ctn[subject])
def get_feerate_in_latest_ctx(self, subject: HTLCOwner) -> int:
return self.get_feerate(subject=subject, ctn=self.ctn_latest(subject))
def get_feerate_in_next_ctx(self, subject: HTLCOwner) -> int:
return self.get_feerate(subject=subject, ctn=self.ctn_latest(subject) + 1)

View file

@ -545,12 +545,12 @@ class Peer(Logger):
"remote_config": remote_config, "remote_config": remote_config,
"local_config": local_config, "local_config": local_config,
"constraints": ChannelConstraints(capacity=funding_sat, is_initiator=True, funding_txn_minimum_depth=funding_txn_minimum_depth), "constraints": ChannelConstraints(capacity=funding_sat, is_initiator=True, funding_txn_minimum_depth=funding_txn_minimum_depth),
"fee_updates": [FeeUpdate(rate=feerate, ctn={LOCAL:0, REMOTE:0})],
"remote_commitment_to_be_revoked": None, "remote_commitment_to_be_revoked": None,
} }
chan = Channel(chan_dict, chan = Channel(chan_dict,
sweep_address=self.lnworker.sweep_address, sweep_address=self.lnworker.sweep_address,
lnworker=self.lnworker) lnworker=self.lnworker,
initial_feerate=feerate)
sig_64, _ = chan.sign_next_commitment() sig_64, _ = chan.sign_next_commitment()
self.send_message("funding_created", self.send_message("funding_created",
temporary_channel_id=temp_channel_id, temporary_channel_id=temp_channel_id,
@ -634,11 +634,11 @@ class Peer(Logger):
"local_config": local_config, "local_config": local_config,
"constraints": ChannelConstraints(capacity=funding_sat, is_initiator=False, funding_txn_minimum_depth=min_depth), "constraints": ChannelConstraints(capacity=funding_sat, is_initiator=False, funding_txn_minimum_depth=min_depth),
"remote_commitment_to_be_revoked": None, "remote_commitment_to_be_revoked": None,
"fee_updates": [FeeUpdate(feerate, ctn={LOCAL:0, REMOTE:0})],
} }
chan = Channel(chan_dict, chan = Channel(chan_dict,
sweep_address=self.lnworker.sweep_address, sweep_address=self.lnworker.sweep_address,
lnworker=self.lnworker) lnworker=self.lnworker,
initial_feerate=feerate)
remote_sig = funding_created['signature'] remote_sig = funding_created['signature']
chan.receive_new_commitment(remote_sig, []) chan.receive_new_commitment(remote_sig, [])
sig_64, _ = chan.sign_next_commitment() sig_64, _ = chan.sign_next_commitment()
@ -1029,7 +1029,7 @@ class Peer(Logger):
# if there are no changes, we will not (and must not) send a new commitment # if there are no changes, we will not (and must not) send a new commitment
next_htlcs, latest_htlcs = chan.hm.get_htlcs_in_next_ctx(REMOTE), chan.hm.get_htlcs_in_latest_ctx(REMOTE) next_htlcs, latest_htlcs = chan.hm.get_htlcs_in_next_ctx(REMOTE), chan.hm.get_htlcs_in_latest_ctx(REMOTE)
if (next_htlcs == latest_htlcs if (next_htlcs == latest_htlcs
and chan.get_next_feerate(REMOTE) == chan.get_current_feerate(REMOTE)) \ and chan.get_next_feerate(REMOTE) == chan.get_latest_feerate(REMOTE)) \
or ctn_to_sign == self.sent_commitment_for_ctn_last[chan]: or ctn_to_sign == self.sent_commitment_for_ctn_last[chan]:
return return
self.logger.info(f'send_commitment. old number htlcs: {len(latest_htlcs)}, new number htlcs: {len(next_htlcs)}') self.logger.info(f'send_commitment. old number htlcs: {len(latest_htlcs)}, new number htlcs: {len(next_htlcs)}')
@ -1091,7 +1091,7 @@ class Peer(Logger):
chan = self.channels[channel_id] chan = self.channels[channel_id]
# make sure there were changes to the ctx, otherwise the remote peer is misbehaving # make sure there were changes to the ctx, otherwise the remote peer is misbehaving
if (chan.hm.get_htlcs_in_next_ctx(LOCAL) == chan.hm.get_htlcs_in_latest_ctx(LOCAL) if (chan.hm.get_htlcs_in_next_ctx(LOCAL) == chan.hm.get_htlcs_in_latest_ctx(LOCAL)
and chan.get_next_feerate(LOCAL) == chan.get_current_feerate(LOCAL)): and chan.get_next_feerate(LOCAL) == chan.get_latest_feerate(LOCAL)):
raise RemoteMisbehaving('received commitment_signed without pending changes') raise RemoteMisbehaving('received commitment_signed without pending changes')
# make sure ctn is new # make sure ctn is new
ctn_to_recv = chan.get_current_ctn(LOCAL) + 1 ctn_to_recv = chan.get_current_ctn(LOCAL) + 1

View file

@ -5,7 +5,7 @@
from enum import IntFlag, IntEnum from enum import IntFlag, IntEnum
import json import json
from collections import namedtuple from collections import namedtuple
from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict
import re import re
from .util import bfh, bh2u, inv_dict from .util import bfh, bh2u, inv_dict
@ -78,7 +78,21 @@ class RemoteConfig(NamedTuple):
current_per_commitment_point: Optional[bytes] current_per_commitment_point: Optional[bytes]
FeeUpdate = namedtuple("FeeUpdate", ["rate", "ctn"]) class FeeUpdate(NamedTuple):
rate: int # in sat/kw
ctns: Dict['HTLCOwner', Optional[int]]
@classmethod
def from_dict(cls, d: dict) -> 'FeeUpdate':
return FeeUpdate(rate=d['rate'],
ctns={LOCAL: d['ctns'][str(int(LOCAL))],
REMOTE: d['ctns'][str(int(REMOTE))]})
def to_dict(self) -> dict:
return {'rate': self.rate,
'ctns': {int(LOCAL): self.ctns[LOCAL],
int(REMOTE): self.ctns[REMOTE]}}
ChannelConstraints = namedtuple("ChannelConstraints", ["capacity", "is_initiator", "funding_txn_minimum_depth"]) ChannelConstraints = namedtuple("ChannelConstraints", ["capacity", "is_initiator", "funding_txn_minimum_depth"])

View file

@ -37,7 +37,7 @@ from electrum.logging import console_stderr_handler
one_bitcoin_in_msat = bitcoin.COIN * 1000 one_bitcoin_in_msat = bitcoin.COIN * 1000
def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate, is_initiator, local_amount, remote_amount, privkeys, other_pubkeys, seed, cur, nex, other_node_id, l_dust, r_dust, l_csv, r_csv): def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator, local_amount, remote_amount, privkeys, other_pubkeys, seed, cur, nex, other_node_id, l_dust, r_dust, l_csv, r_csv):
assert local_amount > 0 assert local_amount > 0
assert remote_amount > 0 assert remote_amount > 0
channel_id, _ = lnpeer.channel_id_from_funding_tx(funding_txid, funding_index) channel_id, _ = lnpeer.channel_id_from_funding_tx(funding_txid, funding_index)
@ -94,7 +94,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
is_initiator=is_initiator, is_initiator=is_initiator,
funding_txn_minimum_depth=3, funding_txn_minimum_depth=3,
), ),
"fee_updates": [FeeUpdate(rate=local_feerate, ctn={LOCAL:0, REMOTE:0})],
"node_id":other_node_id, "node_id":other_node_id,
"remote_commitment_to_be_revoked": None, "remote_commitment_to_be_revoked": None,
'onion_keys': {}, 'onion_keys': {},
@ -126,11 +125,16 @@ def create_test_channels(feerate=6000, local=None, remote=None):
alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big")) alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big")) bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
alice, bob = \ alice, bob = (
lnchannel.Channel( lnchannel.Channel(
create_channel_state(funding_txid, funding_index, funding_sat, feerate, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), name="alice"), \ create_channel_state(funding_txid, funding_index, funding_sat, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4),
name="alice",
initial_feerate=feerate),
lnchannel.Channel( lnchannel.Channel(
create_channel_state(funding_txid, funding_index, funding_sat, feerate, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), name="bob") create_channel_state(funding_txid, funding_index, funding_sat, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5),
name="bob",
initial_feerate=feerate)
)
alice.set_state('OPEN') alice.set_state('OPEN')
bob.set_state('OPEN') bob.set_state('OPEN')
@ -540,22 +544,25 @@ class TestChannel(unittest.TestCase):
bob_channel.receive_new_commitment(alice_sig, alice_htlc_sigs) bob_channel.receive_new_commitment(alice_sig, alice_htlc_sigs)
self.assertNotEqual(fee, bob_channel.get_current_feerate(LOCAL)) self.assertNotEqual(fee, bob_channel.get_oldest_unrevoked_feerate(LOCAL))
self.assertEqual(fee, bob_channel.get_latest_feerate(LOCAL))
rev, _ = bob_channel.revoke_current_commitment() rev, _ = bob_channel.revoke_current_commitment()
self.assertEqual(fee, bob_channel.get_current_feerate(LOCAL)) self.assertEqual(fee, bob_channel.get_oldest_unrevoked_feerate(LOCAL))
alice_channel.receive_revocation(rev) alice_channel.receive_revocation(rev)
bob_sig, bob_htlc_sigs = bob_channel.sign_next_commitment() bob_sig, bob_htlc_sigs = bob_channel.sign_next_commitment()
alice_channel.receive_new_commitment(bob_sig, bob_htlc_sigs) alice_channel.receive_new_commitment(bob_sig, bob_htlc_sigs)
self.assertNotEqual(fee, alice_channel.get_current_feerate(LOCAL)) self.assertNotEqual(fee, alice_channel.get_oldest_unrevoked_feerate(LOCAL))
self.assertEqual(fee, alice_channel.get_latest_feerate(LOCAL))
rev, _ = alice_channel.revoke_current_commitment() rev, _ = alice_channel.revoke_current_commitment()
self.assertEqual(fee, alice_channel.get_current_feerate(LOCAL)) self.assertEqual(fee, alice_channel.get_oldest_unrevoked_feerate(LOCAL))
bob_channel.receive_revocation(rev) bob_channel.receive_revocation(rev)
self.assertEqual(fee, bob_channel.get_current_feerate(LOCAL)) self.assertEqual(fee, bob_channel.get_oldest_unrevoked_feerate(LOCAL))
self.assertEqual(fee, bob_channel.get_latest_feerate(LOCAL))
def test_UpdateFeeReceiverCommits(self): def test_UpdateFeeReceiverCommits(self):
@ -571,20 +578,23 @@ class TestChannel(unittest.TestCase):
alice_sig, alice_htlc_sigs = alice_channel.sign_next_commitment() alice_sig, alice_htlc_sigs = alice_channel.sign_next_commitment()
bob_channel.receive_new_commitment(alice_sig, alice_htlc_sigs) bob_channel.receive_new_commitment(alice_sig, alice_htlc_sigs)
self.assertNotEqual(fee, bob_channel.get_current_feerate(LOCAL)) self.assertNotEqual(fee, bob_channel.get_oldest_unrevoked_feerate(LOCAL))
self.assertEqual(fee, bob_channel.get_latest_feerate(LOCAL))
bob_revocation, _ = bob_channel.revoke_current_commitment() bob_revocation, _ = bob_channel.revoke_current_commitment()
self.assertEqual(fee, bob_channel.get_current_feerate(LOCAL)) self.assertEqual(fee, bob_channel.get_oldest_unrevoked_feerate(LOCAL))
bob_sig, bob_htlc_sigs = bob_channel.sign_next_commitment() bob_sig, bob_htlc_sigs = bob_channel.sign_next_commitment()
alice_channel.receive_revocation(bob_revocation) alice_channel.receive_revocation(bob_revocation)
alice_channel.receive_new_commitment(bob_sig, bob_htlc_sigs) alice_channel.receive_new_commitment(bob_sig, bob_htlc_sigs)
self.assertNotEqual(fee, alice_channel.get_current_feerate(LOCAL)) self.assertNotEqual(fee, alice_channel.get_oldest_unrevoked_feerate(LOCAL))
self.assertEqual(fee, alice_channel.get_latest_feerate(LOCAL))
alice_revocation, _ = alice_channel.revoke_current_commitment() alice_revocation, _ = alice_channel.revoke_current_commitment()
self.assertEqual(fee, alice_channel.get_current_feerate(LOCAL)) self.assertEqual(fee, alice_channel.get_oldest_unrevoked_feerate(LOCAL))
bob_channel.receive_revocation(alice_revocation) bob_channel.receive_revocation(alice_revocation)
self.assertEqual(fee, bob_channel.get_current_feerate(LOCAL)) self.assertEqual(fee, bob_channel.get_oldest_unrevoked_feerate(LOCAL))
self.assertEqual(fee, bob_channel.get_latest_feerate(LOCAL))
@unittest.skip("broken probably because we havn't implemented detecting when we come out of a situation where we violate reserve") @unittest.skip("broken probably because we havn't implemented detecting when we come out of a situation where we violate reserve")
def test_AddHTLCNegativeBalance(self): def test_AddHTLCNegativeBalance(self):
@ -802,7 +812,7 @@ class TestDust(unittest.TestCase):
paymentPreimage = b"\x01" * 32 paymentPreimage = b"\x01" * 32
paymentHash = bitcoin.sha256(paymentPreimage) paymentHash = bitcoin.sha256(paymentPreimage)
fee_per_kw = alice_channel.get_current_feerate(LOCAL) fee_per_kw = alice_channel.get_next_feerate(LOCAL)
self.assertEqual(fee_per_kw, 6000) self.assertEqual(fee_per_kw, 6000)
htlcAmt = 500 + lnutil.HTLC_TIMEOUT_WEIGHT * (fee_per_kw // 1000) htlcAmt = 500 + lnutil.HTLC_TIMEOUT_WEIGHT * (fee_per_kw // 1000)
self.assertEqual(htlcAmt, 4478) self.assertEqual(htlcAmt, 4478)