lnchannel: handle htlc-address collisions

We were previously generating an incorrect commitment_signed msg if there were
multiple htlcs sharing the same scriptPubKey.
This commit is contained in:
SomberNight 2019-09-07 07:37:13 +02:00
parent 00f15d491b
commit 83fcdbd561
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
4 changed files with 219 additions and 157 deletions

View file

@ -46,7 +46,7 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey
HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc, HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc,
funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs, funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script, ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script,
ShortChannelID) ShortChannelID, map_htlcs_to_ctx_output_idxs)
from .lnutil import FeeUpdate from .lnutil import FeeUpdate
from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
from .lnsweep import create_sweeptx_for_their_revoked_htlc from .lnsweep import create_sweeptx_for_their_revoked_htlc
@ -291,24 +291,23 @@ class Channel(Logger):
self.config[REMOTE].next_per_commitment_point) self.config[REMOTE].next_per_commitment_point)
their_remote_htlc_privkey = their_remote_htlc_privkey_number.to_bytes(32, 'big') their_remote_htlc_privkey = their_remote_htlc_privkey_number.to_bytes(32, 'big')
for_us = False
htlcsigs = [] htlcsigs = []
# they sent => we receive htlc_to_ctx_output_idx_map = map_htlcs_to_ctx_output_idxs(chan=self,
for we_receive, htlcs in zip([True, False], [self.included_htlcs(REMOTE, SENT, ctn=next_remote_ctn), ctx=pending_remote_commitment,
self.included_htlcs(REMOTE, RECEIVED, ctn=next_remote_ctn)]):
for htlc in htlcs:
_script, htlc_tx = make_htlc_tx_with_open_channel(chan=self,
pcp=self.config[REMOTE].next_per_commitment_point, pcp=self.config[REMOTE].next_per_commitment_point,
for_us=for_us, subject=REMOTE,
we_receive=we_receive, ctn=next_remote_ctn)
commit=pending_remote_commitment, for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items():
htlc=htlc) _script, htlc_tx = make_htlc_tx_with_open_channel(chan=self,
sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey)) pcp=self.config[REMOTE].next_per_commitment_point,
htlc_sig = ecc.sig_string_from_der_sig(sig[:-1]) subject=REMOTE,
htlc_output_idx = htlc_tx.inputs()[0]['prevout_n'] htlc_direction=direction,
htlcsigs.append((htlc_output_idx, htlc_sig)) commit=pending_remote_commitment,
ctx_output_idx=ctx_output_idx,
htlc=htlc)
sig = bfh(htlc_tx.sign_txin(0, their_remote_htlc_privkey))
htlc_sig = ecc.sig_string_from_der_sig(sig[:-1])
htlcsigs.append((ctx_output_idx, htlc_sig))
htlcsigs.sort() htlcsigs.sort()
htlcsigs = [x[1] for x in htlcsigs] htlcsigs = [x[1] for x in htlcsigs]
@ -329,8 +328,9 @@ class Channel(Logger):
This docstring is from LND. This docstring is from LND.
""" """
# TODO in many failure cases below, we should "fail" the channel (force-close)
next_local_ctn = self.get_next_ctn(LOCAL) next_local_ctn = self.get_next_ctn(LOCAL)
self.logger.info("receive_new_commitment") self.logger.info(f"receive_new_commitment. ctn={next_local_ctn}, len(htlc_sigs)={len(htlc_sigs)}")
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
@ -342,45 +342,48 @@ class Channel(Logger):
htlc_sigs_string = b''.join(htlc_sigs) htlc_sigs_string = b''.join(htlc_sigs)
htlc_sigs = htlc_sigs[:] # copy cause we will delete now _secret, pcp = self.get_secret_and_point(subject=LOCAL, ctn=next_local_ctn)
for htlcs, we_receive in [(self.included_htlcs(LOCAL, SENT, ctn=next_local_ctn), False),
(self.included_htlcs(LOCAL, RECEIVED, ctn=next_local_ctn), True)]: htlc_to_ctx_output_idx_map = map_htlcs_to_ctx_output_idxs(chan=self,
# FIXME this is quadratic. BOLT-02: "corresponding to the ordering of the commitment transaction" ctx=pending_local_commitment,
for htlc in htlcs: pcp=pcp,
idx = self.verify_htlc(htlc, htlc_sigs, we_receive, pending_local_commitment) subject=LOCAL,
del htlc_sigs[idx] ctn=next_local_ctn)
if len(htlc_sigs) != 0: # all sigs should have been popped above if len(htlc_to_ctx_output_idx_map) != len(htlc_sigs):
raise Exception('failed verifying HTLC signatures: invalid amount of correct signatures') raise Exception(f'htlc sigs failure. recv {len(htlc_sigs)} sigs, expected {len(htlc_to_ctx_output_idx_map)}')
for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items():
htlc_sig = htlc_sigs[htlc_relative_idx]
self.verify_htlc(htlc=htlc,
htlc_sig=htlc_sig,
htlc_direction=direction,
pcp=pcp,
ctx=pending_local_commitment,
ctx_output_idx=ctx_output_idx)
self.hm.recv_ctx() self.hm.recv_ctx()
self.config[LOCAL]=self.config[LOCAL]._replace( self.config[LOCAL]=self.config[LOCAL]._replace(
current_commitment_signature=sig, current_commitment_signature=sig,
current_htlc_signatures=htlc_sigs_string) current_htlc_signatures=htlc_sigs_string)
def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool, ctx) -> int: def verify_htlc(self, *, htlc: UpdateAddHtlc, htlc_sig: bytes, htlc_direction: Direction,
ctn = extract_ctn_from_tx_and_chan(ctx, self) pcp: bytes, ctx: Transaction, ctx_output_idx: int) -> None:
secret = get_per_commitment_secret_from_seed(self.config[LOCAL].per_commitment_secret_seed, RevocationStore.START_INDEX - ctn)
point = secret_to_pubkey(int.from_bytes(secret, 'big'))
_script, htlc_tx = make_htlc_tx_with_open_channel(chan=self, _script, htlc_tx = make_htlc_tx_with_open_channel(chan=self,
pcp=point, pcp=pcp,
for_us=True, subject=LOCAL,
we_receive=we_receive, htlc_direction=htlc_direction,
commit=ctx, commit=ctx,
ctx_output_idx=ctx_output_idx,
htlc=htlc) htlc=htlc)
pre_hash = sha256d(bfh(htlc_tx.serialize_preimage(0))) pre_hash = sha256d(bfh(htlc_tx.serialize_preimage(0)))
remote_htlc_pubkey = derive_pubkey(self.config[REMOTE].htlc_basepoint.pubkey, point) remote_htlc_pubkey = derive_pubkey(self.config[REMOTE].htlc_basepoint.pubkey, pcp)
for idx, sig in enumerate(htlc_sigs): if not ecc.verify_signature(remote_htlc_pubkey, htlc_sig, pre_hash):
if ecc.verify_signature(remote_htlc_pubkey, sig, pre_hash): raise Exception(f'failed verifying HTLC signatures: {htlc} {htlc_direction}')
return idx
else:
raise Exception(f'failed verifying HTLC signatures: {htlc}, sigs: {len(htlc_sigs)}, we_receive: {we_receive}')
def get_remote_htlc_sig_for_htlc(self, htlc: UpdateAddHtlc, we_receive: bool, ctx) -> bytes: def get_remote_htlc_sig_for_htlc(self, *, htlc_relative_idx: int) -> bytes:
data = self.config[LOCAL].current_htlc_signatures data = self.config[LOCAL].current_htlc_signatures
htlc_sigs = [data[i:i + 64] for i in range(0, len(data), 64)] htlc_sigs = [data[i:i + 64] for i in range(0, len(data), 64)]
idx = self.verify_htlc(htlc, htlc_sigs, we_receive=we_receive, ctx=ctx) htlc_sig = htlc_sigs[htlc_relative_idx]
remote_htlc_sig = ecc.der_sig_from_sig_string(htlc_sigs[idx]) + b'\x01' remote_htlc_sig = ecc.der_sig_from_sig_string(htlc_sig) + b'\x01'
return remote_htlc_sig return remote_htlc_sig
def revoke_current_commitment(self): def revoke_current_commitment(self):
@ -491,10 +494,10 @@ class Channel(Logger):
else: else:
weight = HTLC_TIMEOUT_WEIGHT weight = HTLC_TIMEOUT_WEIGHT
htlcs = self.hm.htlcs_by_direction(subject, direction, ctn=ctn).values() htlcs = self.hm.htlcs_by_direction(subject, direction, ctn=ctn).values()
fee_for_htlc = lambda htlc: htlc.amount_msat // 1000 - (weight * feerate // 1000) htlc_value_after_fees = lambda htlc: htlc.amount_msat // 1000 - (weight * feerate // 1000)
return list(filter(lambda htlc: fee_for_htlc(htlc) >= conf.dust_limit_sat, htlcs)) return list(filter(lambda htlc: htlc_value_after_fees(htlc) >= conf.dust_limit_sat, htlcs))
def get_secret_and_point(self, subject, ctn) -> Tuple[Optional[bytes], bytes]: def get_secret_and_point(self, subject: HTLCOwner, ctn: int) -> Tuple[Optional[bytes], bytes]:
assert type(subject) is HTLCOwner assert type(subject) is HTLCOwner
assert ctn >= 0, ctn assert ctn >= 0, ctn
offset = ctn - self.get_oldest_unrevoked_ctn(subject) offset = ctn - self.get_oldest_unrevoked_ctn(subject)
@ -537,7 +540,7 @@ class Channel(Logger):
ctn = self.get_oldest_unrevoked_ctn(subject) ctn = self.get_oldest_unrevoked_ctn(subject)
return self.get_commitment(subject, ctn) return self.get_commitment(subject, ctn)
def create_sweeptxs(self, ctn): def create_sweeptxs(self, ctn: int) -> List[Transaction]:
from .lnsweep import create_sweeptxs_for_watchtower from .lnsweep import create_sweeptxs_for_watchtower
secret, ctx = self.get_secret_and_commitment(REMOTE, ctn) secret, ctx = self.get_secret_and_commitment(REMOTE, ctn)
return create_sweeptxs_for_watchtower(self, ctx, secret, self.sweep_address) return create_sweeptxs_for_watchtower(self, ctx, secret, self.sweep_address)
@ -754,9 +757,8 @@ class Channel(Logger):
def sweep_ctx(self, ctx: Transaction): def sweep_ctx(self, ctx: Transaction):
txid = ctx.txid() txid = ctx.txid()
if self.sweep_info.get(txid) is None: if self.sweep_info.get(txid) is None:
ctn = extract_ctn_from_tx_and_chan(ctx, self) our_sweep_info = create_sweeptxs_for_our_ctx(chan=self, ctx=ctx, sweep_address=self.sweep_address)
our_sweep_info = create_sweeptxs_for_our_ctx(self, ctx, ctn, self.sweep_address) their_sweep_info = create_sweeptxs_for_their_ctx(chan=self, ctx=ctx, sweep_address=self.sweep_address)
their_sweep_info = create_sweeptxs_for_their_ctx(self, ctx, ctn, self.sweep_address)
if our_sweep_info is not None: if our_sweep_info is not None:
self.sweep_info[txid] = our_sweep_info self.sweep_info[txid] = our_sweep_info
self.logger.info(f'we force closed.') self.logger.info(f'we force closed.')

View file

@ -2,7 +2,7 @@
# Distributed under the MIT software license, see the accompanying # Distributed under the MIT software license, see the accompanying
# file LICENCE or http://www.opensource.org/licenses/mit-license.php # file LICENCE or http://www.opensource.org/licenses/mit-license.php
from typing import Optional, Dict, List, Tuple, TYPE_CHECKING, NamedTuple from typing import Optional, Dict, List, Tuple, TYPE_CHECKING, NamedTuple, Callable
from enum import Enum, auto from enum import Enum, auto
from .util import bfh, bh2u from .util import bfh, bh2u
@ -13,7 +13,8 @@ from .lnutil import (make_commitment_output_to_remote_address, make_commitment_o
make_htlc_tx_witness, make_htlc_tx_with_open_channel, UpdateAddHtlc, make_htlc_tx_witness, make_htlc_tx_with_open_channel, UpdateAddHtlc,
LOCAL, REMOTE, make_htlc_output_witness_script, UnknownPaymentHash, LOCAL, REMOTE, make_htlc_output_witness_script, UnknownPaymentHash,
get_ordered_channel_configs, privkey_to_pubkey, get_per_commitment_secret_from_seed, get_ordered_channel_configs, privkey_to_pubkey, get_per_commitment_secret_from_seed,
RevocationStore, extract_ctn_from_tx_and_chan, UnableToDeriveSecret, SENT, RECEIVED) RevocationStore, extract_ctn_from_tx_and_chan, UnableToDeriveSecret, SENT, RECEIVED,
map_htlcs_to_ctx_output_idxs, Direction)
from .transaction import Transaction, TxOutput, construct_witness from .transaction import Transaction, TxOutput, construct_witness
from .simple_config import estimate_fee from .simple_config import estimate_fee
from .logging import get_logger from .logging import get_logger
@ -28,7 +29,7 @@ _logger = get_logger(__name__)
def create_sweeptxs_for_watchtower(chan: 'Channel', ctx: Transaction, per_commitment_secret: bytes, def create_sweeptxs_for_watchtower(chan: 'Channel', ctx: Transaction, per_commitment_secret: bytes,
sweep_address: str) -> Dict[str,Transaction]: sweep_address: str) -> List[Transaction]:
"""Presign sweeping transactions using the just received revoked pcs. """Presign sweeping transactions using the just received revoked pcs.
These will only be utilised if the remote breaches. These will only be utilised if the remote breaches.
Sweep 'to_local', and all the HTLCs (two cases: directly from ctx, or from HTLC tx). Sweep 'to_local', and all the HTLCs (two cases: directly from ctx, or from HTLC tx).
@ -46,8 +47,9 @@ def create_sweeptxs_for_watchtower(chan: 'Channel', ctx: Transaction, per_commit
witness_script = bh2u(make_commitment_output_to_local_witness_script( witness_script = bh2u(make_commitment_output_to_local_witness_script(
revocation_pubkey, to_self_delay, this_delayed_pubkey)) revocation_pubkey, to_self_delay, this_delayed_pubkey))
to_local_address = redeem_script_to_address('p2wsh', witness_script) to_local_address = redeem_script_to_address('p2wsh', witness_script)
output_idx = ctx.get_output_idx_from_address(to_local_address) output_idxs = ctx.get_output_idxs_from_address(to_local_address)
if output_idx is not None: if output_idxs:
output_idx = output_idxs.pop()
sweep_tx = create_sweeptx_ctx_to_local( sweep_tx = create_sweeptx_ctx_to_local(
sweep_address=sweep_address, sweep_address=sweep_address,
ctx=ctx, ctx=ctx,
@ -58,15 +60,15 @@ def create_sweeptxs_for_watchtower(chan: 'Channel', ctx: Transaction, per_commit
if sweep_tx: if sweep_tx:
txs.append(sweep_tx) txs.append(sweep_tx)
# HTLCs # HTLCs
def create_sweeptx_for_htlc(htlc: 'UpdateAddHtlc', is_received_htlc: bool) -> Tuple[Optional[Transaction], def create_sweeptx_for_htlc(*, htlc: 'UpdateAddHtlc', htlc_direction: Direction,
Optional[Transaction], ctx_output_idx: int) -> Optional[Transaction]:
Transaction]:
htlc_tx_witness_script, htlc_tx = make_htlc_tx_with_open_channel(chan=chan, htlc_tx_witness_script, htlc_tx = make_htlc_tx_with_open_channel(chan=chan,
pcp=pcp, pcp=pcp,
for_us=False, subject=REMOTE,
we_receive=not is_received_htlc, htlc_direction=htlc_direction,
commit=ctx, commit=ctx,
htlc=htlc) htlc=htlc,
ctx_output_idx=ctx_output_idx)
return create_sweeptx_that_spends_htlctx_that_spends_htlc_in_ctx( return create_sweeptx_that_spends_htlctx_that_spends_htlc_in_ctx(
'sweep_from_their_ctx_htlc_', 'sweep_from_their_ctx_htlc_',
to_self_delay=0, to_self_delay=0,
@ -77,23 +79,22 @@ def create_sweeptxs_for_watchtower(chan: 'Channel', ctx: Transaction, per_commit
is_revocation=True) is_revocation=True)
ctn = extract_ctn_from_tx_and_chan(ctx, chan) ctn = extract_ctn_from_tx_and_chan(ctx, chan)
# received HTLCs, in their ctx htlc_to_ctx_output_idx_map = map_htlcs_to_ctx_output_idxs(chan=chan,
received_htlcs = chan.included_htlcs(REMOTE, RECEIVED, ctn) ctx=ctx,
for htlc in received_htlcs: pcp=pcp,
secondstage_sweep_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=True) subject=REMOTE,
if secondstage_sweep_tx: ctn=ctn)
txs.append(secondstage_sweep_tx) for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items():
# offered HTLCs, in their ctx secondstage_sweep_tx = create_sweeptx_for_htlc(htlc=htlc,
offered_htlcs = chan.included_htlcs(REMOTE, SENT, ctn) htlc_direction=direction,
for htlc in offered_htlcs: ctx_output_idx=ctx_output_idx)
secondstage_sweep_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=False)
if secondstage_sweep_tx: if secondstage_sweep_tx:
txs.append(secondstage_sweep_tx) txs.append(secondstage_sweep_tx)
return txs return txs
def create_sweeptx_for_their_revoked_ctx(chan: 'Channel', ctx: Transaction, per_commitment_secret: bytes, def create_sweeptx_for_their_revoked_ctx(chan: 'Channel', ctx: Transaction, per_commitment_secret: bytes,
sweep_address: str) -> Dict[str,Transaction]: sweep_address: str) -> Optional[Callable[[], Optional[Transaction]]]:
# prep # prep
pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True) pcp = ecc.ECPrivkey(per_commitment_secret).get_public_key_bytes(compressed=True)
this_conf, other_conf = get_ordered_channel_configs(chan=chan, for_us=False) this_conf, other_conf = get_ordered_channel_configs(chan=chan, for_us=False)
@ -107,8 +108,9 @@ def create_sweeptx_for_their_revoked_ctx(chan: 'Channel', ctx: Transaction, per_
witness_script = bh2u(make_commitment_output_to_local_witness_script( witness_script = bh2u(make_commitment_output_to_local_witness_script(
revocation_pubkey, to_self_delay, this_delayed_pubkey)) revocation_pubkey, to_self_delay, this_delayed_pubkey))
to_local_address = redeem_script_to_address('p2wsh', witness_script) to_local_address = redeem_script_to_address('p2wsh', witness_script)
output_idx = ctx.get_output_idx_from_address(to_local_address) output_idxs = ctx.get_output_idxs_from_address(to_local_address)
if output_idx is not None: if output_idxs:
output_idx = output_idxs.pop()
sweep_tx = lambda: create_sweeptx_ctx_to_local( sweep_tx = lambda: create_sweeptx_ctx_to_local(
sweep_address=sweep_address, sweep_address=sweep_address,
ctx=ctx, ctx=ctx,
@ -117,9 +119,10 @@ def create_sweeptx_for_their_revoked_ctx(chan: 'Channel', ctx: Transaction, per_
privkey=other_revocation_privkey, privkey=other_revocation_privkey,
is_revocation=True) is_revocation=True)
return sweep_tx return sweep_tx
return None
def create_sweeptx_for_their_revoked_htlc(chan: 'Channel', ctx: Transaction, htlc_tx: Transaction, def create_sweeptx_for_their_revoked_htlc(chan: 'Channel', ctx: Transaction, htlc_tx: Transaction,
sweep_address: str) -> Dict[str,Transaction]: sweep_address: str) -> Optional[Tuple[str, int, int, Callable]]:
x = analyze_ctx(chan, ctx) x = analyze_ctx(chan, ctx)
if not x: if not x:
return return
@ -154,8 +157,8 @@ def create_sweeptx_for_their_revoked_htlc(chan: 'Channel', ctx: Transaction, htl
def create_sweeptxs_for_our_ctx(chan: 'Channel', ctx: Transaction, ctn: int, def create_sweeptxs_for_our_ctx(*, chan: 'Channel', ctx: Transaction,
sweep_address: str) -> Dict[str,Transaction]: sweep_address: str) -> Optional[Dict[str, Tuple]]:
"""Handle the case where we force close unilaterally with our latest ctx. """Handle the case where we force close unilaterally with our latest ctx.
Construct sweep txns for 'to_local', and for all HTLCs (2 txns each). Construct sweep txns for 'to_local', and for all HTLCs (2 txns each).
'to_local' can be swept even if this is a breach (by us), 'to_local' can be swept even if this is a breach (by us),
@ -181,8 +184,8 @@ def create_sweeptxs_for_our_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
to_remote_address = make_commitment_output_to_remote_address(their_payment_pubkey) to_remote_address = make_commitment_output_to_remote_address(their_payment_pubkey)
# test ctx # test ctx
_logger.debug(f'testing our ctx: {to_local_address} {to_remote_address}') _logger.debug(f'testing our ctx: {to_local_address} {to_remote_address}')
if ctx.get_output_idx_from_address(to_local_address) is None\ if not ctx.get_output_idxs_from_address(to_local_address) \
and ctx.get_output_idx_from_address(to_remote_address) is None: and not ctx.get_output_idxs_from_address(to_remote_address):
return return
# we have to_local, to_remote. # we have to_local, to_remote.
# other outputs are htlcs # other outputs are htlcs
@ -193,8 +196,9 @@ def create_sweeptxs_for_our_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
return {} return {}
txs = {} txs = {}
# to_local # to_local
output_idx = ctx.get_output_idx_from_address(to_local_address) output_idxs = ctx.get_output_idxs_from_address(to_local_address)
if output_idx is not None: if output_idxs:
output_idx = output_idxs.pop()
sweep_tx = lambda: create_sweeptx_ctx_to_local( sweep_tx = lambda: create_sweeptx_ctx_to_local(
sweep_address=sweep_address, sweep_address=sweep_address,
ctx=ctx, ctx=ctx,
@ -206,13 +210,14 @@ def create_sweeptxs_for_our_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
prevout = ctx.txid() + ':%d'%output_idx prevout = ctx.txid() + ':%d'%output_idx
txs[prevout] = ('our_ctx_to_local', to_self_delay, 0, sweep_tx) txs[prevout] = ('our_ctx_to_local', to_self_delay, 0, sweep_tx)
# HTLCs # HTLCs
def create_txns_for_htlc(htlc: 'UpdateAddHtlc', is_received_htlc: bool) -> Tuple[Optional[Transaction], Optional[Transaction]]: def create_txns_for_htlc(*, htlc: 'UpdateAddHtlc', htlc_direction: Direction,
if is_received_htlc: ctx_output_idx: int, htlc_relative_idx: int):
if htlc_direction == RECEIVED:
try: try:
preimage = chan.lnworker.get_preimage(htlc.payment_hash) preimage = chan.lnworker.get_preimage(htlc.payment_hash)
except UnknownPaymentHash as e: except UnknownPaymentHash as e:
_logger.info(f'trying to sweep htlc from our latest ctx but getting {repr(e)}') _logger.info(f'trying to sweep htlc from our latest ctx but getting {repr(e)}')
return None, None return
else: else:
preimage = None preimage = None
htlctx_witness_script, htlc_tx = create_htlctx_that_spends_from_our_ctx( htlctx_witness_script, htlc_tx = create_htlctx_that_spends_from_our_ctx(
@ -222,7 +227,9 @@ def create_sweeptxs_for_our_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
htlc=htlc, htlc=htlc,
local_htlc_privkey=our_htlc_privkey, local_htlc_privkey=our_htlc_privkey,
preimage=preimage, preimage=preimage,
is_received_htlc=is_received_htlc) htlc_direction=htlc_direction,
ctx_output_idx=ctx_output_idx,
htlc_relative_idx=htlc_relative_idx)
sweep_tx = lambda: create_sweeptx_that_spends_htlctx_that_spends_htlc_in_ctx( sweep_tx = lambda: create_sweeptx_that_spends_htlctx_that_spends_htlc_in_ctx(
'our_ctx_htlc_', 'our_ctx_htlc_',
to_self_delay=to_self_delay, to_self_delay=to_self_delay,
@ -237,12 +244,16 @@ def create_sweeptxs_for_our_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
# offered HTLCs, in our ctx --> "timeout" # offered HTLCs, in our ctx --> "timeout"
# received HTLCs, in our ctx --> "success" # received HTLCs, in our ctx --> "success"
offered_htlcs = chan.included_htlcs(LOCAL, SENT, ctn) # type: List[UpdateAddHtlc] htlc_to_ctx_output_idx_map = map_htlcs_to_ctx_output_idxs(chan=chan,
received_htlcs = chan.included_htlcs(LOCAL, RECEIVED, ctn) # type: List[UpdateAddHtlc] ctx=ctx,
for htlc in offered_htlcs: pcp=our_pcp,
create_txns_for_htlc(htlc, is_received_htlc=False) subject=LOCAL,
for htlc in received_htlcs: ctn=ctn)
create_txns_for_htlc(htlc, is_received_htlc=True) for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items():
create_txns_for_htlc(htlc=htlc,
htlc_direction=direction,
ctx_output_idx=ctx_output_idx,
htlc_relative_idx=htlc_relative_idx)
return txs return txs
def analyze_ctx(chan: 'Channel', ctx: Transaction): def analyze_ctx(chan: 'Channel', ctx: Transaction):
@ -273,8 +284,8 @@ def analyze_ctx(chan: 'Channel', ctx: Transaction):
return return
return ctn, their_pcp, is_revocation, per_commitment_secret return ctn, their_pcp, is_revocation, per_commitment_secret
def create_sweeptxs_for_their_ctx(chan: 'Channel', ctx: Transaction, ctn: int, def create_sweeptxs_for_their_ctx(*, chan: 'Channel', ctx: Transaction,
sweep_address: str) -> Dict[str,Transaction]: sweep_address: str) -> Optional[Dict[str,Tuple]]:
"""Handle the case when the remote force-closes with their ctx. """Handle the case when the remote force-closes with their ctx.
Sweep outputs that do not have a CSV delay ('to_remote' and first-stage HTLCs). Sweep outputs that do not have a CSV delay ('to_remote' and first-stage HTLCs).
Outputs with CSV delay ('to_local' and second-stage HTLCs) are redeemed by LNWatcher. Outputs with CSV delay ('to_local' and second-stage HTLCs) are redeemed by LNWatcher.
@ -295,8 +306,8 @@ def create_sweeptxs_for_their_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
to_remote_address = make_commitment_output_to_remote_address(our_payment_pubkey) to_remote_address = make_commitment_output_to_remote_address(our_payment_pubkey)
# test if this is their ctx # test if this is their ctx
_logger.debug(f'testing their ctx: {to_local_address} {to_remote_address}') _logger.debug(f'testing their ctx: {to_local_address} {to_remote_address}')
if ctx.get_output_idx_from_address(to_local_address) is None \ if not ctx.get_output_idxs_from_address(to_local_address) \
and ctx.get_output_idx_from_address(to_remote_address) is None: and not ctx.get_output_idxs_from_address(to_remote_address):
return return
if is_revocation: if is_revocation:
@ -315,8 +326,9 @@ def create_sweeptxs_for_their_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
assert our_payment_pubkey == our_payment_privkey.get_public_key_bytes(compressed=True) assert our_payment_pubkey == our_payment_privkey.get_public_key_bytes(compressed=True)
# to_local is handled by lnwatcher # to_local is handled by lnwatcher
# to_remote # to_remote
output_idx = ctx.get_output_idx_from_address(to_remote_address) output_idxs = ctx.get_output_idxs_from_address(to_remote_address)
if output_idx is not None: if output_idxs:
output_idx = output_idxs.pop()
prevout = ctx.txid() + ':%d'%output_idx prevout = ctx.txid() + ':%d'%output_idx
sweep_tx = lambda: create_sweeptx_their_ctx_to_remote( sweep_tx = lambda: create_sweeptx_their_ctx_to_remote(
sweep_address=sweep_address, sweep_address=sweep_address,
@ -325,13 +337,14 @@ def create_sweeptxs_for_their_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
our_payment_privkey=our_payment_privkey) our_payment_privkey=our_payment_privkey)
txs[prevout] = ('their_ctx_to_remote', 0, 0, sweep_tx) txs[prevout] = ('their_ctx_to_remote', 0, 0, sweep_tx)
# HTLCs # HTLCs
def create_sweeptx_for_htlc(htlc: 'UpdateAddHtlc', is_received_htlc: bool) -> Optional[Transaction]: def create_sweeptx_for_htlc(htlc: 'UpdateAddHtlc', is_received_htlc: bool,
ctx_output_idx: int) -> None:
if not is_received_htlc and not is_revocation: if not is_received_htlc and not is_revocation:
try: try:
preimage = chan.lnworker.get_preimage(htlc.payment_hash) preimage = chan.lnworker.get_preimage(htlc.payment_hash)
except UnknownPaymentHash as e: except UnknownPaymentHash as e:
_logger.info(f'trying to sweep htlc from their latest ctx but getting {repr(e)}') _logger.info(f'trying to sweep htlc from their latest ctx but getting {repr(e)}')
return None return
else: else:
preimage = None preimage = None
htlc_output_witness_script = make_htlc_output_witness_script( htlc_output_witness_script = make_htlc_output_witness_script(
@ -341,50 +354,51 @@ def create_sweeptxs_for_their_ctx(chan: 'Channel', ctx: Transaction, ctn: int,
local_htlc_pubkey=their_htlc_pubkey, local_htlc_pubkey=their_htlc_pubkey,
payment_hash=htlc.payment_hash, payment_hash=htlc.payment_hash,
cltv_expiry=htlc.cltv_expiry) cltv_expiry=htlc.cltv_expiry)
htlc_address = redeem_script_to_address('p2wsh', bh2u(htlc_output_witness_script))
# FIXME handle htlc_address collision cltv_expiry = htlc.cltv_expiry if is_received_htlc and not is_revocation else 0
# also: https://github.com/lightningnetwork/lightning-rfc/issues/448 prevout = ctx.txid() + ':%d'%ctx_output_idx
output_idx = ctx.get_output_idx_from_address(htlc_address) sweep_tx = lambda: create_sweeptx_their_ctx_htlc(
if output_idx is not None: ctx=ctx,
cltv_expiry = htlc.cltv_expiry if is_received_htlc and not is_revocation else 0 witness_script=htlc_output_witness_script,
prevout = ctx.txid() + ':%d'%output_idx sweep_address=sweep_address,
sweep_tx = lambda: create_sweeptx_their_ctx_htlc( preimage=preimage,
ctx=ctx, output_idx=ctx_output_idx,
witness_script=htlc_output_witness_script, privkey=our_revocation_privkey if is_revocation else our_htlc_privkey.get_secret_bytes(),
sweep_address=sweep_address, is_revocation=is_revocation,
preimage=preimage, cltv_expiry=cltv_expiry)
output_idx=output_idx, name = f'their_ctx_htlc_{ctx_output_idx}'
privkey=our_revocation_privkey if is_revocation else our_htlc_privkey.get_secret_bytes(), txs[prevout] = (name, 0, cltv_expiry, sweep_tx)
is_revocation=is_revocation,
cltv_expiry=cltv_expiry)
name = f'their_ctx_htlc_{output_idx}'
txs[prevout] = (name, 0, cltv_expiry, sweep_tx)
# received HTLCs, in their ctx --> "timeout" # received HTLCs, in their ctx --> "timeout"
received_htlcs = chan.included_htlcs(REMOTE, RECEIVED, ctn=ctn) # type: List[UpdateAddHtlc]
for htlc in received_htlcs:
create_sweeptx_for_htlc(htlc, is_received_htlc=True)
# offered HTLCs, in their ctx --> "success" # offered HTLCs, in their ctx --> "success"
offered_htlcs = chan.included_htlcs(REMOTE, SENT, ctn=ctn) # type: List[UpdateAddHtlc] htlc_to_ctx_output_idx_map = map_htlcs_to_ctx_output_idxs(chan=chan,
for htlc in offered_htlcs: ctx=ctx,
create_sweeptx_for_htlc(htlc, is_received_htlc=False) pcp=their_pcp,
subject=REMOTE,
ctn=ctn)
for (direction, htlc), (ctx_output_idx, htlc_relative_idx) in htlc_to_ctx_output_idx_map.items():
create_sweeptx_for_htlc(htlc=htlc,
is_received_htlc=direction == RECEIVED,
ctx_output_idx=ctx_output_idx)
return txs return txs
def create_htlctx_that_spends_from_our_ctx(chan: 'Channel', our_pcp: bytes, def create_htlctx_that_spends_from_our_ctx(chan: 'Channel', our_pcp: bytes,
ctx: Transaction, htlc: 'UpdateAddHtlc', ctx: Transaction, htlc: 'UpdateAddHtlc',
local_htlc_privkey: bytes, preimage: Optional[bytes], local_htlc_privkey: bytes, preimage: Optional[bytes],
is_received_htlc: bool) -> Tuple[bytes, Transaction]: htlc_direction: Direction, htlc_relative_idx: int,
assert is_received_htlc == bool(preimage), 'preimage is required iff htlc is received' ctx_output_idx: int) -> Tuple[bytes, Transaction]:
assert (htlc_direction == RECEIVED) == bool(preimage), 'preimage is required iff htlc is received'
preimage = preimage or b'' preimage = preimage or b''
witness_script, htlc_tx = make_htlc_tx_with_open_channel(chan=chan, witness_script, htlc_tx = make_htlc_tx_with_open_channel(chan=chan,
pcp=our_pcp, pcp=our_pcp,
for_us=True, subject=LOCAL,
we_receive=is_received_htlc, htlc_direction=htlc_direction,
commit=ctx, commit=ctx,
htlc=htlc, htlc=htlc,
name=f'our_ctx_htlc_tx_{bh2u(htlc.payment_hash)}', ctx_output_idx=ctx_output_idx,
cltv_expiry=0 if is_received_htlc else htlc.cltv_expiry) name=f'our_ctx_{ctx_output_idx}_htlc_tx_{bh2u(htlc.payment_hash)}')
remote_htlc_sig = chan.get_remote_htlc_sig_for_htlc(htlc, we_receive=is_received_htlc, ctx=ctx) remote_htlc_sig = chan.get_remote_htlc_sig_for_htlc(htlc_relative_idx=htlc_relative_idx)
local_htlc_sig = bfh(htlc_tx.sign_txin(0, local_htlc_privkey)) local_htlc_sig = bfh(htlc_tx.sign_txin(0, local_htlc_privkey))
txin = htlc_tx.inputs()[0] txin = htlc_tx.inputs()[0]
witness_program = bfh(Transaction.get_preimage_script(txin)) witness_program = bfh(Transaction.get_preimage_script(txin))

View file

@ -5,7 +5,7 @@
from enum import IntFlag, IntEnum from enum import IntFlag, IntEnum
import json import json
from collections import namedtuple from collections import namedtuple
from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict from typing import NamedTuple, List, Tuple, Mapping, Optional, TYPE_CHECKING, Union, Dict, Set
import re import re
from .util import bfh, bh2u, inv_dict from .util import bfh, bh2u, inv_dict
@ -50,7 +50,7 @@ class LocalConfig(NamedTuple):
funding_locked_received: bool funding_locked_received: bool
was_announced: bool was_announced: bool
current_commitment_signature: Optional[bytes] current_commitment_signature: Optional[bytes]
current_htlc_signatures: List[bytes] current_htlc_signatures: bytes
class RemoteConfig(NamedTuple): class RemoteConfig(NamedTuple):
@ -383,10 +383,63 @@ def get_ordered_channel_configs(chan: 'Channel', for_us: bool) -> Tuple[Union[Lo
return conf, other_conf return conf, other_conf
def make_htlc_tx_with_open_channel(chan: 'Channel', pcp: bytes, for_us: bool, def possible_output_idxs_of_htlc_in_ctx(*, chan: 'Channel', pcp: bytes, subject: 'HTLCOwner',
we_receive: bool, commit: Transaction, htlc_direction: 'Direction', ctx: Transaction,
htlc: 'UpdateAddHtlc', name: str = None, cltv_expiry: int = 0) -> Tuple[bytes, Transaction]: htlc: 'UpdateAddHtlc') -> Set[int]:
amount_msat, cltv_expiry, payment_hash = htlc.amount_msat, htlc.cltv_expiry, htlc.payment_hash amount_msat, cltv_expiry, payment_hash = htlc.amount_msat, htlc.cltv_expiry, htlc.payment_hash
for_us = subject == LOCAL
conf, other_conf = get_ordered_channel_configs(chan=chan, for_us=for_us)
other_revocation_pubkey = derive_blinded_pubkey(other_conf.revocation_basepoint.pubkey, pcp)
other_htlc_pubkey = derive_pubkey(other_conf.htlc_basepoint.pubkey, pcp)
htlc_pubkey = derive_pubkey(conf.htlc_basepoint.pubkey, pcp)
preimage_script = make_htlc_output_witness_script(is_received_htlc=htlc_direction == RECEIVED,
remote_revocation_pubkey=other_revocation_pubkey,
remote_htlc_pubkey=other_htlc_pubkey,
local_htlc_pubkey=htlc_pubkey,
payment_hash=payment_hash,
cltv_expiry=cltv_expiry)
htlc_address = redeem_script_to_address('p2wsh', bh2u(preimage_script))
candidates = ctx.get_output_idxs_from_address(htlc_address)
return {output_idx for output_idx in candidates
if ctx.outputs()[output_idx].value == htlc.amount_msat // 1000}
def map_htlcs_to_ctx_output_idxs(*, chan: 'Channel', ctx: Transaction, pcp: bytes,
subject: 'HTLCOwner', ctn: int) -> Dict[Tuple['Direction', 'UpdateAddHtlc'], Tuple[int, int]]:
"""Returns a dict from (htlc_dir, htlc) to (ctx_output_idx, htlc_relative_idx)"""
htlc_to_ctx_output_idx_map = {} # type: Dict[Tuple[Direction, UpdateAddHtlc], int]
unclaimed_ctx_output_idxs = set(range(len(ctx.outputs())))
offered_htlcs = chan.included_htlcs(subject, SENT, ctn=ctn)
offered_htlcs.sort(key=lambda htlc: htlc.cltv_expiry)
received_htlcs = chan.included_htlcs(subject, RECEIVED, ctn=ctn)
received_htlcs.sort(key=lambda htlc: htlc.cltv_expiry)
for direction, htlcs in zip([SENT, RECEIVED], [offered_htlcs, received_htlcs]):
for htlc in htlcs:
cands = sorted(possible_output_idxs_of_htlc_in_ctx(chan=chan,
pcp=pcp,
subject=subject,
htlc_direction=direction,
ctx=ctx,
htlc=htlc))
for ctx_output_idx in cands:
if ctx_output_idx in unclaimed_ctx_output_idxs:
unclaimed_ctx_output_idxs.discard(ctx_output_idx)
htlc_to_ctx_output_idx_map[(direction, htlc)] = ctx_output_idx
break
# calc htlc_relative_idx
inverse_map = {ctx_output_idx: (direction, htlc)
for ((direction, htlc), ctx_output_idx) in htlc_to_ctx_output_idx_map.items()}
return {inverse_map[ctx_output_idx]: (ctx_output_idx, htlc_relative_idx)
for htlc_relative_idx, ctx_output_idx in enumerate(sorted(inverse_map))}
def make_htlc_tx_with_open_channel(*, chan: 'Channel', pcp: bytes, subject: 'HTLCOwner',
htlc_direction: 'Direction', commit: Transaction, ctx_output_idx: int,
htlc: 'UpdateAddHtlc', name: str = None) -> Tuple[bytes, Transaction]:
amount_msat, cltv_expiry, payment_hash = htlc.amount_msat, htlc.cltv_expiry, htlc.payment_hash
for_us = subject == LOCAL
conf, other_conf = get_ordered_channel_configs(chan=chan, for_us=for_us) conf, other_conf = get_ordered_channel_configs(chan=chan, for_us=for_us)
delayedpubkey = derive_pubkey(conf.delayed_basepoint.pubkey, pcp) delayedpubkey = derive_pubkey(conf.delayed_basepoint.pubkey, pcp)
@ -395,8 +448,8 @@ def make_htlc_tx_with_open_channel(chan: 'Channel', pcp: bytes, for_us: bool,
htlc_pubkey = derive_pubkey(conf.htlc_basepoint.pubkey, pcp) htlc_pubkey = derive_pubkey(conf.htlc_basepoint.pubkey, pcp)
# HTLC-success for the HTLC spending from a received HTLC output # HTLC-success for the HTLC spending from a received HTLC output
# if we do not receive, and the commitment tx is not for us, they receive, so it is also an HTLC-success # if we do not receive, and the commitment tx is not for us, they receive, so it is also an HTLC-success
is_htlc_success = for_us == we_receive is_htlc_success = htlc_direction == RECEIVED
script, htlc_tx_output = make_htlc_tx_output( witness_script_of_htlc_tx_output, htlc_tx_output = make_htlc_tx_output(
amount_msat = amount_msat, amount_msat = amount_msat,
local_feerate = chan.get_next_feerate(LOCAL if for_us else REMOTE), local_feerate = chan.get_next_feerate(LOCAL if for_us else REMOTE),
revocationpubkey=other_revocation_pubkey, revocationpubkey=other_revocation_pubkey,
@ -409,20 +462,15 @@ def make_htlc_tx_with_open_channel(chan: 'Channel', pcp: bytes, for_us: bool,
local_htlc_pubkey=htlc_pubkey, local_htlc_pubkey=htlc_pubkey,
payment_hash=payment_hash, payment_hash=payment_hash,
cltv_expiry=cltv_expiry) cltv_expiry=cltv_expiry)
htlc_address = redeem_script_to_address('p2wsh', bh2u(preimage_script))
# FIXME handle htlc_address collision
# also: https://github.com/lightningnetwork/lightning-rfc/issues/448
prevout_idx = commit.get_output_idx_from_address(htlc_address)
assert prevout_idx is not None, (htlc_address, commit.outputs(), extract_ctn_from_tx_and_chan(commit, chan))
htlc_tx_inputs = make_htlc_tx_inputs( htlc_tx_inputs = make_htlc_tx_inputs(
commit.txid(), prevout_idx, commit.txid(), ctx_output_idx,
amount_msat=amount_msat, amount_msat=amount_msat,
witness_script=bh2u(preimage_script)) witness_script=bh2u(preimage_script))
if is_htlc_success: if is_htlc_success:
cltv_expiry = 0 cltv_expiry = 0
htlc_tx = make_htlc_tx(cltv_expiry, inputs=htlc_tx_inputs, output=htlc_tx_output, htlc_tx = make_htlc_tx(cltv_expiry, inputs=htlc_tx_inputs, output=htlc_tx_output,
name=name, cltv_expiry=cltv_expiry) name=name, cltv_expiry=cltv_expiry)
return script, htlc_tx return witness_script_of_htlc_tx_output, htlc_tx
def make_funding_input(local_funding_pubkey: bytes, remote_funding_pubkey: bytes, def make_funding_input(local_funding_pubkey: bytes, remote_funding_pubkey: bytes,
funding_pos: int, funding_txid: bytes, funding_sat: int): funding_pos: int, funding_txid: bytes, funding_sat: int):

View file

@ -31,7 +31,8 @@ import struct
import traceback import traceback
import sys import sys
from typing import (Sequence, Union, NamedTuple, Tuple, Optional, Iterable, from typing import (Sequence, Union, NamedTuple, Tuple, Optional, Iterable,
Callable, List, Dict) Callable, List, Dict, Set)
from collections import defaultdict
from . import ecc, bitcoin, constants, segwit_addr from . import ecc, bitcoin, constants, segwit_addr
from .util import profiler, to_bytes, bh2u, bfh from .util import profiler, to_bytes, bh2u, bfh
@ -1174,7 +1175,7 @@ class Transaction:
sig = self.sign_txin(i, sec, bip143_shared_txdigest_fields=bip143_shared_txdigest_fields) sig = self.sign_txin(i, sec, bip143_shared_txdigest_fields=bip143_shared_txdigest_fields)
self.add_signature_to_txin(i, j, sig) self.add_signature_to_txin(i, j, sig)
_logger.info(f"is_complete {self.is_complete()}") _logger.debug(f"is_complete {self.is_complete()}")
self.raw = self.serialize() self.raw = self.serialize()
def sign_txin(self, txin_index, privkey_bytes, *, bip143_shared_txdigest_fields=None) -> str: def sign_txin(self, txin_index, privkey_bytes, *, bip143_shared_txdigest_fields=None) -> str:
@ -1201,27 +1202,24 @@ class Transaction:
return (addr in (o.address for o in self.outputs())) \ return (addr in (o.address for o in self.outputs())) \
or (addr in (txin.get("address") for txin in self.inputs())) or (addr in (txin.get("address") for txin in self.inputs()))
def get_output_idx_from_scriptpubkey(self, script: str) -> Optional[int]: def get_output_idxs_from_scriptpubkey(self, script: str) -> Set[int]:
"""Returns the index of an output with given script. """Returns the set indices of outputs with given script."""
If there are no such outputs, returns None;
if there are multiple, returns one of them.
"""
assert isinstance(script, str) # hex assert isinstance(script, str) # hex
# build cache if there isn't one yet # build cache if there isn't one yet
# note: can become stale and return incorrect data # note: can become stale and return incorrect data
# if the tx is modified later; that's out of scope. # if the tx is modified later; that's out of scope.
if not hasattr(self, '_script_to_output_idx'): if not hasattr(self, '_script_to_output_idx'):
d = {} d = defaultdict(set)
for output_idx, o in enumerate(self.outputs()): for output_idx, o in enumerate(self.outputs()):
o_script = self.pay_script(o.type, o.address) o_script = self.pay_script(o.type, o.address)
assert isinstance(o_script, str) assert isinstance(o_script, str)
d[o_script] = output_idx d[o_script].add(output_idx)
self._script_to_output_idx = d self._script_to_output_idx = d
return self._script_to_output_idx.get(script) return set(self._script_to_output_idx[script]) # copy
def get_output_idx_from_address(self, addr: str) -> Optional: def get_output_idxs_from_address(self, addr: str) -> Set[int]:
script = bitcoin.address_to_script(addr) script = bitcoin.address_to_script(addr)
return self.get_output_idx_from_scriptpubkey(script) return self.get_output_idxs_from_scriptpubkey(script)
def as_dict(self): def as_dict(self):
if self.raw is None: if self.raw is None: