From f08e5541aeda9524eb456477f93732de7357c792 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Fri, 20 Sep 2019 17:15:49 +0200 Subject: [PATCH] Refactor invoices in lnworker. - use InvoiceInfo (NamedTuple) for normal operations, because lndecode operations can be very slow. - all invoices/requests are stored in wallet - invoice expiration detection is performed in wallet - CLI commands: list_invoices, add_request, add_lightning_request - revert 0062c6d69561991d5918163946344c1b10ed9588 because it forbids self-payments --- electrum/commands.py | 36 ++--- electrum/gui/kivy/main_window.py | 18 +-- electrum/gui/kivy/uix/screens.py | 7 +- electrum/gui/qt/main_window.py | 12 +- electrum/lnpeer.py | 6 +- electrum/lnworker.py | 219 ++++++++++++++---------------- electrum/tests/regtest/regtest.sh | 18 +-- electrum/tests/test_lnpeer.py | 22 +-- electrum/wallet.py | 120 ++++++++-------- 9 files changed, 214 insertions(+), 244 deletions(-) diff --git a/electrum/commands.py b/electrum/commands.py index 2bfcb7161..bd6fa1495 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -716,7 +716,7 @@ class Commands: def _format_request(self, out): from .util import get_request_status out['amount_BTC'] = format_satoshis(out.get('amount')) - out['status'] = get_request_status(out) + out['status_str'] = get_request_status(out) return out @command('w') @@ -733,7 +733,7 @@ class Commands: # pass @command('w') - async def listrequests(self, pending=False, expired=False, paid=False, wallet: Abstract_Wallet = None): + async def list_requests(self, pending=False, expired=False, paid=False, wallet: Abstract_Wallet = None): """List the payment requests you made.""" out = wallet.get_sorted_requests() if pending: @@ -760,7 +760,7 @@ class Commands: return wallet.get_unused_address() @command('w') - async def addrequest(self, amount, memo='', expiration=None, force=False, wallet: Abstract_Wallet = None): + async def add_request(self, amount, memo='', expiration=3600, force=False, wallet: Abstract_Wallet = None): """Create a payment request, using the first unused address of the wallet. The address will be considered as used after this operation. If no payment is received, the address will be considered as unused if the payment request is deleted from the wallet.""" @@ -777,6 +777,12 @@ class Commands: out = wallet.get_request(addr) return self._format_request(out) + @command('wn') + async def add_lightning_request(self, amount, memo='', expiration=3600, wallet: Abstract_Wallet = None): + amount_sat = int(satoshis(amount)) + key = await wallet.lnworker._add_request_coro(amount_sat, memo, expiration) + return wallet.get_request(key)['invoice'] + @command('w') async def addtransaction(self, tx, wallet: Abstract_Wallet = None): """ Add a transaction to the wallet history """ @@ -894,13 +900,6 @@ class Commands: async def lnpay(self, invoice, attempts=1, timeout=10, wallet: Abstract_Wallet = None): return await wallet.lnworker._pay(invoice, attempts=attempts) - @command('wn') - async def addinvoice(self, requested_amount, message, expiration=3600, wallet: Abstract_Wallet = None): - # using requested_amount because it is documented in param_descriptions - payment_hash = await wallet.lnworker._add_invoice_coro(satoshis(requested_amount), message, expiration) - invoice, direction, is_paid = wallet.lnworker.invoices[bh2u(payment_hash)] - return invoice - @command('w') async def nodeid(self, wallet: Abstract_Wallet = None): listen_addr = self.config.get('lightning_listen') @@ -925,21 +924,8 @@ class Commands: self.network.path_finder.blacklist.clear() @command('w') - async def lightning_invoices(self, wallet: Abstract_Wallet = None): - from .util import pr_tooltips - out = [] - for payment_hash, (preimage, invoice, is_received, timestamp) in wallet.lnworker.invoices.items(): - status = wallet.lnworker.get_invoice_status(payment_hash) - item = { - 'date':timestamp_to_datetime(timestamp), - 'direction': 'received' if is_received else 'sent', - 'payment_hash':payment_hash, - 'invoice':invoice, - 'preimage':preimage, - 'status':pr_tooltips[status] - } - out.append(item) - return out + async def list_invoices(self, wallet: Abstract_Wallet = None): + return wallet.get_invoices() @command('w') async def lightning_history(self, wallet: Abstract_Wallet = None): diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py index 37a701550..3e7c389d9 100644 --- a/electrum/gui/kivy/main_window.py +++ b/electrum/gui/kivy/main_window.py @@ -428,14 +428,11 @@ class ElectrumWindow(App): def show_request(self, is_lightning, key): from .uix.dialogs.request_dialog import RequestDialog - if is_lightning: - request, direction, is_paid = self.wallet.lnworker.invoices.get(key) or (None, None, None) - status = self.wallet.lnworker.get_invoice_status(key) - else: - request = self.wallet.get_request_URI(key) - status, conf = self.wallet.get_request_status(key) - self.request_popup = RequestDialog('Request', request, key) - self.request_popup.set_status(status) + request = self.wallet.get_request(key) + status = request['status'] + data = request['invoice'] if is_lightning else request['URI'] + self.request_popup = RequestDialog('Request', data, key) + self.request_popup.set_status(request['status']) self.request_popup.open() def show_invoice(self, is_lightning, key): @@ -444,10 +441,7 @@ class ElectrumWindow(App): if not invoice: return status = invoice['status'] - if is_lightning: - data = invoice['invoice'] - else: - data = key + data = invoice['invoice'] if is_lightning else key self.invoice_popup = InvoiceDialog('Invoice', data, key) self.invoice_popup.open() diff --git a/electrum/gui/kivy/uix/screens.py b/electrum/gui/kivy/uix/screens.py index 556f80b50..507d1e713 100644 --- a/electrum/gui/kivy/uix/screens.py +++ b/electrum/gui/kivy/uix/screens.py @@ -304,12 +304,7 @@ class SendScreen(CScreen): return message = self.screen.message if self.screen.is_lightning: - return { - 'type': PR_TYPE_LN, - 'invoice': address, - 'amount': amount, - 'message': message, - } + return self.app.wallet.lnworker.parse_bech32_invoice(address) else: if not bitcoin.is_address(address): self.app.show_error(_('Invalid Bitcoin Address') + ':\n' + address) diff --git a/electrum/gui/qt/main_window.py b/electrum/gui/qt/main_window.py index 49e16a8ee..25bd5180b 100644 --- a/electrum/gui/qt/main_window.py +++ b/electrum/gui/qt/main_window.py @@ -1073,8 +1073,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger): message = self.receive_message_e.text() expiry = self.config.get('request_expiry', 3600) if is_lightning: - payment_hash = self.wallet.lnworker.add_invoice(amount, message, expiry) - key = bh2u(payment_hash) + key = self.wallet.lnworker.add_request(amount, message, expiry) else: key = self.create_bitcoin_request(amount, message, expiry) self.address_list.update() @@ -1698,12 +1697,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger): message = self.message_e.text() amount = self.amount_e.get_amount() if not self.is_onchain: - return { - 'type': PR_TYPE_LN, - 'invoice': self.payto_e.lightning_invoice, - 'amount': amount, - 'message': message, - } + return self.wallet.lnworker.parse_bech32_invoice(self.payto_e.lightning_invoice) else: outputs = self.read_outputs() if self.check_send_tab_outputs_and_show_errors(outputs): @@ -1733,7 +1727,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger): def do_pay_invoice(self, invoice, preview=False): if invoice['type'] == PR_TYPE_LN: - self.pay_lightning_invoice(self.payto_e.lightning_invoice) + self.pay_lightning_invoice(invoice['invoice']) return elif invoice['type'] == PR_TYPE_ONCHAIN: message = invoice['message'] diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 4e1c7ac41..41e05b3f8 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1402,13 +1402,13 @@ class Peer(Logger): await self.await_local(chan, local_ctn) await self.await_remote(chan, remote_ctn) try: - invoice = self.lnworker.get_invoice(htlc.payment_hash) + info = self.lnworker.get_invoice_info(htlc.payment_hash) preimage = self.lnworker.get_preimage(htlc.payment_hash) except UnknownPaymentHash: reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return - expected_received_msat = int(invoice.amount * bitcoin.COIN * 1000) if invoice.amount is not None else None + expected_received_msat = int(info.amount * 1000) if info.amount is not None else None if expected_received_msat is not None and \ not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat): reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') @@ -1431,7 +1431,7 @@ class Peer(Logger): data=htlc.amount_msat.to_bytes(8, byteorder="big")) await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) return - self.network.trigger_callback('htlc_added', htlc, invoice, RECEIVED) + #self.network.trigger_callback('htlc_added', htlc, invoice, RECEIVED) await asyncio.sleep(self.network.config.lightning_settle_delay) await self._fulfill_htlc(chan, htlc.htlc_id, preimage) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 2be0f9b35..c7134f69f 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -82,6 +82,16 @@ FALLBACK_NODE_LIST_MAINNET = [ encoder = ChannelJsonEncoder() + +from typing import NamedTuple + +class InvoiceInfo(NamedTuple): + payment_hash: bytes + amount: int + direction: int + status: int + + class LNWorker(Logger): def __init__(self, xprv): @@ -313,7 +323,7 @@ class LNWallet(LNWorker): LNWorker.__init__(self, xprv) self.ln_keystore = keystore.from_xprv(xprv) self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ - self.invoices = self.storage.get('lightning_invoices', {}) # RHASH -> (invoice, direction, is_paid) + self.invoices = self.storage.get('lightning_invoices2', {}) # RHASH -> amount, direction, is_paid self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage self.sweep_address = wallet.get_receiving_address() self.lock = threading.RLock() @@ -409,16 +419,6 @@ class LNWallet(LNWorker): timestamp = int(time.time()) self.network.trigger_callback('ln_payment_completed', timestamp, direction, htlc, preimage, chan_id) - def get_invoice_status(self, key): - if key not in self.invoices: - return PR_UNKNOWN - invoice, direction, status = self.invoices[key] - lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - if status == PR_UNPAID and lnaddr.is_expired(): - return PR_EXPIRED - else: - return status - def get_payments(self): # return one item per payment_hash # note: with AMP we will have several channels per payment @@ -431,6 +431,20 @@ class LNWallet(LNWorker): out[k].append(v) return out + def parse_bech32_invoice(self, invoice): + lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) + amount = int(lnaddr.amount * COIN) if lnaddr.amount else None + return { + 'type': PR_TYPE_LN, + 'invoice': invoice, + 'amount': amount, + 'message': lnaddr.get_description(), + 'time': lnaddr.date, + 'exp': lnaddr.get_expiry(), + 'pubkey': bh2u(lnaddr.pubkey.serialize()), + 'rhash': lnaddr.paymenthash.hex(), + } + def get_unsettled_payments(self): out = [] for payment_hash, plist in self.get_payments().items(): @@ -455,7 +469,7 @@ class LNWallet(LNWorker): def get_history(self): out = [] - for payment_hash, plist in self.get_payments().items(): + for key, plist in self.get_payments().items(): plist = list(filter(lambda x: x[3] == 'settled', plist)) if len(plist) == 0: continue @@ -464,11 +478,13 @@ class LNWallet(LNWorker): direction = 'sent' if _direction == SENT else 'received' amount_msat = int(_direction) * htlc.amount_msat timestamp = htlc.timestamp - label = self.wallet.get_label(payment_hash) - req = self.get_request(payment_hash) - if req and _direction == SENT: - req_amount_msat = -req['amount']*1000 - fee_msat = req_amount_msat - amount_msat + label = self.wallet.get_label(key) + if _direction == SENT: + try: + inv = self.get_invoice_info(bfh(key)) + fee_msat = inv.amount*1000 - amount_msat if inv.amount else None + except UnknownPaymentHash: + fee_msat = None else: fee_msat = None else: @@ -489,7 +505,7 @@ class LNWallet(LNWorker): 'status': status, 'amount_msat': amount_msat, 'fee_msat': fee_msat, - 'payment_hash': payment_hash + 'payment_hash': key, } out.append(item) # add funding events @@ -831,20 +847,23 @@ class LNWallet(LNWorker): @log_exceptions async def _pay(self, invoice, amount_sat=None, attempts=1): - addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - key = bh2u(addr.paymenthash) - if key in self.preimages: + lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) + key = bh2u(lnaddr.paymenthash) + amount = int(lnaddr.amount * COIN) if lnaddr.amount else None + status = self.get_invoice_status(lnaddr.paymenthash) + if status == PR_PAID: raise PaymentFailure(_("This invoice has been paid already")) + info = InvoiceInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID) + self.save_invoice_info(info) self._check_invoice(invoice, amount_sat) - self.save_invoice(addr.paymenthash, invoice, SENT, PR_INFLIGHT) - self.wallet.set_label(key, addr.get_description()) + self.wallet.set_label(key, lnaddr.get_description()) for i in range(attempts): - route = await self._create_route_from_invoice(decoded_invoice=addr) + route = await self._create_route_from_invoice(decoded_invoice=lnaddr) if not self.get_channel_by_short_id(route[0].short_channel_id): scid = route[0].short_channel_id raise Exception(f"Got route with unknown first channel: {scid}") self.network.trigger_callback('payment_status', key, 'progress', i) - if await self._pay_to_route(route, addr, invoice): + if await self._pay_to_route(route, lnaddr, invoice): return True return False @@ -854,10 +873,12 @@ class LNWallet(LNWorker): if not chan: raise Exception(f"PathFinder returned path with short_channel_id " f"{short_channel_id} that is not in channel list") + self.set_invoice_status(addr.paymenthash, PR_INFLIGHT) peer = self.peers[route[0].node_id] htlc = await peer.pay(route, chan, int(addr.amount * COIN * 1000), addr.paymenthash, addr.get_min_final_cltv_expiry()) self.network.trigger_callback('htlc_added', htlc, addr, SENT) success = await self.pending_payments[(short_channel_id, htlc.htlc_id)] + self.set_invoice_status(addr.paymenthash, (PR_PAID if success else PR_UNPAID)) return success @staticmethod @@ -933,119 +954,89 @@ class LNWallet(LNWorker): raise PaymentFailure(_("No path found")) return route - def add_invoice(self, amount_sat, message, expiry): - coro = self._add_invoice_coro(amount_sat, message, expiry) + def add_request(self, amount_sat, message, expiry): + coro = self._add_request_coro(amount_sat, message, expiry) fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) try: return fut.result(timeout=5) except concurrent.futures.TimeoutError: - raise Exception(_("add_invoice timed out")) + raise Exception(_("add invoice timed out")) @log_exceptions - async def _add_invoice_coro(self, amount_sat, message, expiry): - payment_preimage = os.urandom(32) - payment_hash = sha256(payment_preimage) - amount_btc = amount_sat/Decimal(COIN) if amount_sat else None + async def _add_request_coro(self, amount_sat, message, expiry): + timestamp = int(time.time()) routing_hints = await self._calc_routing_hints_for_invoice(amount_sat) if not routing_hints: self.logger.info("Warning. No routing hints added to invoice. " "Other clients will likely not be able to send to us.") - invoice = lnencode(LnAddr(payment_hash, amount_btc, - tags=[('d', message), - ('c', MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), - ('x', expiry)] - + routing_hints), - self.node_keypair.privkey) - self.save_invoice(payment_hash, invoice, RECEIVED, PR_UNPAID) + payment_preimage = os.urandom(32) + payment_hash = sha256(payment_preimage) + info = InvoiceInfo(payment_hash, amount_sat, RECEIVED, PR_UNPAID) + amount_btc = amount_sat/Decimal(COIN) if amount_sat else None + lnaddr = LnAddr(payment_hash, amount_btc, + tags=[('d', message), + ('c', MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), + ('x', expiry)] + + routing_hints, + date = timestamp) + invoice = lnencode(lnaddr, self.node_keypair.privkey) + key = bh2u(lnaddr.paymenthash) + req = { + 'type': PR_TYPE_LN, + 'amount': amount_sat, + 'time': lnaddr.date, + 'exp': expiry, + 'message': message, + 'rhash': key, + 'invoice': invoice + } self.save_preimage(payment_hash, payment_preimage) - self.wallet.set_label(bh2u(payment_hash), message) - return payment_hash + self.save_invoice_info(info) + self.wallet.add_payment_request(req) + self.wallet.set_label(key, message) + return key def save_preimage(self, payment_hash: bytes, preimage: bytes): assert sha256(preimage) == payment_hash - key = bh2u(payment_hash) - self.preimages[key] = bh2u(preimage) + self.preimages[bh2u(payment_hash)] = bh2u(preimage) self.storage.put('lightning_preimages', self.preimages) self.storage.write() def get_preimage(self, payment_hash: bytes) -> bytes: - try: - preimage = bfh(self.preimages[bh2u(payment_hash)]) - assert sha256(preimage) == payment_hash - return preimage - except KeyError as e: - raise UnknownPaymentHash(payment_hash) from e + return bfh(self.preimages.get(bh2u(payment_hash))) - def save_new_invoice(self, invoice): - addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - self.save_invoice(addr.paymenthash, invoice, SENT, PR_UNPAID) - - def save_invoice(self, payment_hash:bytes, invoice, direction, status): - key = bh2u(payment_hash) + def get_invoice_info(self, payment_hash: bytes) -> bytes: + key = payment_hash.hex() with self.lock: - self.invoices[key] = invoice, direction, status - self.storage.put('lightning_invoices', self.invoices) + if key not in self.invoices: + raise UnknownPaymentHash(payment_hash) + amount, direction, status = self.invoices[key] + return InvoiceInfo(payment_hash, amount, direction, status) + + def save_invoice_info(self, info): + key = info.payment_hash.hex() + with self.lock: + self.invoices[key] = info.amount, info.direction, info.status + self.storage.put('lightning_invoices2', self.invoices) self.storage.write() - def set_invoice_status(self, payment_hash, status): - key = bh2u(payment_hash) - if key not in self.invoices: + def get_invoice_status(self, payment_hash): + try: + info = self.get_invoice_info(payment_hash) + return info.status + except UnknownPaymentHash: + return PR_UNKNOWN + + def set_invoice_status(self, payment_hash: bytes, status): + try: + info = self.get_invoice_info(payment_hash) + except UnknownPaymentHash: # if we are forwarding return - invoice, direction, _ = self.invoices[key] - self.save_invoice(payment_hash, invoice, direction, status) - if direction == RECEIVED and status == PR_PAID: - self.network.trigger_callback('payment_received', self.wallet, key, PR_PAID) - - def get_invoice(self, payment_hash: bytes) -> LnAddr: - try: - invoice, direction, is_paid = self.invoices[bh2u(payment_hash)] - return lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - except KeyError as e: - raise UnknownPaymentHash(payment_hash) from e - - def get_request(self, key): - if key not in self.invoices: - return - # todo: parse invoices when saving - invoice, direction, is_paid = self.invoices[key] - status = self.get_invoice_status(key) - lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) - amount_sat = int(lnaddr.amount*COIN) if lnaddr.amount else None - description = lnaddr.get_description() - timestamp = lnaddr.date - return { - 'type': PR_TYPE_LN, - 'status': status, - 'amount': amount_sat, - 'time': timestamp, - 'exp': lnaddr.get_expiry(), - 'message': description, - 'rhash': key, - 'invoice': invoice - } - - @profiler - def get_invoices(self): - # invoices = outgoing - out = [] - with self.lock: - invoice_items = list(self.invoices.items()) - for key, (invoice, direction, status) in invoice_items: - if direction == SENT and status != PR_PAID: - out.append(self.get_request(key)) - return out - - @profiler - def get_requests(self): - # requests = incoming - out = [] - with self.lock: - invoice_items = list(self.invoices.items()) - for key, (invoice, direction, status) in invoice_items: - if direction == RECEIVED and status != PR_PAID: - out.append(self.get_request(key)) - return out + info = info._replace(status=status) + self.save_invoice_info(info) + if info.direction == RECEIVED and info.status == PR_PAID: + self.network.trigger_callback('payment_received', self.wallet, bh2u(payment_hash), PR_PAID) async def _calc_routing_hints_for_invoice(self, amount_sat): """calculate routing hints (BOLT-11 'r' field)""" diff --git a/electrum/tests/regtest/regtest.sh b/electrum/tests/regtest/regtest.sh index 282835845..8a310cf69 100755 --- a/electrum/tests/regtest/regtest.sh +++ b/electrum/tests/regtest/regtest.sh @@ -114,7 +114,7 @@ if [[ $1 == "open" ]]; then fi if [[ $1 == "alice_pays_carol" ]]; then - request=$($carol addinvoice 0.0001 "blah") + request=$($carol add_lightning_request 0.0001 -m "blah") $alice lnpay $request carol_balance=$($carol list_channels | jq -r '.[0].local_balance') echo "carol balance: $carol_balance" @@ -140,12 +140,12 @@ if [[ $1 == "breach" ]]; then channel=$($alice open_channel $bob_node 0.15) new_blocks 3 wait_until_channel_open alice - request=$($bob addinvoice 0.01 "blah") + request=$($bob add_lightning_request 0.01 -m "blah") echo "alice pays" $alice lnpay $request sleep 2 ctx=$($alice get_channel_ctx $channel | jq '.hex' | tr -d '"') - request=$($bob addinvoice 0.01 "blah2") + request=$($bob add_lightning_request 0.01 -m "blah2") echo "alice pays again" $alice lnpay $request echo "alice broadcasts old ctx" @@ -168,7 +168,7 @@ if [[ $1 == "redeem_htlcs" ]]; then new_blocks 6 sleep 10 # alice pays bob - invoice=$($bob addinvoice 0.05 "test") + invoice=$($bob add_lightning_request 0.05 -m "test") $alice lnpay $invoice --timeout=1 || true sleep 1 settled=$($alice list_channels | jq '.[] | .local_htlcs | .settles | length') @@ -214,7 +214,7 @@ if [[ $1 == "breach_with_unspent_htlc" ]]; then new_blocks 3 wait_until_channel_open alice echo "alice pays bob" - invoice=$($bob addinvoice 0.05 "test") + invoice=$($bob add_lightning_request 0.05 -m "test") $alice lnpay $invoice --timeout=1 || true settled=$($alice list_channels | jq '.[] | .local_htlcs | .settles | length') if [[ "$settled" != "0" ]]; then @@ -246,7 +246,7 @@ if [[ $1 == "breach_with_spent_htlc" ]]; then new_blocks 3 wait_until_channel_open alice echo "alice pays bob" - invoice=$($bob addinvoice 0.05 "test") + invoice=$($bob add_lightning_request 0.05 -m "test") $alice lnpay $invoice --timeout=1 || true ctx=$($alice get_channel_ctx $channel | jq '.hex' | tr -d '"') settled=$($alice list_channels | jq '.[] | .local_htlcs | .settles | length') @@ -310,11 +310,11 @@ if [[ $1 == "watchtower" ]]; then new_blocks 3 wait_until_channel_open alice echo "alice pays bob" - invoice1=$($bob addinvoice 0.05 "invoice1") + invoice1=$($bob add_lightning_request 0.05 -m "invoice1") $alice lnpay $invoice1 - invoice2=$($bob addinvoice 0.05 "invoice2") + invoice2=$($bob add_lightning_request 0.05 -m "invoice2") $alice lnpay $invoice2 - invoice3=$($bob addinvoice 0.05 "invoice3") + invoice3=$($bob add_lightning_request 0.05 -m "invoice3") $alice lnpay $invoice3 fi diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 226901aad..cb355dcf7 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -21,6 +21,7 @@ from electrum.channel_db import ChannelDB from electrum.lnworker import LNWallet from electrum.lnmsg import encode_msg, decode_msg from electrum.logging import console_stderr_handler +from electrum.lnworker import InvoiceInfo, RECEIVED, PR_UNPAID from .test_lnchannel import create_test_channels from . import SequentialTestCase @@ -80,6 +81,7 @@ class MockWallet: pass class MockLNWallet: + storage = MockStorage() def __init__(self, remote_keypair, local_keypair, chan, tx_queue): self.chan = chan self.remote_keypair = remote_keypair @@ -87,7 +89,6 @@ class MockLNWallet: self.network = MockNetwork(tx_queue) self.channels = {self.chan.channel_id: self.chan} self.invoices = {} - self.preimages = {} self.inflight = {} self.wallet = MockWallet() self.localfeatures = LnLocalFeatures(0) @@ -122,7 +123,11 @@ class MockLNWallet: def save_invoice(*args, is_paid=False): pass - get_invoice = LNWallet.get_invoice + preimages = {} + get_invoice_info = LNWallet.get_invoice_info + save_invoice_info = LNWallet.save_invoice_info + set_invoice_status = LNWallet.set_invoice_status + save_preimage = LNWallet.save_preimage get_preimage = LNWallet.get_preimage _create_route_from_invoice = LNWallet._create_route_from_invoice _check_invoice = staticmethod(LNWallet._check_invoice) @@ -207,19 +212,20 @@ class TestPeer(SequentialTestCase): @staticmethod def prepare_invoice(w2 # receiver ): - amount_btc = 100000/Decimal(COIN) + amount_sat = 100000 + amount_btc = amount_sat/Decimal(COIN) payment_preimage = os.urandom(32) RHASH = sha256(payment_preimage) - addr = LnAddr( + info = InvoiceInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID) + w2.save_preimage(RHASH, payment_preimage) + w2.save_invoice_info(info) + lnaddr = LnAddr( RHASH, amount_btc, tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), ('d', 'coffee') ]) - pay_req = lnencode(addr, w2.node_keypair.privkey) - w2.preimages[bh2u(RHASH)] = bh2u(payment_preimage) - w2.invoices[bh2u(RHASH)] = (pay_req, True, False) - return pay_req + return lnencode(lnaddr, w2.node_keypair.privkey) def test_payment(self): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers() diff --git a/electrum/wallet.py b/electrum/wallet.py index e9089dc64..c674da9a0 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -541,16 +541,16 @@ class Abstract_Wallet(AddressSynchronizer): def save_invoice(self, invoice): invoice_type = invoice['type'] if invoice_type == PR_TYPE_LN: - self.lnworker.save_new_invoice(invoice['invoice']) + key = invoice['rhash'] elif invoice_type == PR_TYPE_ONCHAIN: key = bh2u(sha256(repr(invoice))[0:16]) invoice['id'] = key invoice['txid'] = None - self.invoices[key] = invoice - self.storage.put('invoices', self.invoices) - self.storage.write() else: raise Exception('Unsupported invoice type') + self.invoices[key] = invoice + self.storage.put('invoices', self.invoices) + self.storage.write() def clear_invoices(self): self.invoices = {} @@ -560,29 +560,26 @@ class Abstract_Wallet(AddressSynchronizer): def get_invoices(self): out = [self.get_invoice(key) for key in self.invoices.keys()] out = [x for x in out if x and x.get('status') != PR_PAID] - if self.lnworker: - out += self.lnworker.get_invoices() out.sort(key=operator.itemgetter('time')) return out + def check_if_expired(self, item): + if item['status'] == PR_UNPAID and 'exp' in item and item['time'] + item['exp'] < time.time(): + item['status'] = PR_EXPIRED + def get_invoice(self, key): - if key in self.invoices: - item = copy.copy(self.invoices[key]) - request_type = item.get('type') - if request_type is None: - # todo: convert old bip70 invoices - return - # add status - if item.get('txid'): - status = PR_PAID - elif 'exp' in item and item['time'] + item['exp'] < time.time(): - status = PR_EXPIRED - else: - status = PR_UNPAID - item['status'] = status - return item - if self.lnworker: - return self.lnworker.get_request(key) + if key not in self.invoices: + return + item = copy.copy(self.invoices[key]) + request_type = item.get('type') + if request_type == PR_TYPE_ONCHAIN: + item['status'] = PR_PAID if item.get('txid') is not None else PR_UNPAID + elif request_type == PR_TYPE_LN: + item['status'] = self.lnworker.get_invoice_status(bfh(item['rhash'])) + else: + return + self.check_if_expired(item) + return item @profiler def get_full_history(self, fx=None, *, onchain_domain=None, include_lightning=True): @@ -1319,19 +1316,6 @@ class Abstract_Wallet(AddressSynchronizer): return True, conf return False, None - def get_payment_request(self, addr): - r = self.receive_requests.get(addr) - if not r: - return - out = copy.copy(r) - out['type'] = PR_TYPE_ONCHAIN - out['URI'] = self.get_request_URI(addr) - status, conf = self.get_request_status(addr) - out['status'] = status - if conf is not None: - out['confirmations'] = conf - return out - def get_request_URI(self, addr): req = self.receive_requests[addr] message = self.labels.get(addr, '') @@ -1349,11 +1333,10 @@ class Abstract_Wallet(AddressSynchronizer): uri = create_bip21_uri(addr, amount, message, extra_query_params=extra_query_params) return str(uri) - def get_request_status(self, key): - r = self.receive_requests.get(key) + def get_request_status(self, address): + r = self.receive_requests.get(address) if r is None: return PR_UNKNOWN - address = r['address'] amount = r.get('amount', 0) or 0 timestamp = r.get('time', 0) if timestamp and type(timestamp) != int: @@ -1372,14 +1355,23 @@ class Abstract_Wallet(AddressSynchronizer): return status, conf def get_request(self, key): - if key in self.receive_requests: - req = self.get_payment_request(key) - elif self.lnworker: - req = self.lnworker.get_request(key) - else: - req = None + req = self.receive_requests.get(key) if not req: return + req = copy.copy(req) + if req['type'] == PR_TYPE_ONCHAIN: + addr = req['address'] + req['URI'] = self.get_request_URI(addr) + status, conf = self.get_request_status(addr) + req['status'] = status + if conf is not None: + req['confirmations'] = conf + elif req['type'] == PR_TYPE_LN: + req['status'] = self.lnworker.get_invoice_status(bfh(key)) + else: + return + self.check_if_expired(req) + # add URL if we are running a payserver if self.config.get('payserver_port'): host = self.config.get('payserver_host', 'localhost') port = self.config.get('payserver_port') @@ -1405,8 +1397,16 @@ class Abstract_Wallet(AddressSynchronizer): from .bitcoin import TYPE_ADDRESS timestamp = int(time.time()) _id = bh2u(sha256d(addr + "%d"%timestamp))[0:10] - r = {'time':timestamp, 'amount':amount, 'exp':expiration, 'address':addr, 'memo':message, 'id':_id, 'outputs': [(TYPE_ADDRESS, addr, amount)]} - return r + return { + 'type': PR_TYPE_ONCHAIN, + 'time':timestamp, + 'amount':amount, + 'exp':expiration, + 'address':addr, + 'memo':message, + 'id':_id, + 'outputs': [(TYPE_ADDRESS, addr, amount)] + } def sign_payment_request(self, key, alias, alias_addr, password): req = self.receive_requests.get(key) @@ -1419,17 +1419,23 @@ class Abstract_Wallet(AddressSynchronizer): self.storage.put('payment_requests', self.receive_requests) def add_payment_request(self, req): - addr = req['address'] - if not bitcoin.is_address(addr): - raise Exception(_('Invalid Bitcoin address.')) - if not self.is_mine(addr): - raise Exception(_('Address not in wallet.')) - + if req['type'] == PR_TYPE_ONCHAIN: + addr = req['address'] + if not bitcoin.is_address(addr): + raise Exception(_('Invalid Bitcoin address.')) + if not self.is_mine(addr): + raise Exception(_('Address not in wallet.')) + key = addr + message = req['memo'] + elif req['type'] == PR_TYPE_LN: + key = req['rhash'] + message = req['message'] + else: + raise Exception('Unknown request type') amount = req.get('amount') - message = req.get('memo') - self.receive_requests[addr] = req + self.receive_requests[key] = req self.storage.put('payment_requests', self.receive_requests) - self.set_label(addr, message) # should be a default label + self.set_label(key, message) # should be a default label return req def delete_request(self, key): @@ -1457,8 +1463,6 @@ class Abstract_Wallet(AddressSynchronizer): def get_sorted_requests(self): """ sorted by timestamp """ out = [self.get_request(x) for x in self.receive_requests.keys()] - if self.lnworker: - out += self.lnworker.get_requests() out.sort(key=operator.itemgetter('time')) return out