diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index bc495dd85..f656f9465 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Optional, Sequence, Tuple, List from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate -from .util import bh2u +from .util import bh2u, bfh class HTLCManager: @@ -34,6 +34,9 @@ class HTLCManager: log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()} # "side who initiated fee update" -> action -> list of FeeUpdates log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']] + if 'unacked_local_updates' not in log: + log['unacked_local_updates'] = [] + log['unacked_local_updates'] = [bfh(upd) for upd in log['unacked_local_updates']] # maybe bootstrap fee_updates if initial_feerate was provided if initial_feerate is not None: assert type(initial_feerate) is int @@ -68,6 +71,7 @@ class HTLCManager: log[sub]['adds'] = d # fee_updates log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']] + log['unacked_local_updates'] = [bh2u(upd) for upd in log['unacked_local_updates']] return log ##### Actions on channel: @@ -166,6 +170,8 @@ class HTLCManager: for fee_update in self.log[LOCAL]['fee_updates']: if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE): fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 + # no need to keep local update raw msgs anymore, they have just been ACKed. + self.log['unacked_local_updates'].clear() def discard_unsigned_remote_updates(self): """Discard updates sent by the remote, that the remote itself @@ -186,6 +192,12 @@ class HTLCManager: if fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL): del self.log[REMOTE]['fee_updates'][i] + def store_local_update_raw_msg(self, raw_update_msg: bytes): + self.log['unacked_local_updates'].append(raw_update_msg) + + def get_unacked_local_updates(self) -> Sequence[bytes]: + return self.log['unacked_local_updates'] + ##### Queries re HTLCs: def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 4f70ef475..b22dfb00b 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -93,7 +93,19 @@ class Peer(Logger): def send_message(self, message_name: str, **kwargs): assert type(message_name) is str self.logger.debug(f"Sending {message_name.upper()}") - self.transport.send_bytes(encode_msg(message_name, **kwargs)) + raw_msg = encode_msg(message_name, **kwargs) + self._store_raw_msg_if_local_update(raw_msg, message_name=message_name, channel_id=kwargs.get("channel_id")) + self.transport.send_bytes(raw_msg) + + def _store_raw_msg_if_local_update(self, raw_msg: bytes, *, message_name: str, channel_id: Optional[bytes]): + if not (message_name.startswith("update_") or message_name == "commitment_signed"): + return + assert channel_id + chan = self.lnworker.channels[channel_id] # type: Channel + chan.hm.store_local_update_raw_msg(raw_msg) + if message_name == "commitment_signed": + # saving now, to ensure replaying updates works (in case of channel reestablishment) + self.lnworker.save_channel(chan) async def initialize(self): if isinstance(self.transport, LNTransport): @@ -733,11 +745,21 @@ class Peer(Logger): should_close_they_are_ahead = False # compare remote ctns if next_remote_ctn != their_next_local_ctn: - self.logger.warning(f"channel_reestablish: expected remote ctn {next_remote_ctn}, got {their_next_local_ctn}") - if their_next_local_ctn < next_remote_ctn: - should_close_we_are_ahead = True + if their_next_local_ctn == latest_remote_ctn and chan.hm.is_revack_pending(REMOTE): + # Replay un-acked local updates (including commitment_signed) byte-for-byte. + # If we have sent them a commitment signature that they "lost" (due to disconnect), + # we need to make sure we replay the same local updates, as otherwise they could + # end up with two (or more) signed valid commitment transactions at the same ctn. + # Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns, + # e.g. for watchtowers, hence we must ensure these ctxs coincide. + for raw_upd_msg in chan.hm.get_unacked_local_updates(): + self.transport.send_bytes(raw_upd_msg) else: - should_close_they_are_ahead = True + self.logger.warning(f"channel_reestablish: expected remote ctn {next_remote_ctn}, got {their_next_local_ctn}") + if their_next_local_ctn < next_remote_ctn: + should_close_we_are_ahead = True + else: + should_close_they_are_ahead = True # compare local ctns if oldest_unrevoked_local_ctn != their_oldest_unrevoked_remote_ctn: if oldest_unrevoked_local_ctn - 1 == their_oldest_unrevoked_remote_ctn: