lnhtlc: local update raw messages must not be deleted before acked

In recv_rev() previously all unacked_local_updates were deleted
as it was assumed that all of them have been acked at that point by
the revoke_and_ack itself. However this is not necessarily the case:
see new test case.

renamed log['unacked_local_updates'] to log['unacked_local_updates2']
to avoid breaking existing wallet files
This commit is contained in:
SomberNight 2019-08-12 18:37:13 +02:00 committed by ThomasV
parent 4fc9f243f7
commit a27b03be6d
4 changed files with 63 additions and 16 deletions

View file

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Optional, Sequence, Tuple, List
from typing import Optional, Sequence, Tuple, List, Dict
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u, bfh
@ -33,9 +33,10 @@ class HTLCManager:
log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
# "side who initiated fee update" -> action -> list of FeeUpdates
log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
if 'unacked_local_updates' not in log:
log['unacked_local_updates'] = []
log['unacked_local_updates'] = [bfh(upd) for upd in log['unacked_local_updates']]
if 'unacked_local_updates2' not in log:
log['unacked_local_updates2'] = {}
log['unacked_local_updates2'] = {int(ctn): [bfh(msg) for msg in messages]
for ctn, messages in log['unacked_local_updates2'].items()}
# maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None:
assert type(initial_feerate) is int
@ -74,7 +75,8 @@ class HTLCManager:
log[sub]['adds'] = d
# fee_updates
log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
log['unacked_local_updates'] = [bh2u(upd) for upd in log['unacked_local_updates']]
log['unacked_local_updates2'] = {ctn: [bh2u(msg) for msg in messages]
for ctn, messages in log['unacked_local_updates2'].items()}
return log
##### Actions on channel:
@ -175,7 +177,7 @@ class HTLCManager:
if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE):
fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
# no need to keep local update raw msgs anymore, they have just been ACKed.
self.log['unacked_local_updates'].clear()
self.log['unacked_local_updates2'].pop(self.log[REMOTE]['ctn'], None)
def discard_unsigned_remote_updates(self):
"""Discard updates sent by the remote, that the remote itself
@ -200,11 +202,22 @@ class HTLCManager:
if fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL):
del self.log[REMOTE]['fee_updates'][i]
def store_local_update_raw_msg(self, raw_update_msg: bytes):
self.log['unacked_local_updates'].append(raw_update_msg)
def store_local_update_raw_msg(self, raw_update_msg: bytes, *, is_commitment_signed: bool) -> None:
"""We need to be able to replay unacknowledged updates we sent to the remote
in case of disconnections. Hence, raw update and commitment_signed messages
are stored temporarily (until they are acked)."""
# self.log['unacked_local_updates2'][ctn_idx] is a list of raw messages
# containing some number of updates and then a single commitment_signed
if is_commitment_signed:
ctn_idx = self.ctn_latest(REMOTE)
else:
ctn_idx = self.ctn_latest(REMOTE) + 1
if ctn_idx not in self.log['unacked_local_updates2']:
self.log['unacked_local_updates2'][ctn_idx] = []
self.log['unacked_local_updates2'][ctn_idx].append(raw_update_msg)
def get_unacked_local_updates(self) -> Sequence[bytes]:
return self.log['unacked_local_updates']
def get_unacked_local_updates(self) -> Dict[int, Sequence[bytes]]:
return self.log['unacked_local_updates2']
##### Queries re HTLCs:

View file

@ -96,12 +96,13 @@ class Peer(Logger):
self.transport.send_bytes(raw_msg)
def _store_raw_msg_if_local_update(self, raw_msg: bytes, *, message_name: str, channel_id: Optional[bytes]):
if not (message_name.startswith("update_") or message_name == "commitment_signed"):
is_commitment_signed = message_name == "commitment_signed"
if not (message_name.startswith("update_") or is_commitment_signed):
return
assert channel_id
chan = self.lnworker.channels[channel_id] # type: Channel
chan.hm.store_local_update_raw_msg(raw_msg)
if message_name == "commitment_signed":
chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
if is_commitment_signed:
# saving now, to ensure replaying updates works (in case of channel reestablishment)
self.lnworker.save_channel(chan)
@ -755,8 +756,9 @@ class Peer(Logger):
# Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns,
# e.g. for watchtowers, hence we must ensure these ctxs coincide.
# We replay the local updates even if they were not yet committed.
for raw_upd_msg in chan.hm.get_unacked_local_updates():
self.transport.send_bytes(raw_upd_msg)
for ctn, messages in chan.hm.get_unacked_local_updates():
for raw_upd_msg in messages:
self.transport.send_bytes(raw_upd_msg)
should_close_we_are_ahead = False
should_close_they_are_ahead = False
@ -831,6 +833,7 @@ class Peer(Logger):
self.lnworker.force_close_channel(chan_id)
return
# note: chan.short_channel_id being set implies the funding txn is already at sufficient depth
if their_next_local_ctn == next_local_ctn == 1 and chan.short_channel_id:
self.send_funding_locked(chan)
# checks done

View file

@ -88,7 +88,7 @@ def create_ephemeral_key() -> (bytes, bytes):
class LNTransportBase:
def send_bytes(self, msg):
def send_bytes(self, msg: bytes) -> None:
l = len(msg).to_bytes(2, 'big')
lc = aead_encrypt(self.sk, self.sn(), b'', l)
c = aead_encrypt(self.sk, self.sn(), b'', msg)

View file

@ -211,3 +211,34 @@ class TestHTLCManager(unittest.TestCase):
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_latest_ctx(REMOTE))
self.assertEqual([(Direction.SENT, ah0)], A.get_htlcs_in_next_ctx(LOCAL))
self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE))
def test_unacked_local_updates(self):
A = HTLCManager()
B = HTLCManager()
A.channel_open_finished()
B.channel_open_finished()
self.assertEqual({}, A.get_unacked_local_updates())
ah0 = H('A', 0)
B.recv_htlc(A.send_htlc(ah0))
A.store_local_update_raw_msg(b"upd_msg0", is_commitment_signed=False)
self.assertEqual({1: [b"upd_msg0"]}, A.get_unacked_local_updates())
ah1 = H('A', 1)
B.recv_htlc(A.send_htlc(ah1))
A.store_local_update_raw_msg(b"upd_msg1", is_commitment_signed=False)
self.assertEqual({1: [b"upd_msg0", b"upd_msg1"]}, A.get_unacked_local_updates())
A.send_ctx()
B.recv_ctx()
A.store_local_update_raw_msg(b"ctx1", is_commitment_signed=True)
self.assertEqual({1: [b"upd_msg0", b"upd_msg1", b"ctx1"]}, A.get_unacked_local_updates())
ah2 = H('A', 2)
B.recv_htlc(A.send_htlc(ah2))
A.store_local_update_raw_msg(b"upd_msg2", is_commitment_signed=False)
self.assertEqual({1: [b"upd_msg0", b"upd_msg1", b"ctx1"], 2: [b"upd_msg2"]}, A.get_unacked_local_updates())
B.send_rev()
A.recv_rev()
self.assertEqual({2: [b"upd_msg2"]}, A.get_unacked_local_updates())