move all ctn book-keeping to lnhtlc (from lnchannel)

This commit is contained in:
SomberNight 2019-08-05 16:49:57 +02:00 committed by ThomasV
parent 44761972cb
commit 107f271e58
6 changed files with 39 additions and 38 deletions

View file

@ -137,9 +137,7 @@ class Channel(Logger):
self.data_loss_protect_remote_pcp = str_bytes_dict_from_save(state.get('data_loss_protect_remote_pcp', {})) self.data_loss_protect_remote_pcp = str_bytes_dict_from_save(state.get('data_loss_protect_remote_pcp', {}))
log = state.get('log') log = state.get('log')
self.hm = HTLCManager(local_ctn=self.config[LOCAL].ctn, self.hm = HTLCManager(log=log,
remote_ctn=self.config[REMOTE].ctn,
log=log,
initial_feerate=initial_feerate) initial_feerate=initial_feerate)
@ -174,8 +172,9 @@ class Channel(Logger):
return out return out
def open_with_first_pcp(self, remote_pcp, remote_sig): def open_with_first_pcp(self, remote_pcp, remote_sig):
self.config[REMOTE] = self.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_pcp, next_per_commitment_point=None) self.config[REMOTE] = self.config[REMOTE]._replace(current_per_commitment_point=remote_pcp,
self.config[LOCAL] = self.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) next_per_commitment_point=None)
self.config[LOCAL] = self.config[LOCAL]._replace(current_commitment_signature=remote_sig)
self.hm.channel_open_finished() self.hm.channel_open_finished()
self.set_state('OPENING') self.set_state('OPENING')
@ -380,15 +379,12 @@ class Channel(Logger):
def revoke_current_commitment(self): def revoke_current_commitment(self):
self.logger.info("revoke_current_commitment") self.logger.info("revoke_current_commitment")
new_ctn = self.config[LOCAL].ctn + 1 new_ctn = self.get_latest_ctn(LOCAL)
new_ctx = self.get_latest_commitment(LOCAL) new_ctx = self.get_latest_commitment(LOCAL)
if not self.signature_fits(new_ctx): if not self.signature_fits(new_ctx):
# this should never fail; as receive_new_commitment already did this test # this should never fail; as receive_new_commitment already did this test
raise Exception("refusing to revoke as remote sig does not fit") raise Exception("refusing to revoke as remote sig does not fit")
self.hm.send_rev() self.hm.send_rev()
self.config[LOCAL]=self.config[LOCAL]._replace(
ctn=new_ctn,
)
received = self.hm.received_in_ctn(new_ctn) received = self.hm.received_in_ctn(new_ctn)
sent = self.hm.sent_in_ctn(new_ctn) sent = self.hm.sent_in_ctn(new_ctn)
if self.lnworker: if self.lnworker:
@ -414,7 +410,6 @@ class Channel(Logger):
##### start applying fee/htlc changes ##### start applying fee/htlc changes
self.hm.recv_rev() self.hm.recv_rev()
self.config[REMOTE]=self.config[REMOTE]._replace( self.config[REMOTE]=self.config[REMOTE]._replace(
ctn=self.config[REMOTE].ctn + 1,
current_per_commitment_point=self.config[REMOTE].next_per_commitment_point, current_per_commitment_point=self.config[REMOTE].next_per_commitment_point,
next_per_commitment_point=revocation.next_per_commitment_point, next_per_commitment_point=revocation.next_per_commitment_point,
) )
@ -448,13 +443,13 @@ class Channel(Logger):
return initial return initial
def balance_minus_outgoing_htlcs(self, whose, *, ctx_owner=HTLCOwner.LOCAL): def balance_minus_outgoing_htlcs(self, whose: HTLCOwner, *, ctx_owner: HTLCOwner = HTLCOwner.LOCAL):
""" """
This balance in mSAT, which includes the value of This balance in mSAT, which includes the value of
pending outgoing HTLCs, is used in the UI. pending outgoing HTLCs, is used in the UI.
""" """
assert type(whose) is HTLCOwner assert type(whose) is HTLCOwner
ctn = self.hm.ctn[ctx_owner] + 1 ctn = self.get_next_ctn(ctx_owner)
return self.balance(whose, ctx_owner=ctx_owner, ctn=ctn)\ return self.balance(whose, ctx_owner=ctx_owner, ctn=ctn)\
- htlcsum(self.hm.htlcs_by_direction(ctx_owner, SENT, ctn)) - htlcsum(self.hm.htlcs_by_direction(ctx_owner, SENT, ctn))
@ -482,7 +477,7 @@ class Channel(Logger):
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
assert type(direction) is Direction assert type(direction) is Direction
if ctn is None: if ctn is None:
ctn = self.config[subject].ctn ctn = self.get_oldest_unrevoked_ctn(subject)
feerate = self.get_feerate(subject, ctn) feerate = self.get_feerate(subject, ctn)
conf = self.config[subject] conf = self.config[subject]
if (subject, direction) in [(REMOTE, RECEIVED), (LOCAL, SENT)]: if (subject, direction) in [(REMOTE, RECEIVED), (LOCAL, SENT)]:
@ -541,7 +536,7 @@ class Channel(Logger):
return create_sweeptxs_for_watchtower(self, ctx, secret, self.sweep_address) return create_sweeptxs_for_watchtower(self, ctx, secret, self.sweep_address)
def get_oldest_unrevoked_ctn(self, subject: HTLCOwner) -> int: def get_oldest_unrevoked_ctn(self, subject: HTLCOwner) -> int:
return self.config[subject].ctn return self.hm.ctn_oldest_unrevoked(subject)
def get_latest_ctn(self, subject: HTLCOwner) -> int: def get_latest_ctn(self, subject: HTLCOwner) -> int:
return self.hm.ctn_latest(subject) return self.hm.ctn_latest(subject)

View file

@ -7,9 +7,7 @@ from .util import bh2u, bfh
class HTLCManager: class HTLCManager:
def __init__(self, *, local_ctn=0, remote_ctn=0, log=None, initial_feerate=None): def __init__(self, *, log=None, initial_feerate=None):
# self.ctn[sub] is the ctn for the oldest unrevoked ctx of sub
self.ctn = {LOCAL:local_ctn, REMOTE: remote_ctn}
if log is None: if log is None:
initial = { initial = {
'adds': {}, 'adds': {},
@ -19,6 +17,7 @@ class HTLCManager:
'fee_updates': [], 'fee_updates': [],
'revack_pending': False, 'revack_pending': False,
'next_htlc_id': 0, 'next_htlc_id': 0,
'ctn': -1, # oldest unrevoked ctx of sub
} }
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)} log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
else: else:
@ -47,7 +46,11 @@ class HTLCManager:
def ctn_latest(self, sub: HTLCOwner) -> int: def ctn_latest(self, sub: HTLCOwner) -> int:
"""Return the ctn for the latest (newest that has a valid sig) ctx of sub""" """Return the ctn for the latest (newest that has a valid sig) ctx of sub"""
return self.ctn[sub] + int(self.is_revack_pending(sub)) return self.ctn_oldest_unrevoked(sub) + int(self.is_revack_pending(sub))
def ctn_oldest_unrevoked(self, sub: HTLCOwner) -> int:
"""Return the ctn for the oldest unrevoked ctx of sub"""
return self.log[sub]['ctn']
def is_revack_pending(self, sub: HTLCOwner) -> bool: def is_revack_pending(self, sub: HTLCOwner) -> bool:
"""Returns True iff sub was sent commitment_signed but they did not """Returns True iff sub was sent commitment_signed but they did not
@ -77,7 +80,8 @@ class HTLCManager:
##### Actions on channel: ##### Actions on channel:
def channel_open_finished(self): def channel_open_finished(self):
self.ctn = {LOCAL: 0, REMOTE: 0} self.log[LOCAL]['ctn'] = 0
self.log[REMOTE]['ctn'] = 0
self._set_revack_pending(LOCAL, False) self._set_revack_pending(LOCAL, False)
self._set_revack_pending(REMOTE, False) self._set_revack_pending(REMOTE, False)
@ -132,15 +136,15 @@ class HTLCManager:
self.log[subject]['fee_updates'].append(fee_update) self.log[subject]['fee_updates'].append(fee_update)
def send_ctx(self) -> None: def send_ctx(self) -> None:
assert self.ctn_latest(REMOTE) == self.ctn[REMOTE], (self.ctn_latest(REMOTE), self.ctn[REMOTE]) assert self.ctn_latest(REMOTE) == self.ctn_oldest_unrevoked(REMOTE), (self.ctn_latest(REMOTE), self.ctn_oldest_unrevoked(REMOTE))
self._set_revack_pending(REMOTE, True) self._set_revack_pending(REMOTE, True)
def recv_ctx(self) -> None: def recv_ctx(self) -> None:
assert self.ctn_latest(LOCAL) == self.ctn[LOCAL], (self.ctn_latest(LOCAL), self.ctn[LOCAL]) assert self.ctn_latest(LOCAL) == self.ctn_oldest_unrevoked(LOCAL), (self.ctn_latest(LOCAL), self.ctn_oldest_unrevoked(LOCAL))
self._set_revack_pending(LOCAL, True) self._set_revack_pending(LOCAL, True)
def send_rev(self) -> None: def send_rev(self) -> None:
self.ctn[LOCAL] += 1 self.log[LOCAL]['ctn'] += 1
self._set_revack_pending(LOCAL, False) self._set_revack_pending(LOCAL, False)
# htlcs # htlcs
for ctns in self.log[REMOTE]['locked_in'].values(): for ctns in self.log[REMOTE]['locked_in'].values():
@ -156,7 +160,7 @@ class HTLCManager:
fee_update.ctns[REMOTE] = self.ctn_latest(REMOTE) + 1 fee_update.ctns[REMOTE] = self.ctn_latest(REMOTE) + 1
def recv_rev(self) -> None: def recv_rev(self) -> None:
self.ctn[REMOTE] += 1 self.log[REMOTE]['ctn'] += 1
self._set_revack_pending(REMOTE, False) self._set_revack_pending(REMOTE, False)
# htlcs # htlcs
for ctns in self.log[LOCAL]['locked_in'].values(): for ctns in self.log[LOCAL]['locked_in'].values():
@ -210,7 +214,7 @@ class HTLCManager:
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
assert type(direction) is Direction assert type(direction) is Direction
if ctn is None: if ctn is None:
ctn = self.ctn[subject] ctn = self.ctn_oldest_unrevoked(subject)
l = [] l = []
# subject's ctx # subject's ctx
# party is the proposer of the HTLCs # party is the proposer of the HTLCs
@ -229,7 +233,7 @@ class HTLCManager:
"""Return the list of HTLCs in subject's ctx at ctn.""" """Return the list of HTLCs in subject's ctx at ctn."""
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
if ctn is None: if ctn is None:
ctn = self.ctn[subject] ctn = self.ctn_oldest_unrevoked(subject)
l = [] l = []
l += [(SENT, x) for x in self.htlcs_by_direction(subject, SENT, ctn)] l += [(SENT, x) for x in self.htlcs_by_direction(subject, SENT, ctn)]
l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)] l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)]
@ -237,7 +241,7 @@ class HTLCManager:
def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: def get_htlcs_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
ctn = self.ctn[subject] ctn = self.ctn_oldest_unrevoked(subject)
return self.htlcs(subject, ctn) return self.htlcs(subject, ctn)
def get_htlcs_in_latest_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: def get_htlcs_in_latest_ctx(self, subject: HTLCOwner) -> Sequence[Tuple[Direction, UpdateAddHtlc]]:
@ -257,7 +261,7 @@ class HTLCManager:
""" """
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
if ctn is None: if ctn is None:
ctn = self.ctn[subject] ctn = self.ctn_oldest_unrevoked(subject)
# subject's ctx # subject's ctx
# party is the proposer of the HTLCs # party is the proposer of the HTLCs
party = subject if direction == SENT else subject.inverted() party = subject if direction == SENT else subject.inverted()
@ -274,7 +278,7 @@ class HTLCManager:
""" """
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
if ctn is None: if ctn is None:
ctn = self.ctn[subject] ctn = self.ctn_oldest_unrevoked(subject)
sent = [(SENT, x) for x in self.all_settled_htlcs_ever_by_direction(subject, SENT, ctn)] sent = [(SENT, x) for x in self.all_settled_htlcs_ever_by_direction(subject, SENT, ctn)]
received = [(RECEIVED, x) for x in self.all_settled_htlcs_ever_by_direction(subject, RECEIVED, ctn)] received = [(RECEIVED, x) for x in self.all_settled_htlcs_ever_by_direction(subject, RECEIVED, ctn)]
return sent + received return sent + received
@ -318,7 +322,7 @@ class HTLCManager:
return fee_log[i].rate return fee_log[i].rate
def get_feerate_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> int: def get_feerate_in_oldest_unrevoked_ctx(self, subject: HTLCOwner) -> int:
return self.get_feerate(subject=subject, ctn=self.ctn[subject]) return self.get_feerate(subject=subject, ctn=self.ctn_oldest_unrevoked(subject))
def get_feerate_in_latest_ctx(self, subject: HTLCOwner) -> int: def get_feerate_in_latest_ctx(self, subject: HTLCOwner) -> int:
return self.get_feerate(subject=subject, ctn=self.ctn_latest(subject)) return self.get_feerate(subject=subject, ctn=self.ctn_latest(subject))

View file

@ -451,7 +451,6 @@ class Peer(Logger):
max_htlc_value_in_flight_msat=funding_sat * 1000, max_htlc_value_in_flight_msat=funding_sat * 1000,
max_accepted_htlcs=5, max_accepted_htlcs=5,
initial_msat=initial_msat, initial_msat=initial_msat,
ctn=-1,
reserve_sat=546, reserve_sat=546,
per_commitment_secret_seed=keypair_generator(LnKeyFamily.REVOCATION_ROOT).privkey, per_commitment_secret_seed=keypair_generator(LnKeyFamily.REVOCATION_ROOT).privkey,
funding_locked_received=False, funding_locked_received=False,
@ -533,7 +532,6 @@ class Peer(Logger):
max_htlc_value_in_flight_msat=remote_max, max_htlc_value_in_flight_msat=remote_max,
max_accepted_htlcs=max_accepted_htlcs, max_accepted_htlcs=max_accepted_htlcs,
initial_msat=push_msat, initial_msat=push_msat,
ctn = -1,
reserve_sat = remote_reserve_sat, reserve_sat = remote_reserve_sat,
htlc_minimum_msat = htlc_min, htlc_minimum_msat = htlc_min,
@ -633,7 +631,6 @@ class Peer(Logger):
max_htlc_value_in_flight_msat=int.from_bytes(payload['max_htlc_value_in_flight_msat'], 'big'), # TODO validate max_htlc_value_in_flight_msat=int.from_bytes(payload['max_htlc_value_in_flight_msat'], 'big'), # TODO validate
max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'), # TODO validate max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'), # TODO validate
initial_msat=remote_balance_sat, initial_msat=remote_balance_sat,
ctn = -1,
reserve_sat = remote_reserve_sat, reserve_sat = remote_reserve_sat,
htlc_minimum_msat=int.from_bytes(payload['htlc_minimum_msat'], 'big'), # TODO validate htlc_minimum_msat=int.from_bytes(payload['htlc_minimum_msat'], 'big'), # TODO validate
next_per_commitment_point=payload['first_per_commitment_point'], next_per_commitment_point=payload['first_per_commitment_point'],

View file

@ -34,7 +34,6 @@ OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"])
# NamedTuples cannot subclass NamedTuples :'( https://github.com/python/typing/issues/427 # NamedTuples cannot subclass NamedTuples :'( https://github.com/python/typing/issues/427
class LocalConfig(NamedTuple): class LocalConfig(NamedTuple):
# shared channel config fields (DUPLICATED code!!) # shared channel config fields (DUPLICATED code!!)
ctn: int
payment_basepoint: 'Keypair' payment_basepoint: 'Keypair'
multisig_key: 'Keypair' multisig_key: 'Keypair'
htlc_basepoint: 'Keypair' htlc_basepoint: 'Keypair'
@ -56,7 +55,6 @@ class LocalConfig(NamedTuple):
class RemoteConfig(NamedTuple): class RemoteConfig(NamedTuple):
# shared channel config fields (DUPLICATED code!!) # shared channel config fields (DUPLICATED code!!)
ctn: int
payment_basepoint: 'Keypair' payment_basepoint: 'Keypair'
multisig_key: 'Keypair' multisig_key: 'Keypair'
htlc_basepoint: 'Keypair' htlc_basepoint: 'Keypair'

View file

@ -58,7 +58,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
max_htlc_value_in_flight_msat=one_bitcoin_in_msat * 5, max_htlc_value_in_flight_msat=one_bitcoin_in_msat * 5,
max_accepted_htlcs=5, max_accepted_htlcs=5,
initial_msat=remote_amount, initial_msat=remote_amount,
ctn = -1,
reserve_sat=0, reserve_sat=0,
htlc_minimum_msat=1, htlc_minimum_msat=1,
@ -77,7 +76,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, is_initiator,
max_htlc_value_in_flight_msat=one_bitcoin_in_msat * 5, max_htlc_value_in_flight_msat=one_bitcoin_in_msat * 5,
max_accepted_htlcs=5, max_accepted_htlcs=5,
initial_msat=local_amount, initial_msat=local_amount,
ctn = 0,
reserve_sat=0, reserve_sat=0,
per_commitment_secret_seed=seed, per_commitment_secret_seed=seed,
@ -132,6 +130,9 @@ def create_test_channels(feerate=6000, local=None, remote=None):
initial_feerate=feerate) initial_feerate=feerate)
) )
alice.hm.log[LOCAL]['ctn'] = 0
bob.hm.log[LOCAL]['ctn'] = 0
alice.set_state('OPEN') alice.set_state('OPEN')
bob.set_state('OPEN') bob.set_state('OPEN')
@ -154,8 +155,6 @@ def create_test_channels(feerate=6000, local=None, remote=None):
alice.config[REMOTE] = alice.config[REMOTE]._replace(next_per_commitment_point=bob_second, current_per_commitment_point=bob_first) alice.config[REMOTE] = alice.config[REMOTE]._replace(next_per_commitment_point=bob_second, current_per_commitment_point=bob_first)
bob.config[REMOTE] = bob.config[REMOTE]._replace(next_per_commitment_point=alice_second, current_per_commitment_point=alice_first) bob.config[REMOTE] = bob.config[REMOTE]._replace(next_per_commitment_point=alice_second, current_per_commitment_point=alice_first)
alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0)
bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0)
alice.hm.channel_open_finished() alice.hm.channel_open_finished()
bob.hm.channel_open_finished() bob.hm.channel_open_finished()

View file

@ -12,6 +12,8 @@ class TestHTLCManager(unittest.TestCase):
def test_adding_htlcs_race(self): def test_adding_htlcs_race(self):
A = HTLCManager() A = HTLCManager()
B = HTLCManager() B = HTLCManager()
A.channel_open_finished()
B.channel_open_finished()
ah0, bh0 = H('A', 0), H('B', 0) ah0, bh0 = H('A', 0), H('B', 0)
B.recv_htlc(A.send_htlc(ah0)) B.recv_htlc(A.send_htlc(ah0))
self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1) self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1)
@ -57,6 +59,8 @@ class TestHTLCManager(unittest.TestCase):
def htlc_lifecycle(htlc_success: bool): def htlc_lifecycle(htlc_success: bool):
A = HTLCManager() A = HTLCManager()
B = HTLCManager() B = HTLCManager()
A.channel_open_finished()
B.channel_open_finished()
B.recv_htlc(A.send_htlc(H('A', 0))) B.recv_htlc(A.send_htlc(H('A', 0)))
self.assertEqual(len(B.get_htlcs_in_next_ctx(REMOTE)), 0) self.assertEqual(len(B.get_htlcs_in_next_ctx(REMOTE)), 0)
self.assertEqual(len(A.get_htlcs_in_next_ctx(REMOTE)), 1) self.assertEqual(len(A.get_htlcs_in_next_ctx(REMOTE)), 1)
@ -128,6 +132,8 @@ class TestHTLCManager(unittest.TestCase):
def htlc_lifecycle(htlc_success: bool): def htlc_lifecycle(htlc_success: bool):
A = HTLCManager() A = HTLCManager()
B = HTLCManager() B = HTLCManager()
A.channel_open_finished()
B.channel_open_finished()
ah0 = H('A', 0) ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0)) B.recv_htlc(A.send_htlc(ah0))
A.send_ctx() A.send_ctx()
@ -163,6 +169,8 @@ class TestHTLCManager(unittest.TestCase):
def test_adding_htlc_between_send_ctx_and_recv_rev(self): def test_adding_htlc_between_send_ctx_and_recv_rev(self):
A = HTLCManager() A = HTLCManager()
B = HTLCManager() B = HTLCManager()
A.channel_open_finished()
B.channel_open_finished()
A.send_ctx() A.send_ctx()
B.recv_ctx() B.recv_ctx()
B.send_rev() B.send_rev()