From 1be0a710c3686d6a2b13805e0cf67218034ac1ed Mon Sep 17 00:00:00 2001 From: SomberNight Date: Tue, 24 Mar 2020 20:07:00 +0100 Subject: [PATCH] ln: implement option payment_secret --- electrum/lnaddr.py | 20 +++++++++++-- electrum/lnonion.py | 5 +++- electrum/lnpeer.py | 21 +++++++++++--- electrum/lnutil.py | 14 +++++++++ electrum/lnworker.py | 17 ++++++++--- electrum/tests/test_bolt11.py | 53 +++++++++++++++++++++-------------- electrum/tests/test_lnpeer.py | 11 ++++++-- 7 files changed, 105 insertions(+), 36 deletions(-) diff --git a/electrum/lnaddr.py b/electrum/lnaddr.py index 08cc216eb..55265c111 100644 --- a/electrum/lnaddr.py +++ b/electrum/lnaddr.py @@ -180,16 +180,22 @@ def lnencode(addr: 'LnAddr', privkey): # Start with the timestamp data = bitstring.pack('uint:35', addr.date) + tags_set = set() + # Payment hash data += tagged_bytes('p', addr.paymenthash) - tags_set = set() + tags_set.add('p') + + if addr.payment_secret is not None: + data += tagged_bytes('s', addr.payment_secret) + tags_set.add('s') for k, v in addr.tags: # BOLT #11: # # A writer MUST NOT include more than one `d`, `h`, `n` or `x` fields, - if k in ('d', 'h', 'n', 'x'): + if k in ('d', 'h', 'n', 'x', 'p', 's'): if k in tags_set: raise ValueError("Duplicate '{}' tag".format(k)) @@ -248,11 +254,13 @@ def lnencode(addr: 'LnAddr', privkey): return bech32_encode(hrp, bitarray_to_u5(data)) class LnAddr(object): - def __init__(self, paymenthash: bytes = None, amount=None, currency=None, tags=None, date=None): + def __init__(self, *, paymenthash: bytes = None, amount=None, currency=None, tags=None, date=None, + payment_secret: bytes = None): self.date = int(time.time()) if not date else int(date) self.tags = [] if not tags else tags self.unknown_tags = [] self.paymenthash = paymenthash + self.payment_secret = payment_secret self.signature = None self.pubkey = None self.currency = constants.net.SEGWIT_HRP if currency is None else currency @@ -392,6 +400,12 @@ def lndecode(invoice: str, *, verbose=False, expected_hrp=None) -> LnAddr: continue addr.paymenthash = trim_to_bytes(tagdata) + elif tag == 's': + if data_length != 52: + addr.unknown_tags.append((tag, tagdata)) + continue + addr.payment_secret = trim_to_bytes(tagdata) + elif tag == 'n': if data_length != 53: addr.unknown_tags.append((tag, tagdata)) diff --git a/electrum/lnonion.py b/electrum/lnonion.py index 7dc6d73ea..bba2041a5 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -260,7 +260,8 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes, hmac=next_hmac) -def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, final_cltv: int) \ +def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, + final_cltv: int, *, payment_secret: bytes = None) \ -> Tuple[List[OnionHopsDataSingle], int, int]: """Returns the hops_data to be used for constructing an onion packet, and the amount_msat and cltv to be used on our immediate channel. @@ -275,6 +276,8 @@ def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, final_ "amt_to_forward": {"amt_to_forward": amt}, "outgoing_cltv_value": {"outgoing_cltv_value": cltv}, } + if payment_secret is not None: + hop_payload["payment_data"] = {"payment_secret": payment_secret, "total_msat": amt} hops_data = [OnionHopsDataSingle(is_tlv_payload=route[-1].has_feature_varonion(), payload=hop_payload)] # payloads, backwards from last hop (but excluding the first edge): diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 0b770b0ef..b93cc20e3 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -44,7 +44,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, MINIMUM_MAX_HTLC_VALUE_IN_FLIGHT_ACCEPTED, MAXIMUM_HTLC_MINIMUM_MSAT_ACCEPTED, MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY, NBLOCK_OUR_CLTV_EXPIRY_DELTA, format_short_channel_id, ShortChannelID, - IncompatibleLightningFeatures) + IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage) from .lnutil import FeeUpdate from .lntransport import LNTransport, LNTransportBase from .lnmsg import encode_msg, decode_msg @@ -1037,8 +1037,8 @@ class Peer(Logger): sig_64, htlc_sigs = chan.sign_next_commitment() self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) - def pay(self, route: 'LNPaymentRoute', chan: Channel, amount_msat: int, - payment_hash: bytes, min_final_cltv_expiry: int) -> UpdateAddHtlc: + def pay(self, *, route: 'LNPaymentRoute', chan: Channel, amount_msat: int, + payment_hash: bytes, min_final_cltv_expiry: int, payment_secret: bytes = None) -> UpdateAddHtlc: assert amount_msat > 0, "amount_msat is not greater zero" assert len(route) > 0 if not chan.can_send_update_add_htlc(): @@ -1048,7 +1048,8 @@ class Peer(Logger): local_height = self.network.get_local_height() # create onion packet final_cltv = local_height + min_final_cltv_expiry - hops_data, amount_msat, cltv = calc_hops_data_for_payment(route, amount_msat, final_cltv) + hops_data, amount_msat, cltv = calc_hops_data_for_payment(route, amount_msat, final_cltv, + payment_secret=payment_secret) assert final_cltv <= cltv, (final_cltv, cltv) secret_key = os.urandom(32) onion = new_onion_packet([x.node_id for x in route], secret_key, hops_data, associated_data=payment_hash) @@ -1221,6 +1222,14 @@ class Peer(Logger): except UnknownPaymentHash: reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') return False, reason + try: + payment_secret_from_onion = processed_onion.hop_data.payload["payment_data"]["payment_secret"] + except: + pass # skip + else: + if payment_secret_from_onion != derive_payment_secret_from_payment_preimage(preimage): + reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') + return False, reason expected_received_msat = int(info.amount * 1000) if info.amount is not None else None if expected_received_msat is not None and \ not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat): @@ -1244,6 +1253,10 @@ class Peer(Logger): except: reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00') return False, reason + try: + amount_from_onion = processed_onion.hop_data.payload["payment_data"]["total_msat"] + except: + pass # fall back to "amt_to_forward" if amount_from_onion > htlc.amount_msat: reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_HTLC_AMOUNT, data=htlc.amount_msat.to_bytes(8, byteorder="big")) diff --git a/electrum/lnutil.py b/electrum/lnutil.py index e4855d3dd..f29a08eec 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -839,6 +839,7 @@ LN_FEATURES_IMPLEMENTED = ( | LnFeatures.GOSSIP_QUERIES_OPT | LnFeatures.GOSSIP_QUERIES_REQ | LnFeatures.OPTION_STATIC_REMOTEKEY_OPT | LnFeatures.OPTION_STATIC_REMOTEKEY_REQ | LnFeatures.VAR_ONION_OPT | LnFeatures.VAR_ONION_REQ + | LnFeatures.PAYMENT_SECRET_OPT | LnFeatures.PAYMENT_SECRET_REQ ) @@ -894,6 +895,19 @@ def validate_features(features: int) -> None: raise UnknownEvenFeatureBits(fbit) +def derive_payment_secret_from_payment_preimage(payment_preimage: bytes) -> bytes: + """Returns secret to be put into invoice. + Derivation is deterministic, based on the preimage. + Crucially the payment_hash must be derived in an independent way from this. + """ + # Note that this could be random data too, but then we would need to store it. + # We derive it identically to clightning, so that we cannot be distinguished: + # https://github.com/ElementsProject/lightning/blob/faac4b28adee5221e83787d64cd5d30b16b62097/lightningd/invoice.c#L115 + modified = bytearray(payment_preimage) + modified[0] ^= 1 + return sha256(bytes(modified)) + + class LNPeerAddr: def __init__(self, host: str, port: int, pubkey: bytes): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 833f048cc..535169bcd 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -54,7 +54,7 @@ from .lnutil import (Outpoint, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, Direction, LnFeatures, ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails, - BarePaymentAttemptLog) + BarePaymentAttemptLog, derive_payment_secret_from_payment_preimage) from .lnutil import ln_dummy_address, ln_compare_features, IncompatibleLightningFeatures from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket @@ -151,6 +151,7 @@ class LNWorker(Logger): self.features |= LnFeatures.OPTION_DATA_LOSS_PROTECT_OPT self.features |= LnFeatures.OPTION_STATIC_REMOTEKEY_OPT self.features |= LnFeatures.VAR_ONION_OPT + self.features |= LnFeatures.PAYMENT_SECRET_OPT def channels_for_peer(self, node_id): return {} @@ -953,7 +954,12 @@ class LNWallet(LNWorker): if not peer: raise Exception('Dropped peer') await peer.initialized - htlc = peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry()) + htlc = peer.pay(route=route, + chan=chan, + amount_msat=int(lnaddr.amount * COIN * 1000), + payment_hash=lnaddr.paymenthash, + min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), + payment_secret=lnaddr.payment_secret) self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT) payment_attempt = await self.await_payment(lnaddr.paymenthash) if payment_attempt.success: @@ -1141,6 +1147,7 @@ class LNWallet(LNWorker): "Other clients will likely not be able to send to us.") payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) + info = PaymentInfo(payment_hash, amount_sat, RECEIVED, PR_UNPAID) amount_btc = amount_sat/Decimal(COIN) if amount_sat else None if expiry == 0: @@ -1149,13 +1156,15 @@ class LNWallet(LNWorker): # Our higher level invoices code however uses 0 for "never". # Hence set some high expiration here expiry = 100 * 365 * 24 * 60 * 60 # 100 years - lnaddr = LnAddr(payment_hash, amount_btc, + lnaddr = LnAddr(paymenthash=payment_hash, + amount=amount_btc, tags=[('d', message), ('c', MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), ('x', expiry), ('9', self.features.for_invoice())] + routing_hints, - date = timestamp) + date=timestamp, + payment_secret=derive_payment_secret_from_payment_preimage(payment_preimage)) invoice = lnencode(lnaddr, self.node_keypair.privkey) key = bh2u(lnaddr.paymenthash) req = { diff --git a/electrum/tests/test_bolt11.py b/electrum/tests/test_bolt11.py index 88b187684..42f70ff36 100644 --- a/electrum/tests/test_bolt11.py +++ b/electrum/tests/test_bolt11.py @@ -6,7 +6,7 @@ import unittest from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode, u5_to_bitarray, bitarray_to_u5 from electrum.segwit_addr import bech32_encode, bech32_decode -from electrum.lnutil import UnknownEvenFeatureBits +from electrum.lnutil import UnknownEvenFeatureBits, derive_payment_secret_from_payment_preimage from . import ElectrumTestCase @@ -62,27 +62,28 @@ class TestBolt11(ElectrumTestCase): tests = [ - LnAddr(RHASH, tags=[('d', '')]), - LnAddr(RHASH, amount=Decimal('0.001'), tags=[('d', '1 cup coffee'), ('x', 60)]), - LnAddr(RHASH, amount=Decimal('1'), tags=[('h', longdescription)]), - LnAddr(RHASH, currency='tb', tags=[('f', 'mk2QpYatsKicvFVuTAQLBryyccRXMUaGHP'), ('h', longdescription)]), - LnAddr(RHASH, amount=24, tags=[ + LnAddr(paymenthash=RHASH, tags=[('d', '')]), + LnAddr(paymenthash=RHASH, amount=Decimal('0.001'), tags=[('d', '1 cup coffee'), ('x', 60)]), + LnAddr(paymenthash=RHASH, amount=Decimal('1'), tags=[('h', longdescription)]), + LnAddr(paymenthash=RHASH, currency='tb', tags=[('f', 'mk2QpYatsKicvFVuTAQLBryyccRXMUaGHP'), ('h', longdescription)]), + LnAddr(paymenthash=RHASH, amount=24, tags=[ ('r', [(unhexlify('029e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('0102030405060708'), 1, 20, 3), (unhexlify('039e03a901b85534ff1e92c43c74431f7ce72046060fcf7a95c37e148f78c77255'), unhexlify('030405060708090a'), 2, 30, 4)]), ('f', '1RustyRX2oai4EYYDpQGWvEL62BBGqN9T'), ('h', longdescription)]), - LnAddr(RHASH, amount=24, tags=[('f', '3EktnHQD7RiAE6uzMj2ZifT9YgRrkSgzQX'), ('h', longdescription)]), - LnAddr(RHASH, amount=24, tags=[('f', 'bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4'), ('h', longdescription)]), - LnAddr(RHASH, amount=24, tags=[('f', 'bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3'), ('h', longdescription)]), - LnAddr(RHASH, amount=24, tags=[('n', PUBKEY), ('h', longdescription)]), - LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 514)]), - LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 8))]), - LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 9))]), - LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 11))]), - LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 12))]), - LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 13))]), - #LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 14))]), - LnAddr(RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 15))]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('f', '3EktnHQD7RiAE6uzMj2ZifT9YgRrkSgzQX'), ('h', longdescription)]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('f', 'bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4'), ('h', longdescription)]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('f', 'bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3'), ('h', longdescription)]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('n', PUBKEY), ('h', longdescription)]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 514)]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 8))]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 9))]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 11))]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 12))]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 13))]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 14))]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 10 + (1 << 15))]), + LnAddr(paymenthash=RHASH, amount=24, tags=[('h', longdescription), ('9', 33282)], payment_secret=b"\x11" * 32), ] # Roundtrip @@ -93,14 +94,14 @@ class TestBolt11(ElectrumTestCase): def test_n_decoding(self): # We flip the signature recovery bit, which would normally give a different # pubkey. - hrp, data = bech32_decode(lnencode(LnAddr(RHASH, amount=24, tags=[('d', '')]), PRIVKEY), True) + hrp, data = bech32_decode(lnencode(LnAddr(paymenthash=RHASH, amount=24, tags=[('d', '')]), PRIVKEY), True) databits = u5_to_bitarray(data) databits.invert(-1) lnaddr = lndecode(bech32_encode(hrp, bitarray_to_u5(databits)), verbose=True) assert lnaddr.pubkey.serialize() != PUBKEY # But not if we supply expliciy `n` specifier! - hrp, data = bech32_decode(lnencode(LnAddr(RHASH, amount=24, + hrp, data = bech32_decode(lnencode(LnAddr(paymenthash=RHASH, amount=24, tags=[('d', ''), ('n', PUBKEY)]), PRIVKEY), True) @@ -115,7 +116,7 @@ class TestBolt11(ElectrumTestCase): self.assertEqual(144, lnaddr.get_min_final_cltv_expiry()) def test_min_final_cltv_expiry_roundtrip(self): - lnaddr = LnAddr(RHASH, amount=Decimal('0.001'), tags=[('d', '1 cup coffee'), ('x', 60), ('c', 150)]) + lnaddr = LnAddr(paymenthash=RHASH, amount=Decimal('0.001'), tags=[('d', '1 cup coffee'), ('x', 60), ('c', 150)]) invoice = lnencode(lnaddr, PRIVKEY) self.assertEqual(150, lndecode(invoice).get_min_final_cltv_expiry()) @@ -125,3 +126,13 @@ class TestBolt11(ElectrumTestCase): with self.assertRaises(UnknownEvenFeatureBits): lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdees9q4pqqqqqqqqqqqqqqqqqqszk3ed62snp73037h4py4gry05eltlp0uezm2w9ajnerhmxzhzhsu40g9mgyx5v3ad4aqwkmvyftzk4k9zenz90mhjcy9hcevc7r3lx2sphzfxz7") + + def test_payment_secret(self): + lnaddr = lndecode("lnbc25m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5vdhkven9v5sxyetpdeessp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygs9q5sqqqqqqqqqqqqqqqpqqq4u9s93jtgysm3mrwll70zr697y3mf902hvxwej0v7c62rsltw83ng0pu8w3j230sluc5gxkdmm9dvpy9y6ggtjd2w544mzdrcs42t7sqdkcy8h") + self.assertEqual((1 << 15) + (1 << 99) , lnaddr.get_tag('9')) + self.assertEqual(b"\x11" * 32, lnaddr.payment_secret) + + def test_derive_payment_secret_from_payment_preimage(self): + preimage = bytes.fromhex("cc3fc000bdeff545acee53ada12ff96060834be263f77d645abbebc3a8d53b92") + self.assertEqual("bfd660b559b3f452c6bb05b8d2906f520c151c107b733863ed0cc53fc77021a8", + derive_payment_secret_from_payment_preimage(preimage).hex()) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 8c2452b9a..c351da398 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -235,8 +235,8 @@ class TestPeer(ElectrumTestCase): w2.save_preimage(RHASH, payment_preimage) w2.save_payment_info(info) lnaddr = LnAddr( - RHASH, - amount_btc, + paymenthash=RHASH, + amount=amount_btc, tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), ('d', 'coffee') ]) @@ -355,7 +355,12 @@ class TestPeer(ElectrumTestCase): await asyncio.wait_for(p2.initialized, 1) # alice sends htlc route = w1._create_route_from_invoice(decoded_invoice=lnaddr) - htlc = p1.pay(route, alice_channel, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry()) + htlc = p1.pay(route=route, + chan=alice_channel, + amount_msat=int(lnaddr.amount * COIN * 1000), + payment_hash=lnaddr.paymenthash, + min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), + payment_secret=lnaddr.payment_secret) # alice closes await p1.close_channel(alice_channel.channel_id) gath.cancel()