lnchan refactor

- replace undoing logic with new HTLCManager class
- separate SENT/RECEIVED
- move UpdateAddHtlc to lnutil
This commit is contained in:
Janus 2019-01-21 21:27:27 +01:00 committed by ThomasV
parent ef88bb1c28
commit e56e849505
11 changed files with 706 additions and 398 deletions

View file

@ -5,9 +5,9 @@ import PyQt5.QtWidgets as QtWidgets
import PyQt5.QtCore as QtCore import PyQt5.QtCore as QtCore
from electrum.i18n import _ from electrum.i18n import _
from electrum.lnchan import UpdateAddHtlc, HTLCOwner
from electrum.util import bh2u, format_time from electrum.util import bh2u, format_time
from electrum.lnutil import format_short_channel_id, SENT, RECEIVED from electrum.lnutil import format_short_channel_id, LOCAL, REMOTE, UpdateAddHtlc, Direction
from electrum.lnchan import htlcsum
from electrum.lnaddr import LnAddr, lndecode from electrum.lnaddr import LnAddr, lndecode
from electrum.bitcoin import COIN from electrum.bitcoin import COIN
@ -30,8 +30,8 @@ class LinkedLabel(QtWidgets.QLabel):
self.linkActivated.connect(on_clicked) self.linkActivated.connect(on_clicked)
class ChannelDetailsDialog(QtWidgets.QDialog): class ChannelDetailsDialog(QtWidgets.QDialog):
def make_htlc_item(self, i: UpdateAddHtlc, direction: HTLCOwner) -> HTLCItem: def make_htlc_item(self, i: UpdateAddHtlc, direction: Direction) -> HTLCItem:
it = HTLCItem(_('Sent HTLC with ID {}' if SENT == direction else 'Received HTLC with ID {}').format(i.htlc_id)) it = HTLCItem(_('Sent HTLC with ID {}' if Direction.SENT == direction else 'Received HTLC with ID {}').format(i.htlc_id))
it.appendRow([HTLCItem(_('Amount')),HTLCItem(self.format(i.amount_msat))]) it.appendRow([HTLCItem(_('Amount')),HTLCItem(self.format(i.amount_msat))])
it.appendRow([HTLCItem(_('CLTV expiry')),HTLCItem(str(i.cltv_expiry))]) it.appendRow([HTLCItem(_('CLTV expiry')),HTLCItem(str(i.cltv_expiry))])
it.appendRow([HTLCItem(_('Payment hash')),HTLCItem(bh2u(i.payment_hash))]) it.appendRow([HTLCItem(_('Payment hash')),HTLCItem(bh2u(i.payment_hash))])
@ -45,7 +45,7 @@ class ChannelDetailsDialog(QtWidgets.QDialog):
invoice.appendRow([HTLCItem(_('Date')), HTLCItem(format_time(lnaddr.date))]) invoice.appendRow([HTLCItem(_('Date')), HTLCItem(format_time(lnaddr.date))])
it.appendRow([invoice]) it.appendRow([invoice])
def make_inflight(self, lnaddr, i: UpdateAddHtlc, direction: HTLCOwner) -> HTLCItem: def make_inflight(self, lnaddr, i: UpdateAddHtlc, direction: Direction) -> HTLCItem:
it = self.make_htlc_item(i, direction) it = self.make_htlc_item(i, direction)
self.append_lnaddr(it, lnaddr) self.append_lnaddr(it, lnaddr)
return it return it
@ -99,23 +99,23 @@ class ChannelDetailsDialog(QtWidgets.QDialog):
dest_mapping = self.keyname_rows[to] dest_mapping = self.keyname_rows[to]
dest_mapping[payment_hash] = len(dest_mapping) dest_mapping[payment_hash] = len(dest_mapping)
ln_payment_completed = QtCore.pyqtSignal(str, float, HTLCOwner, UpdateAddHtlc, bytes, bytes) ln_payment_completed = QtCore.pyqtSignal(str, float, Direction, UpdateAddHtlc, bytes, bytes)
htlc_added = QtCore.pyqtSignal(str, UpdateAddHtlc, LnAddr, HTLCOwner) htlc_added = QtCore.pyqtSignal(str, UpdateAddHtlc, LnAddr, Direction)
@QtCore.pyqtSlot(str, UpdateAddHtlc, LnAddr, HTLCOwner) @QtCore.pyqtSlot(str, UpdateAddHtlc, LnAddr, Direction)
def do_htlc_added(self, evtname, htlc, lnaddr, direction): def do_htlc_added(self, evtname, htlc, lnaddr, direction):
mapping = self.keyname_rows['inflight'] mapping = self.keyname_rows['inflight']
mapping[htlc.payment_hash] = len(mapping) mapping[htlc.payment_hash] = len(mapping)
self.folders['inflight'].appendRow(self.make_inflight(lnaddr, htlc, direction)) self.folders['inflight'].appendRow(self.make_inflight(lnaddr, htlc, direction))
@QtCore.pyqtSlot(str, float, HTLCOwner, UpdateAddHtlc, bytes, bytes) @QtCore.pyqtSlot(str, float, Direction, UpdateAddHtlc, bytes, bytes)
def do_ln_payment_completed(self, evtname, date, direction, htlc, preimage, chan_id): def do_ln_payment_completed(self, evtname, date, direction, htlc, preimage, chan_id):
self.move('inflight', 'settled', htlc.payment_hash) self.move('inflight', 'settled', htlc.payment_hash)
self.update_sent_received() self.update_sent_received()
def update_sent_received(self): def update_sent_received(self):
self.sent_label.setText(str(sum(self.chan.settled[SENT]))) self.sent_label.setText(str(htlcsum(self.hm.settled_htlcs_by(LOCAL))))
self.received_label.setText(str(sum(self.chan.settled[RECEIVED]))) self.received_label.setText(str(htlcsum(self.hm.settled_htlcs_by(REMOTE))))
@QtCore.pyqtSlot(str) @QtCore.pyqtSlot(str)
def show_tx(self, link_text: str): def show_tx(self, link_text: str):

View file

@ -30,8 +30,9 @@ class ChannelsList(MyTreeView):
for subject in (REMOTE, LOCAL): for subject in (REMOTE, LOCAL):
bal_minus_htlcs = chan.balance_minus_outgoing_htlcs(subject)//1000 bal_minus_htlcs = chan.balance_minus_outgoing_htlcs(subject)//1000
label = self.parent.format_amount(bal_minus_htlcs) label = self.parent.format_amount(bal_minus_htlcs)
bal_other = chan.balance(-subject)//1000 other = subject.inverted()
bal_minus_htlcs_other = chan.balance_minus_outgoing_htlcs(-subject)//1000 bal_other = chan.balance(other)//1000
bal_minus_htlcs_other = chan.balance_minus_outgoing_htlcs(other)//1000
if bal_other != bal_minus_htlcs_other: if bal_other != bal_minus_htlcs_other:
label += ' (+' + self.parent.format_amount(bal_other - bal_minus_htlcs_other) + ')' label += ' (+' + self.parent.format_amount(bal_other - bal_minus_htlcs_other) + ')'
labels[subject] = label labels[subject] = label

View file

@ -25,8 +25,8 @@ from .util import PrintError, bh2u, print_error, bfh, log_exceptions, list_enabl
from .transaction import Transaction, TxOutput from .transaction import Transaction, TxOutput
from .lnonion import (new_onion_packet, decode_onion_error, OnionFailureCode, calc_hops_data_for_payment, from .lnonion import (new_onion_packet, decode_onion_error, OnionFailureCode, calc_hops_data_for_payment,
process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage) process_onion_packet, OnionPacket, construct_onion_error, OnionRoutingFailureMessage)
from .lnchan import Channel, RevokeAndAck, htlcsum, UpdateAddHtlc from .lnchan import Channel, RevokeAndAck, htlcsum
from .lnutil import (Outpoint, LocalConfig, RECEIVED, from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore, RemoteConfig, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore,
funding_output_script, get_per_commitment_secret_from_seed, funding_output_script, get_per_commitment_secret_from_seed,
secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures, secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures,
@ -397,20 +397,20 @@ class Peer(PrintError):
htlc_basepoint=keypair_generator(LnKeyFamily.HTLC_BASE), htlc_basepoint=keypair_generator(LnKeyFamily.HTLC_BASE),
delayed_basepoint=keypair_generator(LnKeyFamily.DELAY_BASE), delayed_basepoint=keypair_generator(LnKeyFamily.DELAY_BASE),
revocation_basepoint=keypair_generator(LnKeyFamily.REVOCATION_BASE), revocation_basepoint=keypair_generator(LnKeyFamily.REVOCATION_BASE),
to_self_delay=143, to_self_delay=9,
dust_limit_sat=546, dust_limit_sat=546,
max_htlc_value_in_flight_msat=0xffffffffffffffff, max_htlc_value_in_flight_msat=0xffffffffffffffff,
max_accepted_htlcs=5, max_accepted_htlcs=5,
initial_msat=initial_msat, initial_msat=initial_msat,
ctn=-1, ctn=-1,
next_htlc_id=0, next_htlc_id=0,
amount_msat=initial_msat,
reserve_sat=546, reserve_sat=546,
per_commitment_secret_seed=keypair_generator(LnKeyFamily.REVOCATION_ROOT).privkey, per_commitment_secret_seed=keypair_generator(LnKeyFamily.REVOCATION_ROOT).privkey,
funding_locked_received=False, funding_locked_received=False,
was_announced=False, was_announced=False,
current_commitment_signature=None, current_commitment_signature=None,
current_htlc_signatures=[], current_htlc_signatures=[],
got_sig_for_next=False,
) )
return local_config return local_config
@ -472,7 +472,6 @@ class Peer(PrintError):
max_accepted_htlcs=int.from_bytes(payload["max_accepted_htlcs"], 'big'), max_accepted_htlcs=int.from_bytes(payload["max_accepted_htlcs"], 'big'),
initial_msat=push_msat, initial_msat=push_msat,
ctn = -1, ctn = -1,
amount_msat=push_msat,
next_htlc_id = 0, next_htlc_id = 0,
reserve_sat = remote_reserve_sat, reserve_sat = remote_reserve_sat,
@ -517,9 +516,11 @@ class Peer(PrintError):
# broadcast funding tx # broadcast funding tx
await self.network.broadcast_transaction(funding_tx) await self.network.broadcast_transaction(funding_tx)
chan.remote_commitment_to_be_revoked = chan.pending_commitment(REMOTE) chan.remote_commitment_to_be_revoked = chan.pending_commitment(REMOTE)
chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0) chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0, current_per_commitment_point=remote_per_commitment_point, next_per_commitment_point=None)
chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig, got_sig_for_next=False)
chan.set_state('OPENING') chan.set_state('OPENING')
chan.set_remote_commitment()
chan.set_local_commitment(chan.current_commitment(LOCAL))
return chan return chan
async def on_open_channel(self, payload): async def on_open_channel(self, payload):
@ -579,7 +580,6 @@ class Peer(PrintError):
max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'), max_accepted_htlcs=int.from_bytes(payload['max_accepted_htlcs'], 'big'),
initial_msat=remote_balance_sat, initial_msat=remote_balance_sat,
ctn = -1, ctn = -1,
amount_msat=remote_balance_sat,
next_htlc_id = 0, next_htlc_id = 0,
reserve_sat = remote_reserve_sat, reserve_sat = remote_reserve_sat,
@ -605,7 +605,7 @@ class Peer(PrintError):
) )
chan.set_state('OPENING') chan.set_state('OPENING')
chan.remote_commitment_to_be_revoked = chan.pending_commitment(REMOTE) chan.remote_commitment_to_be_revoked = chan.pending_commitment(REMOTE)
chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0) chan.config[REMOTE] = chan.config[REMOTE]._replace(ctn=0, current_per_commitment_point=payload['first_per_commitment_point'], next_per_commitment_point=None)
chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig) chan.config[LOCAL] = chan.config[LOCAL]._replace(ctn=0, current_commitment_signature=remote_sig)
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
self.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str()) self.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str())
@ -732,7 +732,7 @@ class Peer(PrintError):
if not chan.config[LOCAL].funding_locked_received: if not chan.config[LOCAL].funding_locked_received:
our_next_point = chan.config[REMOTE].next_per_commitment_point our_next_point = chan.config[REMOTE].next_per_commitment_point
their_next_point = payload["next_per_commitment_point"] their_next_point = payload["next_per_commitment_point"]
new_remote_state = chan.config[REMOTE]._replace(next_per_commitment_point=their_next_point, current_per_commitment_point=our_next_point) new_remote_state = chan.config[REMOTE]._replace(next_per_commitment_point=their_next_point)
new_local_state = chan.config[LOCAL]._replace(funding_locked_received = True) new_local_state = chan.config[LOCAL]._replace(funding_locked_received = True)
chan.config[REMOTE]=new_remote_state chan.config[REMOTE]=new_remote_state
chan.config[LOCAL]=new_local_state chan.config[LOCAL]=new_local_state

View file

@ -27,24 +27,25 @@ import binascii
import json import json
from enum import Enum, auto from enum import Enum, auto
from typing import Optional, Dict, List, Tuple, NamedTuple, Set, Callable, Iterable, Sequence from typing import Optional, Dict, List, Tuple, NamedTuple, Set, Callable, Iterable, Sequence
from copy import deepcopy
from . import ecc
from .util import bfh, PrintError, bh2u from .util import bfh, PrintError, bh2u
from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS from .bitcoin import TYPE_SCRIPT, TYPE_ADDRESS
from .bitcoin import redeem_script_to_address from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d from .crypto import sha256, sha256d
from . import ecc from .simple_config import get_config
from .lnutil import Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints, RevocationStore
from .lnutil import get_per_commitment_secret_from_seed
from .lnutil import secret_to_pubkey, derive_privkey, derive_pubkey, derive_blinded_pubkey
from .lnutil import sign_and_get_sig_string
from .lnutil import make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc
from .lnutil import HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT
from .lnutil import funding_output_script, LOCAL, REMOTE, HTLCOwner, make_closing_tx, make_commitment_outputs
from .lnutil import ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script
from .transaction import Transaction from .transaction import Transaction
from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints,
get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx,
sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey,
make_htlc_tx_with_open_channel, make_commitment, make_received_htlc, make_offered_htlc,
HTLC_TIMEOUT_WEIGHT, HTLC_SUCCESS_WEIGHT, extract_ctn_from_tx_and_chan, UpdateAddHtlc,
funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script)
from .lnsweep import create_sweeptxs_for_their_just_revoked_ctx from .lnsweep import create_sweeptxs_for_their_just_revoked_ctx
from .lnsweep import create_sweeptxs_for_our_latest_ctx, create_sweeptxs_for_their_latest_ctx from .lnsweep import create_sweeptxs_for_our_latest_ctx, create_sweeptxs_for_their_latest_ctx
from .lnhtlc import HTLCManager
class ChannelJsonEncoder(json.JSONEncoder): class ChannelJsonEncoder(json.JSONEncoder):
@ -83,22 +84,6 @@ class FeeUpdate(defaultdict):
return self.rate return self.rate
# implicit return None # implicit return None
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
"""
This whole class body is so that if you pass a hex-string as payment_hash,
it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
"""
__slots__ = ()
def __new__(cls, *args, **kwargs):
if len(args) > 0:
args = list(args)
if type(args[1]) is str:
args[1] = bfh(args[1])
return super().__new__(cls, *args)
if type(kwargs['payment_hash']) is str:
kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
return super().__new__(cls, **kwargs)
def decodeAll(d, local): def decodeAll(d, local):
for k, v in d.items(): for k, v in d.items():
if k == 'revocation_store': if k == 'revocation_store':
@ -124,20 +109,6 @@ def str_bytes_dict_from_save(x):
def str_bytes_dict_to_save(x): def str_bytes_dict_to_save(x):
return {str(k): bh2u(v) for k, v in x.items()} return {str(k): bh2u(v) for k, v in x.items()}
class HtlcChanges(NamedTuple):
# ints are htlc ids
adds: Dict[int, UpdateAddHtlc]
settles: Set[int]
fails: Set[int]
locked_in: Set[int]
@staticmethod
def new():
"""
Since we can't use default arguments for these types (they would be shared among instances)
"""
return HtlcChanges({}, set(), set(), set())
class Channel(PrintError): class Channel(PrintError):
def diagnostic_name(self): def diagnostic_name(self):
if self.name: if self.name:
@ -147,7 +118,7 @@ class Channel(PrintError):
except: except:
return super().diagnostic_name() return super().diagnostic_name()
def __init__(self, state, sweep_address = None, name = None, payment_completed : Optional[Callable[[HTLCOwner, UpdateAddHtlc, bytes], None]] = None): def __init__(self, state, sweep_address = None, name = None, payment_completed : Optional[Callable[[Direction, UpdateAddHtlc, bytes], None]] = None):
self.preimages = {} self.preimages = {}
if not payment_completed: if not payment_completed:
payment_completed = lambda this, x, y, z: None payment_completed = lambda this, x, y, z: None
@ -179,13 +150,9 @@ class Channel(PrintError):
# we should not persist txns in this format. we should persist htlcs, and be able to derive # we should not persist txns in this format. we should persist htlcs, and be able to derive
# any past commitment transaction and use that instead; until then... # any past commitment transaction and use that instead; until then...
self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"]) self.remote_commitment_to_be_revoked = Transaction(state["remote_commitment_to_be_revoked"])
self.remote_commitment_to_be_revoked.deserialize(True)
self.log = {LOCAL: HtlcChanges.new(), REMOTE: HtlcChanges.new()} self.hm = HTLCManager(state.get('log'))
for strname, subject in [('remote_log', REMOTE), ('local_log', LOCAL)]:
if strname not in state: continue
for y in state[strname]:
htlc = UpdateAddHtlc(**y)
self.log[subject].adds[htlc.htlc_id] = htlc
self.name = name self.name = name
@ -200,23 +167,18 @@ class Channel(PrintError):
self.lnwatcher = None self.lnwatcher = None
self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])} self.local_commitment = None
self.remote_commitment = None
for sub in (LOCAL, REMOTE):
self.log[sub].locked_in.update(self.log[sub].adds.keys())
self.set_local_commitment(self.current_commitment(LOCAL))
self.set_remote_commitment(self.current_commitment(REMOTE))
def set_local_commitment(self, ctx): def set_local_commitment(self, ctx):
ctn = extract_ctn_from_tx_and_chan(ctx, self)
assert self.signature_fits(ctx), (self.log[LOCAL])
self.local_commitment = ctx self.local_commitment = ctx
if self.sweep_address is not None: if self.sweep_address is not None:
self.local_sweeptxs = create_sweeptxs_for_our_latest_ctx(self, self.local_commitment, self.sweep_address) self.local_sweeptxs = create_sweeptxs_for_our_latest_ctx(self, self.local_commitment, self.sweep_address)
self.assert_signature_fits(ctx) def set_remote_commitment(self):
self.remote_commitment = self.current_commitment(REMOTE)
def set_remote_commitment(self, ctx):
self.remote_commitment = ctx
if self.sweep_address is not None: if self.sweep_address is not None:
self.remote_sweeptxs = create_sweeptxs_for_their_latest_ctx(self, self.remote_commitment, self.sweep_address) self.remote_sweeptxs = create_sweeptxs_for_their_latest_ctx(self, self.remote_commitment, self.sweep_address)
@ -233,9 +195,9 @@ class Channel(PrintError):
raise PaymentFailure('Channel not open') raise PaymentFailure('Channel not open')
if self.available_to_spend(LOCAL) < amount_msat: if self.available_to_spend(LOCAL) < amount_msat:
raise PaymentFailure(f'Not enough local balance. Have: {self.available_to_spend(LOCAL)}, Need: {amount_msat}') raise PaymentFailure(f'Not enough local balance. Have: {self.available_to_spend(LOCAL)}, Need: {amount_msat}')
if len(self.htlcs(LOCAL, only_pending=True)) + 1 > self.config[REMOTE].max_accepted_htlcs: if len(self.hm.htlcs(LOCAL)) + 1 > self.config[REMOTE].max_accepted_htlcs:
raise PaymentFailure('Too many HTLCs already in channel') raise PaymentFailure('Too many HTLCs already in channel')
current_htlc_sum = htlcsum(self.htlcs(LOCAL, only_pending=True)) current_htlc_sum = htlcsum(self.hm.htlcs_by_direction(LOCAL, SENT)) + htlcsum(self.hm.htlcs_by_direction(LOCAL, RECEIVED))
if current_htlc_sum + amount_msat > self.config[REMOTE].max_htlc_value_in_flight_msat: if current_htlc_sum + amount_msat > self.config[REMOTE].max_htlc_value_in_flight_msat:
raise PaymentFailure(f'HTLC value sum (sum of pending htlcs: {current_htlc_sum/1000} sat plus new htlc: {amount_msat/1000} sat) would exceed max allowed: {self.config[REMOTE].max_htlc_value_in_flight_msat/1000} sat') raise PaymentFailure(f'HTLC value sum (sum of pending htlcs: {current_htlc_sum/1000} sat plus new htlc: {amount_msat/1000} sat) would exceed max allowed: {self.config[REMOTE].max_htlc_value_in_flight_msat/1000} sat')
if amount_msat <= 0: # FIXME htlc_minimum_msat if amount_msat <= 0: # FIXME htlc_minimum_msat
@ -269,7 +231,7 @@ class Channel(PrintError):
assert type(htlc) is dict assert type(htlc) is dict
self._check_can_pay(htlc['amount_msat']) self._check_can_pay(htlc['amount_msat'])
htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id) htlc = UpdateAddHtlc(**htlc, htlc_id=self.config[LOCAL].next_htlc_id)
self.log[LOCAL].adds[htlc.htlc_id] = htlc self.hm.send_htlc(htlc)
self.print_error("add_htlc") self.print_error("add_htlc")
self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1) self.config[LOCAL]=self.config[LOCAL]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id return htlc.htlc_id
@ -288,8 +250,7 @@ class Channel(PrintError):
raise RemoteMisbehaving('Remote dipped below channel reserve.' +\ raise RemoteMisbehaving('Remote dipped below channel reserve.' +\
f' Available at remote: {self.available_to_spend(REMOTE)},' +\ f' Available at remote: {self.available_to_spend(REMOTE)},' +\
f' HTLC amount: {htlc.amount_msat}') f' HTLC amount: {htlc.amount_msat}')
adds = self.log[REMOTE].adds self.hm.recv_htlc(htlc)
adds[htlc.htlc_id] = htlc
self.print_error("receive_htlc") self.print_error("receive_htlc")
self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1) self.config[REMOTE]=self.config[REMOTE]._replace(next_htlc_id=htlc.htlc_id + 1)
return htlc.htlc_id return htlc.htlc_id
@ -308,7 +269,7 @@ class Channel(PrintError):
""" """
self.print_error("sign_next_commitment") self.print_error("sign_next_commitment")
old_logs = dict(self.lock_in_htlc_changes(LOCAL)) self.hm.send_ctx()
pending_remote_commitment = self.pending_commitment(REMOTE) pending_remote_commitment = self.pending_commitment(REMOTE)
sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE]) sig_64 = sign_and_get_sig_string(pending_remote_commitment, self.config[LOCAL], self.config[REMOTE])
@ -321,7 +282,8 @@ class Channel(PrintError):
for_us = False for_us = False
htlcsigs = [] htlcsigs = []
for we_receive, htlcs in zip([True, False], [self.included_htlcs(REMOTE, REMOTE), self.included_htlcs(REMOTE, LOCAL)]): # they sent => we receive
for we_receive, htlcs in zip([True, False], [self.included_htlcs(REMOTE, SENT, ctn=self.config[REMOTE].ctn+1), self.included_htlcs(REMOTE, RECEIVED, ctn=self.config[REMOTE].ctn+1)]):
for htlc in htlcs: for htlc in htlcs:
_script, htlc_tx = make_htlc_tx_with_open_channel(chan=self, _script, htlc_tx = make_htlc_tx_with_open_channel(chan=self,
pcp=self.config[REMOTE].next_per_commitment_point, pcp=self.config[REMOTE].next_per_commitment_point,
@ -337,26 +299,11 @@ class Channel(PrintError):
htlcsigs.sort() htlcsigs.sort()
htlcsigs = [x[1] for x in htlcsigs] htlcsigs = [x[1] for x in htlcsigs]
self.remote_commitment = self.pending_commitment(REMOTE) # TODO should add remote_commitment here and handle
# both valid ctx'es in lnwatcher at the same time...
# we can't know if this message arrives.
# since we shouldn't actually throw away
# failed htlcs yet (or mark htlc locked in),
# roll back the changes that were made
self.log = old_logs
return sig_64, htlcsigs return sig_64, htlcsigs
def lock_in_htlc_changes(self, subject):
for sub in (LOCAL, REMOTE):
log = self.log[sub]
yield (sub, deepcopy(log))
for htlc_id in log.fails:
log.adds.pop(htlc_id)
log.fails.clear()
self.log[subject].locked_in.update(self.log[subject].adds.keys())
def receive_new_commitment(self, sig, htlc_sigs): def receive_new_commitment(self, sig, htlc_sigs):
""" """
ReceiveNewCommitment process a signature for a new commitment state sent by ReceiveNewCommitment process a signature for a new commitment state sent by
@ -372,7 +319,7 @@ class Channel(PrintError):
""" """
self.print_error("receive_new_commitment") self.print_error("receive_new_commitment")
for _ in self.lock_in_htlc_changes(REMOTE): pass self.hm.recv_ctx()
assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes assert len(htlc_sigs) == 0 or type(htlc_sigs[0]) is bytes
@ -385,16 +332,18 @@ class Channel(PrintError):
htlc_sigs_string = b''.join(htlc_sigs) htlc_sigs_string = b''.join(htlc_sigs)
htlc_sigs = htlc_sigs[:] # copy cause we will delete now htlc_sigs = htlc_sigs[:] # copy cause we will delete now
for htlcs, we_receive in [(self.included_htlcs(LOCAL, REMOTE), True), (self.included_htlcs(LOCAL, LOCAL), False)]: ctn = self.config[LOCAL].ctn+1
for htlcs, we_receive in [(self.included_htlcs(LOCAL, SENT, ctn=ctn), False), (self.included_htlcs(LOCAL, RECEIVED, ctn=ctn), True)]:
for htlc in htlcs: for htlc in htlcs:
idx = self.verify_htlc(htlc, htlc_sigs, we_receive) idx = self.verify_htlc(htlc, htlc_sigs, we_receive, pending_local_commitment)
del htlc_sigs[idx] del htlc_sigs[idx]
if len(htlc_sigs) != 0: # all sigs should have been popped above if len(htlc_sigs) != 0: # all sigs should have been popped above
raise Exception('failed verifying HTLC signatures: invalid amount of correct signatures') raise Exception('failed verifying HTLC signatures: invalid amount of correct signatures')
self.config[LOCAL]=self.config[LOCAL]._replace( self.config[LOCAL]=self.config[LOCAL]._replace(
current_commitment_signature=sig, current_commitment_signature=sig,
current_htlc_signatures=htlc_sigs_string) current_htlc_signatures=htlc_sigs_string,
got_sig_for_next=True)
if self.pending_fee is not None: if self.pending_fee is not None:
if not self.constraints.is_initiator: if not self.constraints.is_initiator:
@ -402,15 +351,15 @@ class Channel(PrintError):
if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]: if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]:
self.pending_fee[FUNDER_SIGNED] = True self.pending_fee[FUNDER_SIGNED] = True
self.set_local_commitment(self.pending_commitment(LOCAL)) self.set_local_commitment(pending_local_commitment)
def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool) -> int: def verify_htlc(self, htlc: UpdateAddHtlc, htlc_sigs: Sequence[bytes], we_receive: bool, ctx) -> int:
_, this_point, _ = self.points() _, this_point, _, _ = self.points()
_script, htlc_tx = make_htlc_tx_with_open_channel(chan=self, _script, htlc_tx = make_htlc_tx_with_open_channel(chan=self,
pcp=this_point, pcp=this_point,
for_us=True, for_us=True,
we_receive=we_receive, we_receive=we_receive,
commit=self.pending_commitment(LOCAL), commit=ctx,
htlc=htlc) htlc=htlc)
pre_hash = sha256d(bfh(htlc_tx.serialize_preimage(0))) pre_hash = sha256d(bfh(htlc_tx.serialize_preimage(0)))
remote_htlc_pubkey = derive_pubkey(self.config[REMOTE].htlc_basepoint.pubkey, this_point) remote_htlc_pubkey = derive_pubkey(self.config[REMOTE].htlc_basepoint.pubkey, this_point)
@ -418,19 +367,19 @@ class Channel(PrintError):
if ecc.verify_signature(remote_htlc_pubkey, sig, pre_hash): if ecc.verify_signature(remote_htlc_pubkey, sig, pre_hash):
return idx return idx
else: else:
raise Exception(f'failed verifying HTLC signatures: {htlc}') raise Exception(f'failed verifying HTLC signatures: {htlc}, sigs: {len(htlc_sigs)}, we_receive: {we_receive}')
def get_remote_htlc_sig_for_htlc(self, htlc: UpdateAddHtlc, we_receive: bool) -> bytes: def get_remote_htlc_sig_for_htlc(self, htlc: UpdateAddHtlc, we_receive: bool, ctx) -> bytes:
data = self.config[LOCAL].current_htlc_signatures data = self.config[LOCAL].current_htlc_signatures
htlc_sigs = [data[i:i + 64] for i in range(0, len(data), 64)] htlc_sigs = [data[i:i + 64] for i in range(0, len(data), 64)]
idx = self.verify_htlc(htlc, htlc_sigs, we_receive=we_receive) idx = self.verify_htlc(htlc, htlc_sigs, we_receive=we_receive, ctx=ctx)
remote_htlc_sig = ecc.der_sig_from_sig_string(htlc_sigs[idx]) + b'\x01' remote_htlc_sig = ecc.der_sig_from_sig_string(htlc_sigs[idx]) + b'\x01'
return remote_htlc_sig return remote_htlc_sig
def revoke_current_commitment(self): def revoke_current_commitment(self):
self.print_error("revoke_current_commitment") self.print_error("revoke_current_commitment")
last_secret, this_point, next_point = self.points() last_secret, this_point, next_point, _ = self.points()
new_feerate = self.constraints.feerate new_feerate = self.constraints.feerate
@ -444,16 +393,18 @@ class Channel(PrintError):
self.pending_fee = None self.pending_fee = None
print("FEERATE CHANGE COMPLETE (initiator)") print("FEERATE CHANGE COMPLETE (initiator)")
self.config[LOCAL]=self.config[LOCAL]._replace( assert self.config[LOCAL].got_sig_for_next
ctn=self.config[LOCAL].ctn + 1,
)
self.constraints=self.constraints._replace( self.constraints=self.constraints._replace(
feerate=new_feerate feerate=new_feerate
) )
self.set_local_commitment(self.pending_commitment(LOCAL))
# since we should not revoke our latest commitment tx, ctx = self.pending_commitment(LOCAL)
# we do not update self.local_commitment here, self.hm.send_rev()
# it should instead be updated when we receive a new sig self.config[LOCAL]=self.config[LOCAL]._replace(
ctn=self.config[LOCAL].ctn + 1,
got_sig_for_next=False,
)
assert self.signature_fits(ctx)
return RevokeAndAck(last_secret, next_point), "current htlcs" return RevokeAndAck(last_secret, next_point), "current htlcs"
@ -466,7 +417,8 @@ class Channel(PrintError):
this_point = secret_to_pubkey(int.from_bytes(this_secret, 'big')) this_point = secret_to_pubkey(int.from_bytes(this_secret, 'big'))
next_secret = get_per_commitment_secret_from_seed(self.config[LOCAL].per_commitment_secret_seed, RevocationStore.START_INDEX - next_small_num) next_secret = get_per_commitment_secret_from_seed(self.config[LOCAL].per_commitment_secret_seed, RevocationStore.START_INDEX - next_small_num)
next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big')) next_point = secret_to_pubkey(int.from_bytes(next_secret, 'big'))
return last_secret, this_point, next_point last_point = secret_to_pubkey(int.from_bytes(last_secret, 'big'))
return last_secret, this_point, next_point, last_point
def process_new_revocation_secret(self, per_commitment_secret: bytes): def process_new_revocation_secret(self, per_commitment_secret: bytes):
if not self.lnwatcher: if not self.lnwatcher:
@ -481,12 +433,9 @@ class Channel(PrintError):
def receive_revocation(self, revocation) -> Tuple[int, int]: def receive_revocation(self, revocation) -> Tuple[int, int]:
self.print_error("receive_revocation") self.print_error("receive_revocation")
old_logs = dict(self.lock_in_htlc_changes(LOCAL))
cur_point = self.config[REMOTE].current_per_commitment_point cur_point = self.config[REMOTE].current_per_commitment_point
derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True) derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
if cur_point != derived_point: if cur_point != derived_point:
self.log = old_logs
raise Exception('revoked secret not for current point') raise Exception('revoked secret not for current point')
# FIXME not sure this is correct... but it seems to work # FIXME not sure this is correct... but it seems to work
@ -505,51 +454,36 @@ class Channel(PrintError):
if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]: if self.constraints.is_initiator and self.pending_fee[FUNDEE_ACKED]:
self.pending_fee[FUNDER_SIGNED] = True self.pending_fee[FUNDER_SIGNED] = True
def mark_settled(subject): received = self.hm.received_in_ctn(self.config[REMOTE].ctn + 1)
""" sent = self.hm.sent_in_ctn(self.config[REMOTE].ctn + 1)
find pending settlements for subject (LOCAL or REMOTE) and mark them settled, return value of settled htlcs for htlc in received:
""" self.payment_completed(self, RECEIVED, htlc, None)
old_amount = htlcsum(self.htlcs(subject, False)) for htlc in sent:
preimage = self.preimages.pop(htlc.htlc_id)
for htlc_id in self.log[subject].settles: self.payment_completed(self, SENT, htlc, preimage)
adds = self.log[subject].adds received_this_batch = htlcsum(received)
htlc = adds.pop(htlc_id) sent_this_batch = htlcsum(sent)
self.settled[subject].append(htlc.amount_msat)
if subject == LOCAL:
preimage = self.preimages.pop(htlc_id)
else:
preimage = None
self.payment_completed(self, subject, htlc, preimage)
self.log[subject].settles.clear()
return old_amount - htlcsum(self.htlcs(subject, False))
sent_this_batch = mark_settled(LOCAL)
received_this_batch = mark_settled(REMOTE)
next_point = self.config[REMOTE].next_per_commitment_point next_point = self.config[REMOTE].next_per_commitment_point
print("RECEIVED", received_this_batch) self.hm.recv_rev()
print("SENT", sent_this_batch)
self.config[REMOTE]=self.config[REMOTE]._replace( self.config[REMOTE]=self.config[REMOTE]._replace(
ctn=self.config[REMOTE].ctn + 1, ctn=self.config[REMOTE].ctn + 1,
current_per_commitment_point=next_point, current_per_commitment_point=next_point,
next_per_commitment_point=revocation.next_per_commitment_point, next_per_commitment_point=revocation.next_per_commitment_point,
amount_msat=self.config[REMOTE].amount_msat + (sent_this_batch - received_this_batch)
)
self.config[LOCAL]=self.config[LOCAL]._replace(
amount_msat = self.config[LOCAL].amount_msat + (received_this_batch - sent_this_batch)
) )
if self.pending_fee is not None: if self.pending_fee is not None:
if self.constraints.is_initiator: if self.constraints.is_initiator:
self.pending_fee[FUNDEE_ACKED] = True self.pending_fee[FUNDEE_ACKED] = True
self.set_remote_commitment(self.pending_commitment(REMOTE)) self.set_remote_commitment()
self.remote_commitment_to_be_revoked = prev_remote_commitment self.remote_commitment_to_be_revoked = prev_remote_commitment
return received_this_batch, sent_this_batch return received_this_batch, sent_this_batch
def balance(self, subject): def balance(self, subject, ctn=None):
""" """
This balance in mSAT is not including reserve and fees. This balance in mSAT is not including reserve and fees.
So a node cannot actually use it's whole balance. So a node cannot actually use it's whole balance.
@ -560,12 +494,15 @@ class Channel(PrintError):
commited to later when the respective commitment commited to later when the respective commitment
transaction as been revoked. transaction as been revoked.
""" """
assert type(subject) is HTLCOwner
initial = self.config[subject].initial_msat initial = self.config[subject].initial_msat
initial -= sum(self.settled[subject]) for direction, htlc in self.hm.settled_htlcs(subject, ctn):
initial += sum(self.settled[-subject]) if direction == SENT:
initial -= htlc.amount_msat
else:
initial += htlc.amount_msat
assert initial == self.config[subject].amount_msat
return initial return initial
def balance_minus_outgoing_htlcs(self, subject): def balance_minus_outgoing_htlcs(self, subject):
@ -573,48 +510,46 @@ class Channel(PrintError):
This balance in mSAT, which includes the value of This balance in mSAT, which includes the value of
pending outgoing HTLCs, is used in the UI. pending outgoing HTLCs, is used in the UI.
""" """
return self.balance(subject)\ assert type(subject) is HTLCOwner
- htlcsum(self.log[subject].adds.values()) ctn = self.hm.log[subject]['ctn'] + 1
return self.balance(subject, ctn)\
- htlcsum(self.hm.htlcs_by_direction(subject, SENT, ctn))
def available_to_spend(self, subject): def available_to_spend(self, subject):
""" """
This balance in mSAT, while technically correct, can This balance in mSAT, while technically correct, can
not be used in the UI cause it fluctuates (commit fee) not be used in the UI cause it fluctuates (commit fee)
""" """
assert type(subject) is HTLCOwner
return self.balance_minus_outgoing_htlcs(subject)\ return self.balance_minus_outgoing_htlcs(subject)\
- htlcsum(self.log[subject].adds.values())\
- self.config[-subject].reserve_sat * 1000\ - self.config[-subject].reserve_sat * 1000\
- calc_onchain_fees( - calc_onchain_fees(
# TODO should we include a potential new htlc, when we are called from receive_htlc? # TODO should we include a potential new htlc, when we are called from receive_htlc?
len(list(self.included_htlcs(subject, LOCAL)) + list(self.included_htlcs(subject, REMOTE))), len(self.included_htlcs(subject, SENT) + self.included_htlcs(subject, RECEIVED)),
self.pending_feerate(subject), self.pending_feerate(subject),
True, # for_us
self.constraints.is_initiator, self.constraints.is_initiator,
)[subject] )[subject]
def amounts(self): def included_htlcs(self, subject, direction, ctn=None):
remote_settled= htlcsum(self.htlcs(REMOTE, False))
local_settled= htlcsum(self.htlcs(LOCAL, False))
unsettled_local = htlcsum(self.htlcs(LOCAL, True))
unsettled_remote = htlcsum(self.htlcs(REMOTE, True))
remote_msat = self.config[REMOTE].amount_msat -\
unsettled_remote + local_settled - remote_settled
local_msat = self.config[LOCAL].amount_msat -\
unsettled_local + remote_settled - local_settled
return remote_msat, local_msat
def included_htlcs(self, subject, htlc_initiator, only_pending=True):
""" """
return filter of non-dust htlcs for subjects commitment transaction, initiated by given party return filter of non-dust htlcs for subjects commitment transaction, initiated by given party
""" """
assert type(subject) is HTLCOwner
assert type(direction) is Direction
if ctn is None:
ctn = self.config[subject].ctn
feerate = self.pending_feerate(subject) feerate = self.pending_feerate(subject)
conf = self.config[subject] conf = self.config[subject]
weight = HTLC_SUCCESS_WEIGHT if subject != htlc_initiator else HTLC_TIMEOUT_WEIGHT if (subject, direction) in [(REMOTE, RECEIVED), (LOCAL, SENT)]:
htlcs = self.htlcs(htlc_initiator, only_pending=only_pending) weight = HTLC_SUCCESS_WEIGHT
else:
weight = HTLC_TIMEOUT_WEIGHT
htlcs = self.hm.htlcs_by_direction(subject, direction, ctn=ctn)
fee_for_htlc = lambda htlc: htlc.amount_msat // 1000 - (weight * feerate // 1000) fee_for_htlc = lambda htlc: htlc.amount_msat // 1000 - (weight * feerate // 1000)
return filter(lambda htlc: fee_for_htlc(htlc) >= conf.dust_limit_sat, htlcs) return list(filter(lambda htlc: fee_for_htlc(htlc) >= conf.dust_limit_sat, htlcs))
def pending_feerate(self, subject): def pending_feerate(self, subject):
assert type(subject) is HTLCOwner
candidate = self.constraints.feerate candidate = self.constraints.feerate
if self.pending_fee is not None: if self.pending_fee is not None:
x = self.pending_fee.pending_feerate(subject) x = self.pending_fee.pending_feerate(subject)
@ -623,81 +558,53 @@ class Channel(PrintError):
return candidate return candidate
def pending_commitment(self, subject): def pending_commitment(self, subject):
assert type(subject) is HTLCOwner
this_point = self.config[REMOTE].next_per_commitment_point if subject == REMOTE else self.points()[1] this_point = self.config[REMOTE].next_per_commitment_point if subject == REMOTE else self.points()[1]
return self.make_commitment(subject, this_point) ctn = self.config[subject].ctn + 1
feerate = self.pending_feerate(subject)
return self.make_commitment(subject, this_point, ctn, feerate, True)
def current_commitment(self, subject): def current_commitment(self, subject):
old_local_state = self.config[subject] assert type(subject) is HTLCOwner
self.config[subject]=self.config[subject]._replace(ctn=self.config[subject].ctn - 1) this_point = self.config[REMOTE].current_per_commitment_point if subject == REMOTE else self.points()[3]
r = self.pending_commitment(subject) ctn = self.config[subject].ctn
self.config[subject] = old_local_state feerate = self.constraints.feerate
return r return self.make_commitment(subject, this_point, ctn, feerate, False)
def total_msat(self, sub): def total_msat(self, direction):
return sum(self.settled[sub]) assert type(direction) is Direction
sub = LOCAL if direction == SENT else REMOTE
def htlcs(self, subject, only_pending): return htlcsum(self.hm.settled_htlcs_by(sub, self.config[sub].ctn))
"""
only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
sets returned with True and False are disjunct
only_pending true:
skipped if settled or failed
<=>
included if not settled and not failed
only_pending false:
skipped if not (settled or failed)
<=>
included if not not (settled or failed)
included if settled or failed
"""
update_log = self.log[subject]
res = []
for htlc in update_log.adds.values():
locked_in = htlc.htlc_id in update_log.locked_in
settled = htlc.htlc_id in update_log.settles
failed = htlc.htlc_id in update_log.fails
if not locked_in:
continue
if only_pending == (settled or failed):
continue
res.append(htlc)
return res
def settle_htlc(self, preimage, htlc_id): def settle_htlc(self, preimage, htlc_id):
""" """
SettleHTLC attempts to settle an existing outstanding received HTLC. SettleHTLC attempts to settle an existing outstanding received HTLC.
""" """
self.print_error("settle_htlc") self.print_error("settle_htlc")
log = self.log[REMOTE] log = self.hm.log[REMOTE]
htlc = log.adds[htlc_id] htlc = log['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage) assert htlc.payment_hash == sha256(preimage)
assert htlc_id not in log.settles assert htlc_id not in log['settles']
log.settles.add(htlc_id) self.hm.send_settle(htlc_id)
# not saving preimage because it's already saved in LNWorker.invoices # not saving preimage because it's already saved in LNWorker.invoices
def receive_htlc_settle(self, preimage, htlc_id): def receive_htlc_settle(self, preimage, htlc_id):
self.print_error("receive_htlc_settle") self.print_error("receive_htlc_settle")
log = self.log[LOCAL] log = self.hm.log[LOCAL]
htlc = log.adds[htlc_id] htlc = log['adds'][htlc_id]
assert htlc.payment_hash == sha256(preimage) assert htlc.payment_hash == sha256(preimage)
assert htlc_id not in log.settles assert htlc_id not in log['settles']
self.hm.recv_settle(htlc_id)
self.preimages[htlc_id] = preimage self.preimages[htlc_id] = preimage
log.settles.add(htlc_id)
# we don't save the preimage because we don't need to forward it anyway # we don't save the preimage because we don't need to forward it anyway
def fail_htlc(self, htlc_id): def fail_htlc(self, htlc_id):
self.print_error("fail_htlc") self.print_error("fail_htlc")
log = self.log[REMOTE] self.hm.send_fail(htlc_id)
assert htlc_id not in log.fails
log.fails.add(htlc_id)
def receive_fail_htlc(self, htlc_id): def receive_fail_htlc(self, htlc_id):
self.print_error("receive_fail_htlc") self.print_error("receive_fail_htlc")
log = self.log[LOCAL] self.hm.recv_fail(htlc_id)
assert htlc_id not in log.fails
log.fails.add(htlc_id)
@property @property
def current_height(self): def current_height(self):
@ -713,29 +620,7 @@ class Channel(PrintError):
raise Exception("a fee update is already in progress") raise Exception("a fee update is already in progress")
self.pending_fee = FeeUpdate(self, rate=feerate) self.pending_fee = FeeUpdate(self, rate=feerate)
def remove_uncommitted_htlcs_from_log(self, subject):
"""
returns
- the htlcs with uncommited (not locked in) htlcs removed
- a list of htlc_ids that were removed
"""
removed = []
htlcs = []
log = self.log[subject]
for i in log.adds.values():
locked_in = i.htlc_id in log.locked_in
if locked_in:
htlcs.append(i._asdict())
else:
removed.append(i.htlc_id)
return htlcs, removed
def to_save(self): def to_save(self):
# need to forget about uncommited htlcs
# since we must assume they don't know about it,
# if it was not acked
remote_filtered, remote_removed = self.remove_uncommitted_htlcs_from_log(REMOTE)
local_filtered, local_removed = self.remove_uncommitted_htlcs_from_log(LOCAL)
to_save = { to_save = {
"local_config": self.config[LOCAL], "local_config": self.config[LOCAL],
"remote_config": self.config[REMOTE], "remote_config": self.config[REMOTE],
@ -745,24 +630,10 @@ class Channel(PrintError):
"funding_outpoint": self.funding_outpoint, "funding_outpoint": self.funding_outpoint,
"node_id": self.node_id, "node_id": self.node_id,
"remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked), "remote_commitment_to_be_revoked": str(self.remote_commitment_to_be_revoked),
"remote_log": remote_filtered, "log": self.hm.to_save(),
"local_log": local_filtered,
"onion_keys": str_bytes_dict_to_save(self.onion_keys), "onion_keys": str_bytes_dict_to_save(self.onion_keys),
"settled_local": self.settled[LOCAL],
"settled_remote": self.settled[REMOTE],
"force_closed": self.get_state() == 'FORCE_CLOSING', "force_closed": self.get_state() == 'FORCE_CLOSING',
} }
# htlcs number must be monotonically increasing,
# so we have to decrease the counter
if len(remote_removed) != 0:
assert min(remote_removed) < to_save['remote_config'].next_htlc_id
to_save['remote_config'] = to_save['remote_config']._replace(next_htlc_id = min(remote_removed))
if len(local_removed) != 0:
assert min(local_removed) < to_save['local_config'].next_htlc_id
to_save['local_config'] = to_save['local_config']._replace(next_htlc_id = min(local_removed))
return to_save return to_save
def serialize(self): def serialize(self):
@ -792,33 +663,49 @@ class Channel(PrintError):
def __str__(self): def __str__(self):
return str(self.serialize()) return str(self.serialize())
def make_commitment(self, subject, this_point) -> Transaction: def make_commitment(self, subject, this_point, ctn, feerate, pending) -> Transaction:
remote_msat, local_msat = self.amounts() #if subject == REMOTE and not pending:
assert local_msat >= 0, local_msat # ctn -= 1
assert remote_msat >= 0, remote_msat assert type(subject) is HTLCOwner
other = REMOTE if LOCAL == subject else LOCAL
remote_msat, local_msat = self.balance(other, ctn), self.balance(subject, ctn)
received_htlcs = self.hm.htlcs_by_direction(subject, SENT if subject == LOCAL else RECEIVED, ctn)
sent_htlcs = self.hm.htlcs_by_direction(subject, RECEIVED if subject == LOCAL else SENT, ctn)
if subject != LOCAL:
remote_msat -= htlcsum(received_htlcs)
local_msat -= htlcsum(sent_htlcs)
else:
remote_msat -= htlcsum(sent_htlcs)
local_msat -= htlcsum(received_htlcs)
assert remote_msat >= 0
assert local_msat >= 0
# same htlcs as before, but now without dust.
received_htlcs = self.included_htlcs(subject, SENT if subject == LOCAL else RECEIVED, ctn)
sent_htlcs = self.included_htlcs(subject, RECEIVED if subject == LOCAL else SENT, ctn)
this_config = self.config[subject] this_config = self.config[subject]
other_config = self.config[-subject] other_config = self.config[-subject]
other_htlc_pubkey = derive_pubkey(other_config.htlc_basepoint.pubkey, this_point) other_htlc_pubkey = derive_pubkey(other_config.htlc_basepoint.pubkey, this_point)
this_htlc_pubkey = derive_pubkey(this_config.htlc_basepoint.pubkey, this_point) this_htlc_pubkey = derive_pubkey(this_config.htlc_basepoint.pubkey, this_point)
other_revocation_pubkey = derive_blinded_pubkey(other_config.revocation_basepoint.pubkey, this_point) other_revocation_pubkey = derive_blinded_pubkey(other_config.revocation_basepoint.pubkey, this_point)
htlcs = [] # type: List[ScriptHtlc] htlcs = [] # type: List[ScriptHtlc]
def append_htlc(htlc: UpdateAddHtlc, is_received_htlc: bool): for is_received_htlc, htlc_list in zip((subject != LOCAL, subject == LOCAL), (received_htlcs, sent_htlcs)):
htlcs.append(ScriptHtlc(make_htlc_output_witness_script( for htlc in htlc_list:
is_received_htlc=is_received_htlc, htlcs.append(ScriptHtlc(make_htlc_output_witness_script(
remote_revocation_pubkey=other_revocation_pubkey, is_received_htlc=is_received_htlc,
remote_htlc_pubkey=other_htlc_pubkey, remote_revocation_pubkey=other_revocation_pubkey,
local_htlc_pubkey=this_htlc_pubkey, remote_htlc_pubkey=other_htlc_pubkey,
payment_hash=htlc.payment_hash, local_htlc_pubkey=this_htlc_pubkey,
cltv_expiry=htlc.cltv_expiry), htlc)) payment_hash=htlc.payment_hash,
for htlc in self.included_htlcs(subject, -subject): cltv_expiry=htlc.cltv_expiry), htlc))
append_htlc(htlc, is_received_htlc=True) onchain_fees = calc_onchain_fees(
for htlc in self.included_htlcs(subject, subject): len(htlcs),
append_htlc(htlc, is_received_htlc=False) feerate,
if subject != LOCAL: self.constraints.is_initiator == (subject == LOCAL),
remote_msat, local_msat = local_msat, remote_msat )
payment_pubkey = derive_pubkey(other_config.payment_basepoint.pubkey, this_point) payment_pubkey = derive_pubkey(other_config.payment_basepoint.pubkey, this_point)
return make_commitment( return make_commitment(
self.config[subject].ctn + 1, ctn,
this_config.multisig_key.pubkey, this_config.multisig_key.pubkey,
other_config.multisig_key.pubkey, other_config.multisig_key.pubkey,
payment_pubkey, payment_pubkey,
@ -832,12 +719,7 @@ class Channel(PrintError):
local_msat, local_msat,
remote_msat, remote_msat,
this_config.dust_limit_sat, this_config.dust_limit_sat,
calc_onchain_fees( onchain_fees,
len(htlcs),
self.pending_feerate(subject),
subject == LOCAL,
self.constraints.is_initiator,
),
htlcs=htlcs) htlcs=htlcs)
def get_local_index(self): def get_local_index(self):
@ -850,8 +732,8 @@ class Channel(PrintError):
LOCAL: fee_sat * 1000 if self.constraints.is_initiator else 0, LOCAL: fee_sat * 1000 if self.constraints.is_initiator else 0,
REMOTE: fee_sat * 1000 if not self.constraints.is_initiator else 0, REMOTE: fee_sat * 1000 if not self.constraints.is_initiator else 0,
}, },
self.config[LOCAL].amount_msat, self.balance(LOCAL),
self.config[REMOTE].amount_msat, self.balance(REMOTE),
(TYPE_SCRIPT, bh2u(local_script)), (TYPE_SCRIPT, bh2u(local_script)),
(TYPE_SCRIPT, bh2u(remote_script)), (TYPE_SCRIPT, bh2u(remote_script)),
[], self.config[LOCAL].dust_limit_sat) [], self.config[LOCAL].dust_limit_sat)
@ -867,38 +749,39 @@ class Channel(PrintError):
sig = ecc.sig_string_from_der_sig(der_sig[:-1]) sig = ecc.sig_string_from_der_sig(der_sig[:-1])
return sig, closing_tx return sig, closing_tx
def assert_signature_fits(self, tx): def signature_fits(self, tx):
remote_sig = self.config[LOCAL].current_commitment_signature remote_sig = self.config[LOCAL].current_commitment_signature
if remote_sig: # only None in test preimage_hex = tx.serialize_preimage(0)
preimage_hex = tx.serialize_preimage(0) pre_hash = sha256d(bfh(preimage_hex))
pre_hash = sha256d(bfh(preimage_hex)) assert remote_sig
if not ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, remote_sig, pre_hash): res = ecc.verify_signature(self.config[REMOTE].multisig_key.pubkey, remote_sig, pre_hash)
self.print_error("WARNING: commitment signature inconsistency, cannot force close") return res
def force_close_tx(self): def force_close_tx(self):
tx = self.local_commitment tx = self.local_commitment
assert self.signature_fits(tx)
tx = Transaction(str(tx)) tx = Transaction(str(tx))
tx.deserialize(True) tx.deserialize(True)
self.assert_signature_fits(tx)
tx.sign({bh2u(self.config[LOCAL].multisig_key.pubkey): (self.config[LOCAL].multisig_key.privkey, True)}) tx.sign({bh2u(self.config[LOCAL].multisig_key.pubkey): (self.config[LOCAL].multisig_key.privkey, True)})
remote_sig = self.config[LOCAL].current_commitment_signature remote_sig = self.config[LOCAL].current_commitment_signature
if remote_sig: # only None in test remote_sig = ecc.der_sig_from_sig_string(remote_sig) + b"\x01"
remote_sig = ecc.der_sig_from_sig_string(remote_sig) + b"\x01" sigs = tx._inputs[0]["signatures"]
sigs = tx._inputs[0]["signatures"] none_idx = sigs.index(None)
none_idx = sigs.index(None) tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig))
tx.add_signature_to_txin(0, none_idx, bh2u(remote_sig)) assert tx.is_complete()
assert tx.is_complete()
return tx return tx
def included_htlcs_in_their_latest_ctxs(self, htlc_initiator) -> Dict[int, List[UpdateAddHtlc]]: def included_htlcs_in_their_latest_ctxs(self, htlc_initiator) -> Dict[int, List[UpdateAddHtlc]]:
""" A map from commitment number to list of HTLCs in """ A map from commitment number to list of HTLCs in
their latest two commitment transactions. their latest two commitment transactions.
The oldest might have been revoked. """ The oldest might have been revoked. """
old_htlcs = list(self.included_htlcs(REMOTE, htlc_initiator, only_pending=False)) assert type(htlc_initiator) is HTLCOwner
direction = RECEIVED if htlc_initiator == LOCAL else SENT
old_ctn = self.config[REMOTE].ctn
old_htlcs = self.included_htlcs(REMOTE, direction, ctn=old_ctn)
old_logs = dict(self.lock_in_htlc_changes(LOCAL)) new_ctn = self.config[REMOTE].ctn+1
new_htlcs = list(self.included_htlcs(REMOTE, htlc_initiator)) new_htlcs = self.included_htlcs(REMOTE, direction, ctn=new_ctn)
self.log = old_logs
return {self.config[REMOTE].ctn: old_htlcs, return {old_ctn: old_htlcs,
self.config[REMOTE].ctn+1: new_htlcs, } new_ctn: new_htlcs, }

159
electrum/lnhtlc.py Normal file
View file

@ -0,0 +1,159 @@
from copy import deepcopy
from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction
from .util import bh2u
class HTLCManager:
def __init__(self, log=None):
self.expect_sig = {SENT: False, RECEIVED: False}
if log is None:
initial = {'ctn': 0, 'adds': {}, 'locked_in': {}, 'settles': {}, 'fails': {}}
log = {LOCAL: deepcopy(initial), REMOTE: deepcopy(initial)}
else:
assert type(log) is dict
log = {HTLCOwner(int(x)): y for x, y in deepcopy(log).items()}
for sub in (LOCAL, REMOTE):
log[sub]['adds'] = {int(x): UpdateAddHtlc(*y) for x, y in log[sub]['adds'].items()}
coerceHtlcOwner2IntMap = lambda x: {HTLCOwner(int(y)): z for y, z in x.items()}
log[sub]['locked_in'] = {int(x): coerceHtlcOwner2IntMap(y) for x, y in log[sub]['locked_in'].items()}
log[sub]['settles'] = {int(x): y for x, y in log[sub]['settles'].items()}
log[sub]['fails'] = {int(x): y for x, y in log[sub]['fails'].items()}
self.log = log
def to_save(self):
x = deepcopy(self.log)
for sub in (LOCAL, REMOTE):
d = {}
for htlc_id, htlc in x[sub]['adds'].items():
d[htlc_id] = (htlc[0], bh2u(htlc[1])) + htlc[2:]
x[sub]['adds'] = d
return x
def send_htlc(self, htlc):
htlc_id = htlc.htlc_id
adds = self.log[LOCAL]['adds']
assert type(adds) is not str
adds[htlc_id] = htlc
self.log[LOCAL]['locked_in'][htlc_id] = {LOCAL: None, REMOTE: self.log[REMOTE]['ctn']+1}
self.expect_sig[SENT] = True
return htlc
def recv_htlc(self, htlc):
htlc_id = htlc.htlc_id
self.log[REMOTE]['htlc_id'] = htlc_id
self.log[REMOTE]['adds'][htlc_id] = htlc
l = self.log[REMOTE]['locked_in'][htlc_id] = {LOCAL: self.log[LOCAL]['ctn']+1, REMOTE: None}
self.expect_sig[RECEIVED] = True
def send_ctx(self):
next_ctn = self.log[REMOTE]['ctn'] + 1
for locked_in in self.log[REMOTE]['locked_in'].values():
if locked_in[REMOTE] is None:
locked_in[REMOTE] = next_ctn
self.expect_sig[SENT] = False
#return Sig(self.pending_htlcs(REMOTE), next_ctn)
def recv_ctx(self):
next_ctn = self.log[LOCAL]['ctn'] + 1
for locked_in in self.log[LOCAL]['locked_in'].values():
if locked_in[LOCAL] is None:
locked_in[LOCAL] = next_ctn
self.expect_sig[SENT] = False
def send_rev(self):
self.log[LOCAL]['ctn'] += 1
def recv_rev(self):
self.log[REMOTE]['ctn'] += 1
did_set_htlc_height = False
for htlc_id, ctnheights in self.log[LOCAL]['locked_in'].items():
if ctnheights[LOCAL] is None:
did_set_htlc_height = True
assert ctnheights[REMOTE] == self.log[REMOTE]['ctn']
ctnheights[LOCAL] = ctnheights[REMOTE]
return did_set_htlc_height
def htlcs_by_direction(self, subject, direction, ctn=None):
"""
direction is relative to subject!
"""
assert type(subject) is HTLCOwner
assert type(direction) is Direction
if ctn is None:
ctn = self.log[subject]['ctn']
l = []
if direction == SENT and subject == LOCAL:
party = LOCAL
elif direction == RECEIVED and subject == REMOTE:
party = LOCAL
else:
party = REMOTE
for htlc_id, ctnheights in self.log[party]['locked_in'].items():
htlc_height = ctnheights[subject]
if htlc_height is None:
include = not self.expect_sig[RECEIVED if party == LOCAL else SENT] and ctnheights[-subject] <= ctn
else:
include = htlc_height <= ctn
if include:
settles = self.log[party]['settles']
if htlc_id not in settles or settles[htlc_id] > ctn:
fails = self.log[party]['fails']
if htlc_id not in fails or fails[htlc_id] > ctn:
l.append(self.log[party]['adds'][htlc_id])
return l
def htlcs(self, subject, ctn=None):
assert type(subject) is HTLCOwner
if ctn is None:
ctn = self.log[subject]['ctn']
l = []
l += [(SENT, x) for x in self.htlcs_by_direction(subject, SENT, ctn)]
l += [(RECEIVED, x) for x in self.htlcs_by_direction(subject, RECEIVED, ctn)]
return l
def current_htlcs(self, subject):
assert type(subject) is HTLCOwner
ctn = self.log[subject]['ctn']
return self.htlcs(subject, ctn)
def pending_htlcs(self, subject):
assert type(subject) is HTLCOwner
ctn = self.log[subject]['ctn'] + 1
return self.htlcs(subject, ctn)
def send_settle(self, htlc_id):
self.log[REMOTE]['settles'][htlc_id] = self.log[REMOTE]['ctn'] + 1
def recv_settle(self, htlc_id):
self.log[LOCAL]['settles'][htlc_id] = self.log[LOCAL]['ctn'] + 1
def settled_htlcs_by(self, subject, ctn=None):
assert type(subject) is HTLCOwner
if ctn is None:
ctn = self.log[subject]['ctn']
return [self.log[subject]['adds'][htlc_id] for htlc_id, height in self.log[subject]['settles'].items() if height <= ctn]
def settled_htlcs(self, subject, ctn=None):
assert type(subject) is HTLCOwner
if ctn is None:
ctn = self.log[subject]['ctn']
sent = [(SENT, x) for x in self.settled_htlcs_by(subject, ctn)]
other = subject.inverted()
received = [(RECEIVED, x) for x in self.settled_htlcs_by(other, ctn)]
return sent + received
def received_in_ctn(self, ctn):
return [self.log[REMOTE]['adds'][htlc_id] for htlc_id, height in self.log[REMOTE]['settles'].items() if height == ctn]
def sent_in_ctn(self, ctn):
return [self.log[LOCAL]['adds'][htlc_id] for htlc_id, height in self.log[LOCAL]['settles'].items() if height == ctn]
def send_fail(self, htlc_id):
self.log[REMOTE]['fails'][htlc_id] = self.log[REMOTE]['ctn'] + 1
def recv_fail(self, htlc_id):
self.log[LOCAL]['fails'][htlc_id] = self.log[LOCAL]['ctn'] + 1

View file

@ -9,15 +9,15 @@ from .bitcoin import TYPE_ADDRESS, redeem_script_to_address, dust_threshold
from . import ecc from . import ecc
from .lnutil import (make_commitment_output_to_remote_address, make_commitment_output_to_local_witness_script, from .lnutil import (make_commitment_output_to_remote_address, make_commitment_output_to_local_witness_script,
derive_privkey, derive_pubkey, derive_blinded_pubkey, derive_blinded_privkey, derive_privkey, derive_pubkey, derive_blinded_pubkey, derive_blinded_privkey,
make_htlc_tx_witness, make_htlc_tx_with_open_channel, make_htlc_tx_witness, make_htlc_tx_with_open_channel, UpdateAddHtlc,
LOCAL, REMOTE, make_htlc_output_witness_script, UnknownPaymentHash, LOCAL, REMOTE, make_htlc_output_witness_script, UnknownPaymentHash,
get_ordered_channel_configs, privkey_to_pubkey, get_per_commitment_secret_from_seed, get_ordered_channel_configs, privkey_to_pubkey, get_per_commitment_secret_from_seed,
RevocationStore, extract_ctn_from_tx_and_chan, UnableToDeriveSecret) RevocationStore, extract_ctn_from_tx_and_chan, UnableToDeriveSecret, SENT, RECEIVED)
from .transaction import Transaction, TxOutput, construct_witness from .transaction import Transaction, TxOutput, construct_witness
from .simple_config import SimpleConfig, FEERATE_FALLBACK_STATIC_FEE from .simple_config import SimpleConfig, FEERATE_FALLBACK_STATIC_FEE
if TYPE_CHECKING: if TYPE_CHECKING:
from .lnchan import Channel, UpdateAddHtlc from .lnchan import Channel
def maybe_create_sweeptx_for_their_ctx_to_remote(ctx: Transaction, sweep_address: str, def maybe_create_sweeptx_for_their_ctx_to_remote(ctx: Transaction, sweep_address: str,
@ -106,7 +106,7 @@ def create_sweeptxs_for_their_just_revoked_ctx(chan: 'Channel', ctx: Transaction
ctn = extract_ctn_from_tx_and_chan(ctx, chan) ctn = extract_ctn_from_tx_and_chan(ctx, chan)
assert ctn == chan.config[REMOTE].ctn assert ctn == chan.config[REMOTE].ctn
# received HTLCs, in their ctx # received HTLCs, in their ctx
received_htlcs = chan.included_htlcs(REMOTE, LOCAL, False) received_htlcs = chan.included_htlcs(REMOTE, RECEIVED, ctn)
for htlc in received_htlcs: for htlc in received_htlcs:
direct_sweep_tx, secondstage_sweep_tx, htlc_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=True) direct_sweep_tx, secondstage_sweep_tx, htlc_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=True)
if direct_sweep_tx: if direct_sweep_tx:
@ -114,7 +114,7 @@ def create_sweeptxs_for_their_just_revoked_ctx(chan: 'Channel', ctx: Transaction
if secondstage_sweep_tx: if secondstage_sweep_tx:
txs[htlc_tx.txid()] = secondstage_sweep_tx txs[htlc_tx.txid()] = secondstage_sweep_tx
# offered HTLCs, in their ctx # offered HTLCs, in their ctx
offered_htlcs = chan.included_htlcs(REMOTE, REMOTE, False) offered_htlcs = chan.included_htlcs(REMOTE, SENT, ctn)
for htlc in offered_htlcs: for htlc in offered_htlcs:
direct_sweep_tx, secondstage_sweep_tx, htlc_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=False) direct_sweep_tx, secondstage_sweep_tx, htlc_tx = create_sweeptx_for_htlc(htlc, is_received_htlc=False)
if direct_sweep_tx: if direct_sweep_tx:
@ -181,16 +181,14 @@ def create_sweeptxs_for_our_latest_ctx(chan: 'Channel', ctx: Transaction,
is_revocation=False) is_revocation=False)
return htlc_tx, to_wallet_tx return htlc_tx, to_wallet_tx
# offered HTLCs, in our ctx --> "timeout" # offered HTLCs, in our ctx --> "timeout"
# TODO consider carefully if "included_htlcs" is what we need here # received HTLCs, in our ctx --> "success"
offered_htlcs = list(chan.included_htlcs(LOCAL, LOCAL)) # type: List[UpdateAddHtlc] offered_htlcs = chan.included_htlcs(LOCAL, SENT, ctn) # type: List[UpdateAddHtlc]
received_htlcs = chan.included_htlcs(LOCAL, RECEIVED, ctn) # type: List[UpdateAddHtlc]
for htlc in offered_htlcs: for htlc in offered_htlcs:
htlc_tx, to_wallet_tx = create_txns_for_htlc(htlc, is_received_htlc=False) htlc_tx, to_wallet_tx = create_txns_for_htlc(htlc, is_received_htlc=False)
if htlc_tx and to_wallet_tx: if htlc_tx and to_wallet_tx:
txs[to_wallet_tx.prevout(0)] = to_wallet_tx txs[to_wallet_tx.prevout(0)] = to_wallet_tx
txs[htlc_tx.prevout(0)] = htlc_tx txs[htlc_tx.prevout(0)] = htlc_tx
# received HTLCs, in our ctx --> "success"
# TODO consider carefully if "included_htlcs" is what we need here
received_htlcs = list(chan.included_htlcs(LOCAL, REMOTE)) # type: List[UpdateAddHtlc]
for htlc in received_htlcs: for htlc in received_htlcs:
htlc_tx, to_wallet_tx = create_txns_for_htlc(htlc, is_received_htlc=True) htlc_tx, to_wallet_tx = create_txns_for_htlc(htlc, is_received_htlc=True)
if htlc_tx and to_wallet_tx: if htlc_tx and to_wallet_tx:
@ -332,7 +330,7 @@ def create_htlctx_that_spends_from_our_ctx(chan: 'Channel', our_pcp: bytes,
htlc=htlc, htlc=htlc,
name=f'our_ctx_htlc_tx_{bh2u(htlc.payment_hash)}', name=f'our_ctx_htlc_tx_{bh2u(htlc.payment_hash)}',
cltv_expiry=0 if is_received_htlc else htlc.cltv_expiry) cltv_expiry=0 if is_received_htlc else htlc.cltv_expiry)
remote_htlc_sig = chan.get_remote_htlc_sig_for_htlc(htlc, we_receive=is_received_htlc) remote_htlc_sig = chan.get_remote_htlc_sig_for_htlc(htlc, we_receive=is_received_htlc, ctx=ctx)
local_htlc_sig = bfh(htlc_tx.sign_txin(0, local_htlc_privkey)) local_htlc_sig = bfh(htlc_tx.sign_txin(0, local_htlc_privkey))
txin = htlc_tx.inputs()[0] txin = htlc_tx.inputs()[0]
witness_program = bfh(Transaction.get_preimage_script(txin)) witness_program = bfh(Transaction.get_preimage_script(txin))

View file

@ -21,7 +21,7 @@ from .lnaddr import lndecode
from .keystore import BIP32_KeyStore from .keystore import BIP32_KeyStore
if TYPE_CHECKING: if TYPE_CHECKING:
from .lnchan import Channel, UpdateAddHtlc from .lnchan import Channel
HTLC_TIMEOUT_WEIGHT = 663 HTLC_TIMEOUT_WEIGHT = 663
@ -35,7 +35,6 @@ OnlyPubkeyKeypair = namedtuple("OnlyPubkeyKeypair", ["pubkey"])
class LocalConfig(NamedTuple): class LocalConfig(NamedTuple):
# shared channel config fields (DUPLICATED code!!) # shared channel config fields (DUPLICATED code!!)
ctn: int ctn: int
amount_msat: int
next_htlc_id: int next_htlc_id: int
payment_basepoint: 'Keypair' payment_basepoint: 'Keypair'
multisig_key: 'Keypair' multisig_key: 'Keypair'
@ -54,12 +53,12 @@ class LocalConfig(NamedTuple):
was_announced: bool was_announced: bool
current_commitment_signature: Optional[bytes] current_commitment_signature: Optional[bytes]
current_htlc_signatures: List[bytes] current_htlc_signatures: List[bytes]
got_sig_for_next: bool
class RemoteConfig(NamedTuple): class RemoteConfig(NamedTuple):
# shared channel config fields (DUPLICATED code!!) # shared channel config fields (DUPLICATED code!!)
ctn: int ctn: int
amount_msat: int
next_htlc_id: int next_htlc_id: int
payment_basepoint: 'Keypair' payment_basepoint: 'Keypair'
multisig_key: 'Keypair' multisig_key: 'Keypair'
@ -364,7 +363,7 @@ def make_htlc_tx_with_open_channel(chan: 'Channel', pcp: bytes, for_us: bool,
# FIXME handle htlc_address collision # FIXME handle htlc_address collision
# also: https://github.com/lightningnetwork/lightning-rfc/issues/448 # also: https://github.com/lightningnetwork/lightning-rfc/issues/448
prevout_idx = commit.get_output_idx_from_address(htlc_address) prevout_idx = commit.get_output_idx_from_address(htlc_address)
assert prevout_idx is not None assert prevout_idx is not None, (htlc_address, commit.outputs(), extract_ctn_from_tx_and_chan(commit, chan))
htlc_tx_inputs = make_htlc_tx_inputs( htlc_tx_inputs = make_htlc_tx_inputs(
commit.txid(), prevout_idx, commit.txid(), prevout_idx,
amount_msat=amount_msat, amount_msat=amount_msat,
@ -395,11 +394,16 @@ class HTLCOwner(IntFlag):
LOCAL = 1 LOCAL = 1
REMOTE = -LOCAL REMOTE = -LOCAL
SENT = LOCAL def inverted(self):
RECEIVED = REMOTE return HTLCOwner(-self)
class Direction(IntFlag):
SENT = 3
RECEIVED = 4
SENT = Direction.SENT
RECEIVED = Direction.RECEIVED
SENT = HTLCOwner.SENT
RECEIVED = HTLCOwner.RECEIVED
LOCAL = HTLCOwner.LOCAL LOCAL = HTLCOwner.LOCAL
REMOTE = HTLCOwner.REMOTE REMOTE = HTLCOwner.REMOTE
@ -420,8 +424,7 @@ def make_commitment_outputs(fees_per_participant: Mapping[HTLCOwner, int], local
c_outputs_filtered = list(filter(lambda x: x.value >= dust_limit_sat, non_htlc_outputs + htlc_outputs)) c_outputs_filtered = list(filter(lambda x: x.value >= dust_limit_sat, non_htlc_outputs + htlc_outputs))
return htlc_outputs, c_outputs_filtered return htlc_outputs, c_outputs_filtered
def calc_onchain_fees(num_htlcs, feerate, for_us, we_are_initiator): def calc_onchain_fees(num_htlcs, feerate, we_pay_fee):
we_pay_fee = for_us == we_are_initiator
overall_weight = 500 + 172 * num_htlcs + 224 overall_weight = 500 + 172 * num_htlcs + 224
fee = feerate * overall_weight fee = feerate * overall_weight
fee = fee // 1000 * 1000 fee = fee // 1000 * 1000
@ -451,7 +454,7 @@ def make_commitment(ctn, local_funding_pubkey, remote_funding_pubkey,
htlc_outputs, c_outputs_filtered = make_commitment_outputs(fees_per_participant, local_amount, remote_amount, htlc_outputs, c_outputs_filtered = make_commitment_outputs(fees_per_participant, local_amount, remote_amount,
(bitcoin.TYPE_ADDRESS, local_address), (bitcoin.TYPE_ADDRESS, remote_address), htlcs, dust_limit_sat) (bitcoin.TYPE_ADDRESS, local_address), (bitcoin.TYPE_ADDRESS, remote_address), htlcs, dust_limit_sat)
assert sum(x.value for x in c_outputs_filtered) <= funding_sat assert sum(x.value for x in c_outputs_filtered) <= funding_sat, (c_outputs_filtered, funding_sat)
# create commitment tx # create commitment tx
tx = Transaction.from_io(c_inputs, c_outputs_filtered, locktime=locktime, version=2) tx = Transaction.from_io(c_inputs, c_outputs_filtered, locktime=locktime, version=2)
@ -649,3 +652,20 @@ def format_short_channel_id(short_channel_id: Optional[bytes]):
return str(int.from_bytes(short_channel_id[:3], 'big')) \ return str(int.from_bytes(short_channel_id[:3], 'big')) \
+ 'x' + str(int.from_bytes(short_channel_id[3:6], 'big')) \ + 'x' + str(int.from_bytes(short_channel_id[3:6], 'big')) \
+ 'x' + str(int.from_bytes(short_channel_id[6:], 'big')) + 'x' + str(int.from_bytes(short_channel_id[6:], 'big'))
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'htlc_id'])):
"""
This whole class body is so that if you pass a hex-string as payment_hash,
it is decoded to bytes. Bytes can't be saved to disk, so we save hex-strings.
"""
__slots__ = ()
def __new__(cls, *args, **kwargs):
if len(args) > 0:
args = list(args)
if type(args[1]) is str:
args[1] = bfh(args[1])
return super().__new__(cls, *args)
if type(kwargs['payment_hash']) is str:
kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
return super().__new__(cls, **kwargs)

View file

@ -29,13 +29,14 @@ from .lntransport import LNResponderTransport
from .lnbase import Peer from .lnbase import Peer
from .lnaddr import lnencode, LnAddr, lndecode from .lnaddr import lnencode, LnAddr, lndecode
from .ecc import der_sig_from_sig_string from .ecc import der_sig_from_sig_string
from .lnchan import Channel, ChannelJsonEncoder, UpdateAddHtlc from .lnchan import Channel, ChannelJsonEncoder
from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
get_compressed_pubkey_from_bech32, extract_nodeid, get_compressed_pubkey_from_bech32, extract_nodeid,
PaymentFailure, split_host_port, ConnStringFormatError, PaymentFailure, split_host_port, ConnStringFormatError,
generate_keypair, LnKeyFamily, LOCAL, REMOTE, generate_keypair, LnKeyFamily, LOCAL, REMOTE,
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner) NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner,
UpdateAddHtlc, Direction)
from .i18n import _ from .i18n import _
from .lnrouter import RouteEdge, is_route_sane_to_use from .lnrouter import RouteEdge, is_route_sane_to_use
from .address_synchronizer import TX_HEIGHT_LOCAL from .address_synchronizer import TX_HEIGHT_LOCAL
@ -66,7 +67,7 @@ class LNWorker(PrintError):
def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'): def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'):
self.wallet = wallet self.wallet = wallet
# invoices we are currently trying to pay (might be pending HTLCs on a commitment transaction) # invoices we are currently trying to pay (might be pending HTLCs on a commitment transaction)
self.paying = self.wallet.storage.get('lightning_payments_inflight', {}) # type: Dict[bytes, Tuple[str, Optional[int], bytes]] self.paying = self.wallet.storage.get('lightning_payments_inflight', {}) # type: Dict[bytes, Tuple[str, Optional[int], str]]
self.sweep_address = wallet.get_receiving_address() self.sweep_address = wallet.get_receiving_address()
self.network = network self.network = network
self.channel_db = self.network.channel_db self.channel_db = self.network.channel_db
@ -75,12 +76,15 @@ class LNWorker(PrintError):
self.node_keypair = generate_keypair(self.ln_keystore, LnKeyFamily.NODE_KEY, 0) self.node_keypair = generate_keypair(self.ln_keystore, LnKeyFamily.NODE_KEY, 0)
self.config = network.config self.config = network.config
self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer self.peers = {} # type: Dict[bytes, Peer] # pubkey -> Peer
self.invoices = wallet.storage.get('lightning_invoices', {}) # type: Dict[str, Tuple[str,str]] # RHASH -> (preimage, invoice)
self.channels = {} # type: Dict[bytes, Channel] self.channels = {} # type: Dict[bytes, Channel]
for x in wallet.storage.get("channels", []): for x in wallet.storage.get("channels", []):
c = Channel(x, sweep_address=self.sweep_address, payment_completed=self.payment_completed) c = Channel(x, sweep_address=self.sweep_address, payment_completed=self.payment_completed)
self.channels[c.channel_id] = c
c.lnwatcher = network.lnwatcher c.lnwatcher = network.lnwatcher
self.invoices = wallet.storage.get('lightning_invoices', {}) # type: Dict[str, Tuple[str,str]] # RHASH -> (preimage, invoice) c.get_preimage_and_invoice = self.get_invoice
self.channels[c.channel_id] = c
c.set_remote_commitment()
c.set_local_commitment(c.current_commitment(LOCAL))
for chan_id, chan in self.channels.items(): for chan_id, chan in self.channels.items():
self.network.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str()) self.network.lnwatcher.watch_channel(chan.get_funding_address(), chan.funding_outpoint.to_str())
self._last_tried_peer = {} # LNPeerAddr -> unix timestamp self._last_tried_peer = {} # LNPeerAddr -> unix timestamp
@ -116,6 +120,7 @@ class LNWorker(PrintError):
self.print_error('saved lightning gossip timestamp') self.print_error('saved lightning gossip timestamp')
def payment_completed(self, chan, direction, htlc, preimage): def payment_completed(self, chan, direction, htlc, preimage):
assert type(direction) is Direction
chan_id = chan.channel_id chan_id = chan.channel_id
if direction == SENT: if direction == SENT:
assert htlc.payment_hash not in self.invoices assert htlc.payment_hash not in self.invoices
@ -166,6 +171,7 @@ class LNWorker(PrintError):
unsettled = [] unsettled = []
inflight = [] inflight = []
for date, direction, htlc, hex_preimage, hex_chan_id in completed: for date, direction, htlc, hex_preimage, hex_chan_id in completed:
direction = Direction(direction)
if chan_id is not None: if chan_id is not None:
if bfh(hex_chan_id) != chan_id: if bfh(hex_chan_id) != chan_id:
continue continue
@ -175,12 +181,12 @@ class LNWorker(PrintError):
else: else:
preimage = bfh(hex_preimage) preimage = bfh(hex_preimage)
# FIXME use fromisoformat when minimum Python is 3.7 # FIXME use fromisoformat when minimum Python is 3.7
settled.append((datetime.fromtimestamp(date, timezone.utc), HTLCOwner(direction), htlcobj, preimage)) settled.append((datetime.fromtimestamp(date, timezone.utc), direction, htlcobj, preimage))
for preimage, pay_req in invoices.values(): for preimage, pay_req in invoices.values():
addr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP) addr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP)
unsettled.append((addr, bfh(preimage), pay_req)) unsettled.append((addr, bfh(preimage), pay_req))
for pay_req, amount_sat, this_chan_id in self.paying.values(): for pay_req, amount_sat, this_chan_id in self.paying.values():
if chan_id is not None and this_chan_id != chan_id: if chan_id is not None and bfh(this_chan_id) != chan_id:
continue continue
addr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP) addr = lndecode(pay_req, expected_hrp=constants.net.SEGWIT_HRP)
if amount_sat is not None: if amount_sat is not None:
@ -194,7 +200,7 @@ class LNWorker(PrintError):
def find_htlc_for_addr(self, addr, whitelist=None): def find_htlc_for_addr(self, addr, whitelist=None):
channels = [y for x,y in self.channels.items() if x in whitelist or whitelist is None] channels = [y for x,y in self.channels.items() if x in whitelist or whitelist is None]
for chan in channels: for chan in channels:
for htlc in chan.log[LOCAL].adds.values(): for htlc in chan.hm.log[LOCAL]['adds'].values():
if htlc.payment_hash == addr.paymenthash: if htlc.payment_hash == addr.paymenthash:
return htlc return htlc
@ -319,7 +325,7 @@ class LNWorker(PrintError):
self.print_error('they force closed', funding_outpoint) self.print_error('they force closed', funding_outpoint)
encumbered_sweeptxs = chan.remote_sweeptxs encumbered_sweeptxs = chan.remote_sweeptxs
else: else:
self.print_error('not sure who closed', funding_outpoint) self.print_error('not sure who closed', funding_outpoint, txid)
return return
# sweep # sweep
for prevout, spender in spenders.items(): for prevout, spender in spenders.items():
@ -456,7 +462,7 @@ class LNWorker(PrintError):
break break
else: else:
assert False, 'Found route with short channel ID we don\'t have: ' + repr(route[0].short_channel_id) assert False, 'Found route with short channel ID we don\'t have: ' + repr(route[0].short_channel_id)
self.paying[bh2u(addr.paymenthash)] = (invoice, amount_sat, chan_id) self.paying[bh2u(addr.paymenthash)] = (invoice, amount_sat, bh2u(chan_id))
self.wallet.storage.put('lightning_payments_inflight', self.paying) self.wallet.storage.put('lightning_payments_inflight', self.paying)
self.wallet.storage.write() self.wallet.storage.write()
return addr, peer, self._pay_to_route(route, addr) return addr, peer, self._pay_to_route(route, addr)
@ -623,8 +629,8 @@ class LNWorker(PrintError):
# we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels # we output the funding_outpoint instead of the channel_id because lnd uses channel_point (funding outpoint) to identify channels
for channel_id, chan in self.channels.items(): for channel_id, chan in self.channels.items():
yield { yield {
'local_htlcs': json.loads(encoder.encode(chan.log[LOCAL ]._asdict())), 'local_htlcs': json.loads(encoder.encode(chan.hm.log[LOCAL ])),
'remote_htlcs': json.loads(encoder.encode(chan.log[REMOTE]._asdict())), 'remote_htlcs': json.loads(encoder.encode(chan.hm.log[REMOTE])),
'channel_id': bh2u(chan.short_channel_id), 'channel_id': bh2u(chan.short_channel_id),
'channel_point': chan.funding_outpoint.to_str(), 'channel_point': chan.funding_outpoint.to_str(),
'state': chan.get_state(), 'state': chan.get_state(),

View file

@ -22,6 +22,7 @@
import unittest import unittest
import os import os
import binascii import binascii
from pprint import pformat
from electrum import bitcoin from electrum import bitcoin
from electrum import lnbase from electrum import lnbase
@ -30,6 +31,7 @@ from electrum import lnutil
from electrum import bip32 as bip32_utils from electrum import bip32 as bip32_utils
from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED from electrum.lnutil import SENT, LOCAL, REMOTE, RECEIVED
from electrum.ecc import sig_string_from_der_sig from electrum.ecc import sig_string_from_der_sig
from electrum.util import set_verbosity
one_bitcoin_in_msat = bitcoin.COIN * 1000 one_bitcoin_in_msat = bitcoin.COIN * 1000
@ -54,9 +56,8 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
max_htlc_value_in_flight_msat=one_bitcoin_in_msat * 5, max_htlc_value_in_flight_msat=one_bitcoin_in_msat * 5,
max_accepted_htlcs=5, max_accepted_htlcs=5,
initial_msat=remote_amount, initial_msat=remote_amount,
ctn = 0, ctn = -1,
next_htlc_id = 0, next_htlc_id = 0,
amount_msat=remote_amount,
reserve_sat=0, reserve_sat=0,
next_per_commitment_point=nex, next_per_commitment_point=nex,
@ -76,7 +77,6 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
initial_msat=local_amount, initial_msat=local_amount,
ctn = 0, ctn = 0,
next_htlc_id = 0, next_htlc_id = 0,
amount_msat=local_amount,
reserve_sat=0, reserve_sat=0,
per_commitment_secret_seed=seed, per_commitment_secret_seed=seed,
@ -84,6 +84,7 @@ def create_channel_state(funding_txid, funding_index, funding_sat, local_feerate
was_announced=False, was_announced=False,
current_commitment_signature=None, current_commitment_signature=None,
current_htlc_signatures=None, current_htlc_signatures=None,
got_sig_for_next=False,
), ),
"constraints":lnbase.ChannelConstraints( "constraints":lnbase.ChannelConstraints(
capacity=funding_sat, capacity=funding_sat,
@ -105,7 +106,7 @@ def bip32(sequence):
return k return k
def create_test_channels(feerate=6000, local=None, remote=None): def create_test_channels(feerate=6000, local=None, remote=None):
funding_txid = binascii.hexlify(os.urandom(32)).decode("ascii") funding_txid = binascii.hexlify(b"\x01"*32).decode("ascii")
funding_index = 0 funding_index = 0
funding_sat = ((local + remote) // 1000) if local is not None and remote is not None else (bitcoin.COIN * 10) funding_sat = ((local + remote) // 1000) if local is not None and remote is not None else (bitcoin.COIN * 10)
local_amount = local if local is not None else (funding_sat * 1000 // 2) local_amount = local if local is not None else (funding_sat * 1000 // 2)
@ -117,23 +118,52 @@ def create_test_channels(feerate=6000, local=None, remote=None):
alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys] alice_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in alice_privkeys]
bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys] bob_pubkeys = [lnutil.OnlyPubkeyKeypair(x.pubkey) for x in bob_privkeys]
alice_seed = os.urandom(32) alice_seed = b"\x01" * 32
bob_seed = os.urandom(32) bob_seed = b"\x02" * 32
alice_cur = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big")) alice_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX), "big"))
alice_next = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big")) bob_first = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
bob_cur = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX), "big"))
bob_next = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
alice, bob = \ alice, bob = \
lnchan.Channel( lnchan.Channel(
create_channel_state(funding_txid, funding_index, funding_sat, feerate, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, bob_cur, bob_next, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), name="alice"), \ create_channel_state(funding_txid, funding_index, funding_sat, feerate, True, local_amount, remote_amount, alice_privkeys, bob_pubkeys, alice_seed, None, bob_first, b"\x02"*33, l_dust=200, r_dust=1300, l_csv=5, r_csv=4), name="alice"), \
lnchan.Channel( lnchan.Channel(
create_channel_state(funding_txid, funding_index, funding_sat, feerate, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, alice_cur, alice_next, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), name="bob") create_channel_state(funding_txid, funding_index, funding_sat, feerate, False, remote_amount, local_amount, bob_privkeys, alice_pubkeys, bob_seed, None, alice_first, b"\x01"*33, l_dust=1300, r_dust=200, l_csv=4, r_csv=5), name="bob")
alice.set_state('OPEN') alice.set_state('OPEN')
bob.set_state('OPEN') bob.set_state('OPEN')
a_out = alice.current_commitment(LOCAL).outputs()
b_out = bob.pending_commitment(REMOTE).outputs()
assert a_out == b_out, "\n" + pformat((a_out, b_out))
sig_from_bob, a_htlc_sigs = bob.sign_next_commitment()
sig_from_alice, b_htlc_sigs = alice.sign_next_commitment()
assert len(a_htlc_sigs) == 0
assert len(b_htlc_sigs) == 0
alice.config[LOCAL] = alice.config[LOCAL]._replace(current_commitment_signature=sig_from_bob)
bob.config[LOCAL] = bob.config[LOCAL]._replace(current_commitment_signature=sig_from_alice)
alice.set_local_commitment(alice.current_commitment(LOCAL))
bob.set_local_commitment(bob.current_commitment(LOCAL))
alice_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(alice_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
bob_second = lnutil.secret_to_pubkey(int.from_bytes(lnutil.get_per_commitment_secret_from_seed(bob_seed, lnutil.RevocationStore.START_INDEX - 1), "big"))
alice.config[REMOTE] = alice.config[REMOTE]._replace(next_per_commitment_point=bob_second, current_per_commitment_point=bob_first)
bob.config[REMOTE] = bob.config[REMOTE]._replace(next_per_commitment_point=alice_second, current_per_commitment_point=alice_first)
alice.set_remote_commitment()
bob.set_remote_commitment()
alice.remote_commitment_to_be_revoked = alice.remote_commitment
bob.remote_commitment_to_be_revoked = bob.remote_commitment
alice.config[REMOTE] = alice.config[REMOTE]._replace(ctn=0)
bob.config[REMOTE] = bob.config[REMOTE]._replace(ctn=0)
return alice, bob return alice, bob
class TestFee(unittest.TestCase): class TestFee(unittest.TestCase):
@ -141,11 +171,13 @@ class TestFee(unittest.TestCase):
test test
https://github.com/lightningnetwork/lightning-rfc/blob/e0c436bd7a3ed6a028e1cb472908224658a14eca/03-transactions.md#requirements-2 https://github.com/lightningnetwork/lightning-rfc/blob/e0c436bd7a3ed6a028e1cb472908224658a14eca/03-transactions.md#requirements-2
""" """
def test_SimpleAddSettleWorkflow(self): def test_fee(self):
alice_channel, bob_channel = create_test_channels(253, 10000000000, 5000000000) alice_channel, bob_channel = create_test_channels(253, 10000000000, 5000000000)
self.assertIn(9999817, [x[2] for x in alice_channel.local_commitment.outputs()]) self.assertIn(9999817, [x[2] for x in alice_channel.local_commitment.outputs()])
class TestChannel(unittest.TestCase): class TestChannel(unittest.TestCase):
maxDiff = 999
def assertOutputExistsByValue(self, tx, amt_sat): def assertOutputExistsByValue(self, tx, amt_sat):
for typ, scr, val in tx.outputs(): for typ, scr, val in tx.outputs():
if val == amt_sat: if val == amt_sat:
@ -153,6 +185,10 @@ class TestChannel(unittest.TestCase):
else: else:
self.assertFalse() self.assertFalse()
@staticmethod
def setUpClass():
set_verbosity(True)
def setUp(self): def setUp(self):
# Create a test channel which will be used for the duration of this # Create a test channel which will be used for the duration of this
# unittest. The channel will be funded evenly with Alice having 5 BTC, # unittest. The channel will be funded evenly with Alice having 5 BTC,
@ -171,12 +207,15 @@ class TestChannel(unittest.TestCase):
# update log. Then Alice sends this wire message over to Bob who adds # update log. Then Alice sends this wire message over to Bob who adds
# this htlc to his remote state update log. # this htlc to his remote state update log.
self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict) self.aliceHtlcIndex = self.alice_channel.add_htlc(self.htlc_dict)
self.assertNotEqual(self.alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED, 1), set())
before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE) before = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE)
beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL) beforeLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL)
self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict) self.bobHtlcIndex = self.bob_channel.receive_htlc(self.htlc_dict)
self.assertEqual(1, self.bob_channel.hm.log[LOCAL]['ctn'] + 1)
self.assertNotEqual(self.bob_channel.hm.htlcs_by_direction(LOCAL, RECEIVED, 1), set())
after = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE) after = self.bob_channel.balance_minus_outgoing_htlcs(REMOTE)
afterLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL) afterLocal = self.bob_channel.balance_minus_outgoing_htlcs(LOCAL)
@ -185,7 +224,7 @@ class TestChannel(unittest.TestCase):
self.bob_pending_remote_balance = after self.bob_pending_remote_balance = after
self.htlc = self.bob_channel.log[lnutil.REMOTE].adds[0] self.htlc = self.bob_channel.hm.log[REMOTE]['adds'][0]
def test_concurrent_reversed_payment(self): def test_concurrent_reversed_payment(self):
self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02') self.htlc_dict['payment_hash'] = bitcoin.sha256(32 * b'\x02')
@ -193,32 +232,65 @@ class TestChannel(unittest.TestCase):
bob_idx = self.bob_channel.add_htlc(self.htlc_dict) bob_idx = self.bob_channel.add_htlc(self.htlc_dict)
alice_idx = self.alice_channel.receive_htlc(self.htlc_dict) alice_idx = self.alice_channel.receive_htlc(self.htlc_dict)
self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment()) self.alice_channel.receive_new_commitment(*self.bob_channel.sign_next_commitment())
self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 3) self.assertEqual(len(self.alice_channel.pending_commitment(REMOTE).outputs()), 4)
def test_SimpleAddSettleWorkflow(self): def test_SimpleAddSettleWorkflow(self):
alice_channel, bob_channel = self.alice_channel, self.bob_channel alice_channel, bob_channel = self.alice_channel, self.bob_channel
htlc = self.htlc htlc = self.htlc
alice_out = alice_channel.current_commitment(LOCAL).outputs()
short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42]
long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62]
self.assertLess(alice_out[long_idx].value, 5 * 10**8, alice_out)
self.assertEqual(alice_out[short_idx].value, 5 * 10**8, alice_out)
alice_out = alice_channel.current_commitment(REMOTE).outputs()
short_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 42]
long_idx, = [idx for idx, x in enumerate(alice_out) if len(x.address) == 62]
self.assertLess(alice_out[short_idx].value, 5 * 10**8)
self.assertEqual(alice_out[long_idx].value, 5 * 10**8)
def com():
return alice_channel.local_commitment
self.assertTrue(alice_channel.signature_fits(com()))
self.assertNotEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, 1), [])
self.assertEqual({0: [], 1: [htlc]}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) self.assertEqual({0: [], 1: [htlc]}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
self.assertEqual({0: [], 1: []}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertNotEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [])
self.assertEqual({0: [], 1: [htlc]}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
self.assertEqual({0: [], 1: []}, alice_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertEqual({0: [], 1: []}, alice_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
self.assertEqual({0: [], 1: []}, bob_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) self.assertEqual({0: [], 1: []}, bob_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
# this wouldn't work since we put None in the remote_sig from electrum.lnutil import extract_ctn_from_tx_and_chan
# alice_channel.force_close_tx() tx0 = str(alice_channel.force_close_tx())
self.assertEqual(alice_channel.config[LOCAL].ctn, 0)
self.assertEqual(extract_ctn_from_tx_and_chan(alice_channel.force_close_tx(), alice_channel), 0)
self.assertTrue(alice_channel.signature_fits(alice_channel.current_commitment(LOCAL)))
# Next alice commits this change by sending a signature message. Since # Next alice commits this change by sending a signature message. Since
# we expect the messages to be ordered, Bob will receive the HTLC we # we expect the messages to be ordered, Bob will receive the HTLC we
# just sent before he receives this signature, so the signature will # just sent before he receives this signature, so the signature will
# cover the HTLC. # cover the HTLC.
aliceSig, aliceHtlcSigs = alice_channel.sign_next_commitment() aliceSig, aliceHtlcSigs = alice_channel.sign_next_commitment()
self.assertEqual(len(aliceHtlcSigs), 1, "alice should generate one htlc signature") self.assertEqual(len(aliceHtlcSigs), 1, "alice should generate one htlc signature")
self.assertTrue(alice_channel.signature_fits(com()))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(next(iter(alice_channel.hm.pending_htlcs(REMOTE)))[0], RECEIVED)
self.assertEqual(alice_channel.hm.pending_htlcs(REMOTE), bob_channel.hm.pending_htlcs(LOCAL))
self.assertEqual(alice_channel.pending_commitment(REMOTE).outputs(), bob_channel.pending_commitment(LOCAL).outputs())
# Bob receives this signature message, and checks that this covers the # Bob receives this signature message, and checks that this covers the
# state he has in his remote log. This includes the HTLC just sent # state he has in his remote log. This includes the HTLC just sent
# from Alice. # from Alice.
self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL)))
bob_channel.receive_new_commitment(aliceSig, aliceHtlcSigs) bob_channel.receive_new_commitment(aliceSig, aliceHtlcSigs)
self.assertTrue(bob_channel.signature_fits(bob_channel.pending_commitment(LOCAL)))
self.assertEqual(bob_channel.config[REMOTE].ctn, 0)
self.assertEqual(bob_channel.included_htlcs(REMOTE, SENT, 1), [htlc])
self.assertEqual({0: [], 1: [htlc]}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) self.assertEqual({0: [], 1: [htlc]}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
self.assertEqual({0: [], 1: [htlc]}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertEqual({0: [], 1: [htlc]}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
@ -228,31 +300,68 @@ class TestChannel(unittest.TestCase):
# Bob revokes his prior commitment given to him by Alice, since he now # Bob revokes his prior commitment given to him by Alice, since he now
# has a valid signature for a newer commitment. # has a valid signature for a newer commitment.
bobRevocation, _ = bob_channel.revoke_current_commitment() bobRevocation, _ = bob_channel.revoke_current_commitment()
bob_channel.serialize()
self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL)))
# Bob finally send a signature for Alice's commitment transaction. # Bob finally sends a signature for Alice's commitment transaction.
# This signature will cover the HTLC, since Bob will first send the # This signature will cover the HTLC, since Bob will first send the
# revocation just created. The revocation also acks every received # revocation just created. The revocation also acks every received
# HTLC up to the point where Alice sent here signature. # HTLC up to the point where Alice sent her signature.
bobSig, bobHtlcSigs = bob_channel.sign_next_commitment() bobSig, bobHtlcSigs = bob_channel.sign_next_commitment()
self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL)))
self.assertEqual(len(bobHtlcSigs), 1)
self.assertTrue(alice_channel.signature_fits(com()))
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 3)
# Alice then processes this revocation, sending her own revocation for # Alice then processes this revocation, sending her own revocation for
# her prior commitment transaction. Alice shouldn't have any HTLCs to # her prior commitment transaction. Alice shouldn't have any HTLCs to
# forward since she's sending an outgoing HTLC. # forward since she's sending an outgoing HTLC.
alice_channel.receive_revocation(bobRevocation) alice_channel.receive_revocation(bobRevocation)
# test serializing with locked_in htlc
self.assertEqual(len(alice_channel.to_save()['local_log']), 1)
alice_channel.serialize() alice_channel.serialize()
self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
self.assertTrue(alice_channel.signature_fits(com()))
self.assertTrue(alice_channel.signature_fits(alice_channel.current_commitment(LOCAL)))
alice_channel.serialize()
self.assertEqual(str(alice_channel.current_commitment(LOCAL)), str(com()))
self.assertEqual(len(alice_channel.current_commitment(LOCAL).outputs()), 2)
self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 3)
self.assertEqual(len(com().outputs()), 2)
self.assertEqual(len(alice_channel.force_close_tx().outputs()), 2)
self.assertEqual(alice_channel.hm.log.keys(), set([LOCAL, REMOTE]))
self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
alice_channel.serialize()
self.assertEqual(alice_channel.pending_commitment(LOCAL).outputs(),
bob_channel.pending_commitment(REMOTE).outputs())
# Alice then processes bob's signature, and since she just received # Alice then processes bob's signature, and since she just received
# the revocation, she expect this signature to cover everything up to # the revocation, she expect this signature to cover everything up to
# the point where she sent her signature, including the HTLC. # the point where she sent her signature, including the HTLC.
alice_channel.receive_new_commitment(bobSig, bobHtlcSigs) alice_channel.receive_new_commitment(bobSig, bobHtlcSigs)
self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
self.assertEqual(len(alice_channel.current_commitment(REMOTE).outputs()), 3)
self.assertEqual(len(com().outputs()), 3)
self.assertEqual(len(alice_channel.force_close_tx().outputs()), 3)
self.assertEqual(len(alice_channel.hm.log[LOCAL]['adds']), 1)
alice_channel.serialize()
tx1 = str(alice_channel.force_close_tx()) tx1 = str(alice_channel.force_close_tx())
self.assertNotEqual(tx0, tx1)
# Alice then generates a revocation for bob. # Alice then generates a revocation for bob.
self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
aliceRevocation, _ = alice_channel.revoke_current_commitment() aliceRevocation, _ = alice_channel.revoke_current_commitment()
alice_channel.serialize()
#self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
tx2 = str(alice_channel.force_close_tx()) tx2 = str(alice_channel.force_close_tx())
# since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one) # since alice already has the signature for the next one, it doesn't change her force close tx (it was already the newer one)
@ -262,7 +371,9 @@ class TestChannel(unittest.TestCase):
# is fully locked in within both commitment transactions. Bob should # is fully locked in within both commitment transactions. Bob should
# also be able to forward an HTLC now that the HTLC has been locked # also be able to forward an HTLC now that the HTLC has been locked
# into both commitment transactions. # into both commitment transactions.
self.assertTrue(bob_channel.signature_fits(bob_channel.current_commitment(LOCAL)))
bob_channel.receive_revocation(aliceRevocation) bob_channel.receive_revocation(aliceRevocation)
bob_channel.serialize()
# At this point, both sides should have the proper number of satoshis # At this point, both sides should have the proper number of satoshis
# sent, and commitment height updated within their local channel # sent, and commitment height updated within their local channel
@ -279,16 +390,19 @@ class TestChannel(unittest.TestCase):
# Both commitment transactions should have three outputs, and one of # Both commitment transactions should have three outputs, and one of
# them should be exactly the amount of the HTLC. # them should be exactly the amount of the HTLC.
self.assertEqual(len(alice_channel.local_commitment.outputs()), 3, "alice should have three commitment outputs, instead have %s"% len(alice_channel.local_commitment.outputs())) alice_ctx = alice_channel.pending_commitment(LOCAL)
self.assertEqual(len(bob_channel.local_commitment.outputs()), 3, "bob should have three commitment outputs, instead have %s"% len(bob_channel.local_commitment.outputs())) bob_ctx = bob_channel.pending_commitment(LOCAL)
self.assertOutputExistsByValue(alice_channel.local_commitment, htlc.amount_msat // 1000) self.assertEqual(len(alice_ctx.outputs()), 3, "alice should have three commitment outputs, instead have %s"% len(alice_ctx.outputs()))
self.assertOutputExistsByValue(bob_channel.local_commitment, htlc.amount_msat // 1000) self.assertEqual(len(bob_ctx.outputs()), 3, "bob should have three commitment outputs, instead have %s"% len(bob_ctx.outputs()))
self.assertOutputExistsByValue(alice_ctx, htlc.amount_msat // 1000)
self.assertOutputExistsByValue(bob_ctx, htlc.amount_msat // 1000)
# Now we'll repeat a similar exchange, this time with Bob settling the # Now we'll repeat a similar exchange, this time with Bob settling the
# HTLC once he learns of the preimage. # HTLC once he learns of the preimage.
preimage = self.paymentPreimage preimage = self.paymentPreimage
bob_channel.settle_htlc(preimage, self.bobHtlcIndex) bob_channel.settle_htlc(preimage, self.bobHtlcIndex)
#self.assertEqual(alice_channel.remote_commitment.outputs(), alice_channel.current_commitment(REMOTE).outputs())
alice_channel.receive_htlc_settle(preimage, self.aliceHtlcIndex) alice_channel.receive_htlc_settle(preimage, self.aliceHtlcIndex)
tx3 = str(alice_channel.force_close_tx()) tx3 = str(alice_channel.force_close_tx())
@ -296,28 +410,43 @@ class TestChannel(unittest.TestCase):
self.assertEqual(tx2, tx3) self.assertEqual(tx2, tx3)
bobSig2, bobHtlcSigs2 = bob_channel.sign_next_commitment() bobSig2, bobHtlcSigs2 = bob_channel.sign_next_commitment()
self.assertEqual(len(bobHtlcSigs2), 0)
self.assertEqual(alice_channel.hm.htlcs_by_direction(REMOTE, RECEIVED), [htlc])
self.assertEqual(alice_channel.included_htlcs(REMOTE, RECEIVED, alice_channel.config[REMOTE].ctn), [htlc])
self.assertEqual({1: [htlc], 2: []}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) self.assertEqual({1: [htlc], 2: []}, alice_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
self.assertEqual({1: [htlc], 2: []}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertEqual({1: [htlc], 2: []}, bob_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
self.assertEqual({1: [], 2: []}, alice_channel.included_htlcs_in_their_latest_ctxs(REMOTE)) self.assertEqual({1: [], 2: []}, alice_channel.included_htlcs_in_their_latest_ctxs(REMOTE))
self.assertEqual({1: [], 2: []}, bob_channel.included_htlcs_in_their_latest_ctxs(LOCAL)) self.assertEqual({1: [], 2: []}, bob_channel.included_htlcs_in_their_latest_ctxs(LOCAL))
alice_ctx_bob_version = bob_channel.pending_commitment(REMOTE).outputs()
alice_ctx_alice_version = alice_channel.pending_commitment(LOCAL).outputs()
self.assertEqual(alice_ctx_alice_version, alice_ctx_bob_version)
alice_channel.receive_new_commitment(bobSig2, bobHtlcSigs2) alice_channel.receive_new_commitment(bobSig2, bobHtlcSigs2)
tx4 = str(alice_channel.force_close_tx()) tx4 = str(alice_channel.force_close_tx())
self.assertNotEqual(tx3, tx4) self.assertNotEqual(tx3, tx4)
self.assertEqual(alice_channel.balance(LOCAL), 500000000000)
self.assertEqual(1, alice_channel.config[LOCAL].ctn)
self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0)
aliceRevocation2, _ = alice_channel.revoke_current_commitment() aliceRevocation2, _ = alice_channel.revoke_current_commitment()
alice_channel.serialize()
aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment() aliceSig2, aliceHtlcSigs2 = alice_channel.sign_next_commitment()
self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures") self.assertEqual(aliceHtlcSigs2, [], "alice should generate no htlc signatures")
self.assertEqual(len(bob_channel.current_commitment(LOCAL).outputs()), 3)
self.assertEqual(len(bob_channel.pending_commitment(LOCAL).outputs()), 2)
received, sent = bob_channel.receive_revocation(aliceRevocation2) received, sent = bob_channel.receive_revocation(aliceRevocation2)
bob_channel.serialize()
self.assertEqual(received, one_bitcoin_in_msat) self.assertEqual(received, one_bitcoin_in_msat)
bob_channel.receive_new_commitment(aliceSig2, aliceHtlcSigs2) bob_channel.receive_new_commitment(aliceSig2, aliceHtlcSigs2)
bobRevocation2, _ = bob_channel.revoke_current_commitment() bobRevocation2, _ = bob_channel.revoke_current_commitment()
bob_channel.serialize()
alice_channel.receive_revocation(bobRevocation2) alice_channel.receive_revocation(bobRevocation2)
alice_channel.serialize()
# At this point, Bob should have 6 BTC settled, with Alice still having # At this point, Bob should have 6 BTC settled, with Alice still having
# 4 BTC. Alice's channel should show 1 BTC sent and Bob's channel # 4 BTC. Alice's channel should show 1 BTC sent and Bob's channel
@ -331,15 +460,15 @@ class TestChannel(unittest.TestCase):
self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height") self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height")
self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height") self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height")
# The logs of both sides should now be cleared since the entry adding
# the HTLC should have been removed once both sides receive the
# revocation.
#self.assertEqual(alice_channel.local_update_log, [], "alice's local not updated, should be empty, has %s entries instead"% len(alice_channel.local_update_log))
#self.assertEqual(alice_channel.remote_update_log, [], "alice's remote not updated, should be empty, has %s entries instead"% len(alice_channel.remote_update_log))
self.assertEqual(self.bob_pending_remote_balance, self.alice_channel.balance(LOCAL)) self.assertEqual(self.bob_pending_remote_balance, self.alice_channel.balance(LOCAL))
alice_channel.update_fee(100000, True) alice_channel.update_fee(100000, True)
alice_outputs = alice_channel.pending_commitment(REMOTE).outputs()
old_outputs = bob_channel.pending_commitment(LOCAL).outputs()
bob_channel.update_fee(100000, False) bob_channel.update_fee(100000, False)
new_outputs = bob_channel.pending_commitment(LOCAL).outputs()
self.assertNotEqual(old_outputs, new_outputs)
self.assertEqual(alice_outputs, new_outputs)
tx5 = str(alice_channel.force_close_tx()) tx5 = str(alice_channel.force_close_tx())
# sending a fee update does not change her force close tx # sending a fee update does not change her force close tx
@ -353,10 +482,17 @@ class TestChannel(unittest.TestCase):
self.htlc_dict['amount_msat'] *= 5 self.htlc_dict['amount_msat'] *= 5
bob_index = bob_channel.add_htlc(self.htlc_dict) bob_index = bob_channel.add_htlc(self.htlc_dict)
alice_index = alice_channel.receive_htlc(self.htlc_dict) alice_index = alice_channel.receive_htlc(self.htlc_dict)
force_state_transition(alice_channel, bob_channel)
bob_channel.pending_commitment(REMOTE)
alice_channel.pending_commitment(LOCAL)
alice_channel.pending_commitment(REMOTE)
bob_channel.pending_commitment(LOCAL)
force_state_transition(bob_channel, alice_channel)
alice_channel.settle_htlc(self.paymentPreimage, alice_index) alice_channel.settle_htlc(self.paymentPreimage, alice_index)
bob_channel.receive_htlc_settle(self.paymentPreimage, bob_index) bob_channel.receive_htlc_settle(self.paymentPreimage, bob_index)
force_state_transition(alice_channel, bob_channel) force_state_transition(bob_channel, alice_channel)
self.assertEqual(alice_channel.total_msat(SENT), one_bitcoin_in_msat, "alice satoshis sent incorrect") self.assertEqual(alice_channel.total_msat(SENT), one_bitcoin_in_msat, "alice satoshis sent incorrect")
self.assertEqual(alice_channel.total_msat(RECEIVED), 5 * one_bitcoin_in_msat, "alice satoshis received incorrect") self.assertEqual(alice_channel.total_msat(RECEIVED), 5 * one_bitcoin_in_msat, "alice satoshis received incorrect")
self.assertEqual(bob_channel.total_msat(RECEIVED), one_bitcoin_in_msat, "bob satoshis received incorrect") self.assertEqual(bob_channel.total_msat(RECEIVED), one_bitcoin_in_msat, "bob satoshis received incorrect")
@ -366,8 +502,15 @@ class TestChannel(unittest.TestCase):
def alice_to_bob_fee_update(self, fee=111): def alice_to_bob_fee_update(self, fee=111):
aoldctx = self.alice_channel.pending_commitment(REMOTE).outputs()
self.alice_channel.update_fee(fee, True) self.alice_channel.update_fee(fee, True)
anewctx = self.alice_channel.pending_commitment(REMOTE).outputs()
self.assertNotEqual(aoldctx, anewctx)
boldctx = self.bob_channel.pending_commitment(LOCAL).outputs()
self.bob_channel.update_fee(fee, False) self.bob_channel.update_fee(fee, False)
bnewctx = self.bob_channel.pending_commitment(LOCAL).outputs()
self.assertNotEqual(boldctx, bnewctx)
self.assertEqual(anewctx, bnewctx)
return fee return fee
def test_UpdateFeeSenderCommits(self): def test_UpdateFeeSenderCommits(self):
@ -444,7 +587,7 @@ class TestChannel(unittest.TestCase):
# value 2 BTC, which should make Alice's balance negative (since she # value 2 BTC, which should make Alice's balance negative (since she
# has to pay a commitment fee). # has to pay a commitment fee).
new = dict(self.htlc_dict) new = dict(self.htlc_dict)
new['amount_msat'] *= 2 new['amount_msat'] *= 2.5
new['payment_hash'] = bitcoin.sha256(32 * b'\x04') new['payment_hash'] = bitcoin.sha256(32 * b'\x04')
with self.assertRaises(lnutil.PaymentFailure) as cm: with self.assertRaises(lnutil.PaymentFailure) as cm:
self.alice_channel.add_htlc(new) self.alice_channel.add_htlc(new)
@ -462,7 +605,6 @@ class TestChannel(unittest.TestCase):
except: except:
try: try:
from deepdiff import DeepDiff from deepdiff import DeepDiff
from pprint import pformat
except ImportError: except ImportError:
raise raise
raise Exception(pformat(DeepDiff(before_signing, after_signing))) raise Exception(pformat(DeepDiff(before_signing, after_signing)))
@ -549,9 +691,9 @@ class TestChanReserve(unittest.TestCase):
force_state_transition(self.alice_channel, self.bob_channel) force_state_transition(self.alice_channel, self.bob_channel)
aliceSelfBalance = self.alice_channel.balance(LOCAL)\ aliceSelfBalance = self.alice_channel.balance(LOCAL)\
- lnchan.htlcsum(self.alice_channel.htlcs(LOCAL, True)) - lnchan.htlcsum(self.alice_channel.hm.htlcs_by_direction(LOCAL, SENT))
bobBalance = self.bob_channel.balance(REMOTE)\ bobBalance = self.bob_channel.balance(REMOTE)\
- lnchan.htlcsum(self.alice_channel.htlcs(REMOTE, True)) - lnchan.htlcsum(self.alice_channel.hm.htlcs_by_direction(REMOTE, SENT))
self.assertEqual(aliceSelfBalance, one_bitcoin_in_msat*4.5) self.assertEqual(aliceSelfBalance, one_bitcoin_in_msat*4.5)
self.assertEqual(bobBalance, one_bitcoin_in_msat*5) self.assertEqual(bobBalance, one_bitcoin_in_msat*5)
# Now let Bob try to add an HTLC. This should fail, since it will # Now let Bob try to add an HTLC. This should fail, since it will
@ -647,17 +789,22 @@ class TestDust(unittest.TestCase):
'cltv_expiry' : 5, # also in create_test_channels 'cltv_expiry' : 5, # also in create_test_channels
} }
old_values = [x.value for x in bob_channel.current_commitment(LOCAL).outputs() ]
aliceHtlcIndex = alice_channel.add_htlc(htlc) aliceHtlcIndex = alice_channel.add_htlc(htlc)
bobHtlcIndex = bob_channel.receive_htlc(htlc) bobHtlcIndex = bob_channel.receive_htlc(htlc)
force_state_transition(alice_channel, bob_channel) force_state_transition(alice_channel, bob_channel)
self.assertEqual(len(alice_channel.local_commitment.outputs()), 3) alice_ctx = alice_channel.current_commitment(LOCAL)
self.assertEqual(len(bob_channel.local_commitment.outputs()), 2) bob_ctx = bob_channel.current_commitment(LOCAL)
new_values = [x.value for x in bob_ctx.outputs() ]
self.assertNotEqual(old_values, new_values)
self.assertEqual(len(alice_ctx.outputs()), 3)
self.assertEqual(len(bob_ctx.outputs()), 2)
default_fee = calc_static_fee(0) default_fee = calc_static_fee(0)
self.assertEqual(bob_channel.pending_local_fee(), default_fee + htlcAmt) self.assertEqual(bob_channel.pending_local_fee(), default_fee + htlcAmt)
bob_channel.settle_htlc(paymentPreimage, bobHtlcIndex) bob_channel.settle_htlc(paymentPreimage, bobHtlcIndex)
alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex) alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex)
force_state_transition(bob_channel, alice_channel) force_state_transition(bob_channel, alice_channel)
self.assertEqual(len(alice_channel.local_commitment.outputs()), 2) self.assertEqual(len(alice_channel.pending_commitment(LOCAL).outputs()), 2)
self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt) self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt)
def force_state_transition(chanA, chanB): def force_state_transition(chanA, chanB):

View file

@ -0,0 +1,95 @@
import unittest
from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner
from electrum.lnhtlc import HTLCManager
from typing import NamedTuple
class H(NamedTuple):
owner : str
htlc_id : int
class TestHTLCManager(unittest.TestCase):
def test_race(self):
A = HTLCManager()
B = HTLCManager()
ah0, bh0 = H('A', 0), H('B', 0)
B.recv_htlc(A.send_htlc(ah0))
self.assertTrue(B.expect_sig[RECEIVED])
self.assertTrue(A.expect_sig[SENT])
self.assertFalse(B.expect_sig[SENT])
self.assertFalse(A.expect_sig[RECEIVED])
self.assertEqual(B.log[REMOTE]['locked_in'][0][LOCAL], 1)
A.recv_htlc(B.send_htlc(bh0))
self.assertTrue(B.expect_sig[RECEIVED])
self.assertTrue(A.expect_sig[SENT])
self.assertTrue(A.expect_sig[SENT])
self.assertTrue(B.expect_sig[RECEIVED])
self.assertEqual(B.current_htlcs(LOCAL), [])
self.assertEqual(A.current_htlcs(LOCAL), [])
self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0)])
self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0)])
A.send_ctx()
B.recv_ctx()
B.send_ctx()
A.recv_ctx()
self.assertEqual(B.pending_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.pending_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
B.send_rev()
A.recv_rev()
A.send_rev()
B.recv_rev()
self.assertEqual(B.current_htlcs(LOCAL), [(RECEIVED, ah0), (SENT, bh0)][::-1])
self.assertEqual(A.current_htlcs(LOCAL), [(RECEIVED, bh0), (SENT, ah0)][::-1])
def test_no_race(self):
A = HTLCManager()
B = HTLCManager()
B.recv_htlc(A.send_htlc(H('A', 0)))
self.assertEqual(len(B.pending_htlcs(REMOTE)), 1)
A.send_ctx()
B.recv_ctx()
B.send_rev()
A.recv_rev()
B.send_ctx()
A.recv_ctx()
A.send_rev()
B.recv_rev()
self.assertEqual(len(A.current_htlcs(LOCAL)), 1)
self.assertEqual(len(B.current_htlcs(LOCAL)), 1)
B.send_settle(0)
A.recv_settle(0)
self.assertEqual(A.htlcs_by_direction(REMOTE, RECEIVED), [H('A', 0)])
self.assertNotEqual(A.current_htlcs(LOCAL), [])
self.assertNotEqual(B.current_htlcs(REMOTE), [])
self.assertEqual(A.pending_htlcs(LOCAL), [])
self.assertEqual(B.pending_htlcs(REMOTE), [])
B.send_ctx()
A.recv_ctx()
A.send_rev()
B.recv_rev()
A.send_ctx()
B.recv_ctx()
B.send_rev()
A.recv_rev()
self.assertEqual(B.current_htlcs(LOCAL), [])
self.assertEqual(A.current_htlcs(LOCAL), [])
self.assertEqual(A.current_htlcs(REMOTE), [])
self.assertEqual(B.current_htlcs(REMOTE), [])
self.assertEqual(len(A.settled_htlcs(LOCAL)), 1)
self.assertEqual(len(A.sent_in_ctn(2)), 1)
self.assertEqual(len(B.received_in_ctn(2)), 1)
def test_settle_while_owing(self):
A = HTLCManager()
B = HTLCManager()
B.recv_htlc(A.send_htlc(H('A', 0)))
A.send_ctx()
B.recv_ctx()
B.send_rev()
A.recv_rev()
B.send_settle(0)
A.recv_settle(0)
self.assertEqual(B.pending_htlcs(REMOTE), [])
B.send_ctx()
A.recv_ctx()
A.send_rev()
B.recv_rev()

View file

@ -6,8 +6,7 @@ from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_see
make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey, make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey,
derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret, derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret,
get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError, get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError,
ScriptHtlc, extract_nodeid, calc_onchain_fees) ScriptHtlc, extract_nodeid, calc_onchain_fees, UpdateAddHtlc)
from electrum import lnchan
from electrum.util import bh2u, bfh from electrum.util import bh2u, bfh
from electrum.transaction import Transaction from electrum.transaction import Transaction
@ -496,7 +495,7 @@ class TestLNUtil(unittest.TestCase):
(1, 2000 * 1000), (1, 2000 * 1000),
(3, 3000 * 1000), (3, 3000 * 1000),
(4, 4000 * 1000)]: (4, 4000 * 1000)]:
htlc_obj[num] = lnchan.UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None) htlc_obj[num] = UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None)
htlcs = [ScriptHtlc(htlc[x], htlc_obj[x]) for x in range(5)] htlcs = [ScriptHtlc(htlc[x], htlc_obj[x]) for x in range(5)]
our_commit_tx = make_commitment( our_commit_tx = make_commitment(
@ -506,7 +505,7 @@ class TestLNUtil(unittest.TestCase):
local_revocation_pubkey, local_delayedpubkey, local_delay, local_revocation_pubkey, local_delayedpubkey, local_delay,
funding_tx_id, funding_output_index, funding_amount_satoshi, funding_tx_id, funding_output_index, funding_amount_satoshi,
to_local_msat, to_remote_msat, local_dust_limit_satoshi, to_local_msat, to_remote_msat, local_dust_limit_satoshi,
calc_onchain_fees(len(htlcs), local_feerate_per_kw, True, we_are_initiator=True), htlcs=htlcs) calc_onchain_fees(len(htlcs), local_feerate_per_kw, True), htlcs=htlcs)
self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey) self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey)
self.assertEqual(str(our_commit_tx), output_commit_tx) self.assertEqual(str(our_commit_tx), output_commit_tx)
@ -584,7 +583,7 @@ class TestLNUtil(unittest.TestCase):
local_revocation_pubkey, local_delayedpubkey, local_delay, local_revocation_pubkey, local_delayedpubkey, local_delay,
funding_tx_id, funding_output_index, funding_amount_satoshi, funding_tx_id, funding_output_index, funding_amount_satoshi,
to_local_msat, to_remote_msat, local_dust_limit_satoshi, to_local_msat, to_remote_msat, local_dust_limit_satoshi,
calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[]) calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[])
self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey) self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey)
self.assertEqual(str(our_commit_tx), output_commit_tx) self.assertEqual(str(our_commit_tx), output_commit_tx)
@ -603,7 +602,7 @@ class TestLNUtil(unittest.TestCase):
local_revocation_pubkey, local_delayedpubkey, local_delay, local_revocation_pubkey, local_delayedpubkey, local_delay,
funding_tx_id, funding_output_index, funding_amount_satoshi, funding_tx_id, funding_output_index, funding_amount_satoshi,
to_local_msat, to_remote_msat, local_dust_limit_satoshi, to_local_msat, to_remote_msat, local_dust_limit_satoshi,
calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[]) calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[])
self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey) self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey)
self.assertEqual(str(our_commit_tx), output_commit_tx) self.assertEqual(str(our_commit_tx), output_commit_tx)
@ -661,7 +660,7 @@ class TestLNUtil(unittest.TestCase):
local_revocation_pubkey, local_delayedpubkey, local_delay, local_revocation_pubkey, local_delayedpubkey, local_delay,
funding_tx_id, funding_output_index, funding_amount_satoshi, funding_tx_id, funding_output_index, funding_amount_satoshi,
to_local_msat, to_remote_msat, local_dust_limit_satoshi, to_local_msat, to_remote_msat, local_dust_limit_satoshi,
calc_onchain_fees(0, local_feerate_per_kw, True, we_are_initiator=True), htlcs=[]) calc_onchain_fees(0, local_feerate_per_kw, True), htlcs=[])
self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey) self.sign_and_insert_remote_sig(our_commit_tx, remote_funding_pubkey, remote_signature, local_funding_pubkey, local_funding_privkey)
ref_commit_tx_str = '02000000000101bef67e4e2fb9ddeeb3461973cd4c62abb35050b1add772995b820b584a488489000000000038b02b8002c0c62d0000000000160014ccf1af2f2aabee14bb40fa3851ab2301de84311054a56a00000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e0400473044022051b75c73198c6deee1a875871c3961832909acd297c6b908d59e3319e5185a46022055c419379c5051a78d00dbbce11b5b664a0c22815fbcc6fcef6b1937c383693901483045022100f51d2e566a70ba740fc5d8c0f07b9b93d2ed741c3c0860c613173de7d39e7968022041376d520e9c0e1ad52248ddf4b22e12be8763007df977253ef45a4ca3bdb7c001475221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae3e195220' ref_commit_tx_str = '02000000000101bef67e4e2fb9ddeeb3461973cd4c62abb35050b1add772995b820b584a488489000000000038b02b8002c0c62d0000000000160014ccf1af2f2aabee14bb40fa3851ab2301de84311054a56a00000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e0400473044022051b75c73198c6deee1a875871c3961832909acd297c6b908d59e3319e5185a46022055c419379c5051a78d00dbbce11b5b664a0c22815fbcc6fcef6b1937c383693901483045022100f51d2e566a70ba740fc5d8c0f07b9b93d2ed741c3c0860c613173de7d39e7968022041376d520e9c0e1ad52248ddf4b22e12be8763007df977253ef45a4ca3bdb7c001475221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae3e195220'
self.assertEqual(str(our_commit_tx), ref_commit_tx_str) self.assertEqual(str(our_commit_tx), ref_commit_tx_str)