lnpeer: reestablish_channel - replay un-acked local updates

Replay un-acked local updates (including commitment_signed) byte-for-byte.
If we have sent them a commitment signature that they "lost" (due to disconnect),
we need to make sure we replay the same local updates, as otherwise they could
end up with two (or more) signed valid commitment transactions at the same ctn.
Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns,
e.g. for watchtowers, hence we must ensure these ctxs coincide.
This commit is contained in:
SomberNight 2019-08-02 17:55:45 +02:00 committed by ThomasV
parent e81ae1921b
commit 014b921393
2 changed files with 40 additions and 6 deletions

View file

@ -2,7 +2,7 @@ from copy import deepcopy
from typing import Optional, Sequence, Tuple, List from typing import Optional, Sequence, Tuple, List
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate
from .util import bh2u from .util import bh2u, bfh
class HTLCManager: class HTLCManager:
@ -34,6 +34,9 @@ class HTLCManager:
log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()} log[sub]['fails'] = {int(htlc_id): coerceHtlcOwner2IntMap(ctns) for htlc_id, ctns in log[sub]['fails'].items()}
# "side who initiated fee update" -> action -> list of FeeUpdates # "side who initiated fee update" -> action -> list of FeeUpdates
log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']] log[sub]['fee_updates'] = [FeeUpdate.from_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
if 'unacked_local_updates' not in log:
log['unacked_local_updates'] = []
log['unacked_local_updates'] = [bfh(upd) for upd in log['unacked_local_updates']]
# maybe bootstrap fee_updates if initial_feerate was provided # maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None: if initial_feerate is not None:
assert type(initial_feerate) is int assert type(initial_feerate) is int
@ -68,6 +71,7 @@ class HTLCManager:
log[sub]['adds'] = d log[sub]['adds'] = d
# fee_updates # fee_updates
log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']] log[sub]['fee_updates'] = [FeeUpdate.to_dict(fee_upd) for fee_upd in log[sub]['fee_updates']]
log['unacked_local_updates'] = [bh2u(upd) for upd in log['unacked_local_updates']]
return log return log
##### Actions on channel: ##### Actions on channel:
@ -166,6 +170,8 @@ class HTLCManager:
for fee_update in self.log[LOCAL]['fee_updates']: for fee_update in self.log[LOCAL]['fee_updates']:
if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE): if fee_update.ctns[LOCAL] is None and fee_update.ctns[REMOTE] <= self.ctn_latest(REMOTE):
fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1 fee_update.ctns[LOCAL] = self.ctn_latest(LOCAL) + 1
# no need to keep local update raw msgs anymore, they have just been ACKed.
self.log['unacked_local_updates'].clear()
def discard_unsigned_remote_updates(self): def discard_unsigned_remote_updates(self):
"""Discard updates sent by the remote, that the remote itself """Discard updates sent by the remote, that the remote itself
@ -186,6 +192,12 @@ class HTLCManager:
if fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL): if fee_update.ctns[LOCAL] > self.ctn_latest(LOCAL):
del self.log[REMOTE]['fee_updates'][i] del self.log[REMOTE]['fee_updates'][i]
def store_local_update_raw_msg(self, raw_update_msg: bytes):
self.log['unacked_local_updates'].append(raw_update_msg)
def get_unacked_local_updates(self) -> Sequence[bytes]:
return self.log['unacked_local_updates']
##### Queries re HTLCs: ##### Queries re HTLCs:
def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction, def htlcs_by_direction(self, subject: HTLCOwner, direction: Direction,

View file

@ -93,7 +93,19 @@ class Peer(Logger):
def send_message(self, message_name: str, **kwargs): def send_message(self, message_name: str, **kwargs):
assert type(message_name) is str assert type(message_name) is str
self.logger.debug(f"Sending {message_name.upper()}") self.logger.debug(f"Sending {message_name.upper()}")
self.transport.send_bytes(encode_msg(message_name, **kwargs)) raw_msg = encode_msg(message_name, **kwargs)
self._store_raw_msg_if_local_update(raw_msg, message_name=message_name, channel_id=kwargs.get("channel_id"))
self.transport.send_bytes(raw_msg)
def _store_raw_msg_if_local_update(self, raw_msg: bytes, *, message_name: str, channel_id: Optional[bytes]):
if not (message_name.startswith("update_") or message_name == "commitment_signed"):
return
assert channel_id
chan = self.lnworker.channels[channel_id] # type: Channel
chan.hm.store_local_update_raw_msg(raw_msg)
if message_name == "commitment_signed":
# saving now, to ensure replaying updates works (in case of channel reestablishment)
self.lnworker.save_channel(chan)
async def initialize(self): async def initialize(self):
if isinstance(self.transport, LNTransport): if isinstance(self.transport, LNTransport):
@ -733,11 +745,21 @@ class Peer(Logger):
should_close_they_are_ahead = False should_close_they_are_ahead = False
# compare remote ctns # compare remote ctns
if next_remote_ctn != their_next_local_ctn: if next_remote_ctn != their_next_local_ctn:
self.logger.warning(f"channel_reestablish: expected remote ctn {next_remote_ctn}, got {their_next_local_ctn}") if their_next_local_ctn == latest_remote_ctn and chan.hm.is_revack_pending(REMOTE):
if their_next_local_ctn < next_remote_ctn: # Replay un-acked local updates (including commitment_signed) byte-for-byte.
should_close_we_are_ahead = True # If we have sent them a commitment signature that they "lost" (due to disconnect),
# we need to make sure we replay the same local updates, as otherwise they could
# end up with two (or more) signed valid commitment transactions at the same ctn.
# Multiple valid ctxs at the same ctn is a major headache for pre-signing spending txns,
# e.g. for watchtowers, hence we must ensure these ctxs coincide.
for raw_upd_msg in chan.hm.get_unacked_local_updates():
self.transport.send_bytes(raw_upd_msg)
else: else:
should_close_they_are_ahead = True self.logger.warning(f"channel_reestablish: expected remote ctn {next_remote_ctn}, got {their_next_local_ctn}")
if their_next_local_ctn < next_remote_ctn:
should_close_we_are_ahead = True
else:
should_close_they_are_ahead = True
# compare local ctns # compare local ctns
if oldest_unrevoked_local_ctn != their_oldest_unrevoked_remote_ctn: if oldest_unrevoked_local_ctn != their_oldest_unrevoked_remote_ctn:
if oldest_unrevoked_local_ctn - 1 == their_oldest_unrevoked_remote_ctn: if oldest_unrevoked_local_ctn - 1 == their_oldest_unrevoked_remote_ctn: