lnworker: rename 'invoices' to 'payments' when they can be in both directions

This commit is contained in:
ThomasV 2019-10-09 20:16:11 +02:00
parent 788d54f9a6
commit 638de63f13
4 changed files with 40 additions and 43 deletions

View file

@ -1267,7 +1267,7 @@ class Peer(Logger):
self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
chan.receive_htlc_settle(preimage, htlc_id) chan.receive_htlc_settle(preimage, htlc_id)
self.lnworker.save_preimage(payment_hash, preimage) self.lnworker.save_preimage(payment_hash, preimage)
self.lnworker.set_invoice_status(payment_hash, PR_PAID) self.lnworker.set_payment_status(payment_hash, PR_PAID)
local_ctn = chan.get_latest_ctn(LOCAL) local_ctn = chan.get_latest_ctn(LOCAL)
asyncio.ensure_future(self._on_update_fulfill_htlc(chan, htlc_id, preimage, local_ctn)) asyncio.ensure_future(self._on_update_fulfill_htlc(chan, htlc_id, preimage, local_ctn))
@ -1406,7 +1406,7 @@ class Peer(Logger):
await self.await_local(chan, local_ctn) await self.await_local(chan, local_ctn)
await self.await_remote(chan, remote_ctn) await self.await_remote(chan, remote_ctn)
try: try:
info = self.lnworker.get_invoice_info(htlc.payment_hash) info = self.lnworker.get_payment_info(htlc.payment_hash)
preimage = self.lnworker.get_preimage(htlc.payment_hash) preimage = self.lnworker.get_preimage(htlc.payment_hash)
except UnknownPaymentHash: except UnknownPaymentHash:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
@ -1443,7 +1443,7 @@ class Peer(Logger):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
chan.settle_htlc(preimage, htlc_id) chan.settle_htlc(preimage, htlc_id)
payment_hash = sha256(preimage) payment_hash = sha256(preimage)
self.lnworker.set_invoice_status(payment_hash, PR_PAID) self.lnworker.set_payment_status(payment_hash, PR_PAID)
remote_ctn = chan.get_latest_ctn(REMOTE) remote_ctn = chan.get_latest_ctn(REMOTE)
self.send_message("update_fulfill_htlc", self.send_message("update_fulfill_htlc",
channel_id=chan.channel_id, channel_id=chan.channel_id,

View file

@ -85,7 +85,7 @@ encoder = ChannelJsonEncoder()
from typing import NamedTuple from typing import NamedTuple
class InvoiceInfo(NamedTuple): class PaymentInfo(NamedTuple):
payment_hash: bytes payment_hash: bytes
amount: int amount: int
direction: int direction: int
@ -324,7 +324,7 @@ class LNWallet(LNWorker):
LNWorker.__init__(self, xprv) LNWorker.__init__(self, xprv)
self.ln_keystore = keystore.from_xprv(xprv) self.ln_keystore = keystore.from_xprv(xprv)
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
self.invoices = self.storage.get('lightning_invoices2', {}) # RHASH -> amount, direction, is_paid self.payments = self.storage.get('lightning_payments', {}) # RHASH -> amount, direction, is_paid
self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage self.preimages = self.storage.get('lightning_preimages', {}) # RHASH -> preimage
self.sweep_address = wallet.get_receiving_address() self.sweep_address = wallet.get_receiving_address()
self.lock = threading.RLock() self.lock = threading.RLock()
@ -481,7 +481,7 @@ class LNWallet(LNWorker):
label = self.wallet.get_label(key) label = self.wallet.get_label(key)
if _direction == SENT: if _direction == SENT:
try: try:
inv = self.get_invoice_info(bfh(key)) inv = self.get_payment_info(bfh(key))
fee_msat = inv.amount*1000 - amount_msat if inv.amount else None fee_msat = inv.amount*1000 - amount_msat if inv.amount else None
except UnknownPaymentHash: except UnknownPaymentHash:
fee_msat = None fee_msat = None
@ -850,11 +850,11 @@ class LNWallet(LNWorker):
lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP) lnaddr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
key = bh2u(lnaddr.paymenthash) key = bh2u(lnaddr.paymenthash)
amount = int(lnaddr.amount * COIN) if lnaddr.amount else None amount = int(lnaddr.amount * COIN) if lnaddr.amount else None
status = self.get_invoice_status(lnaddr.paymenthash) status = self.get_payment_status(lnaddr.paymenthash)
if status == PR_PAID: if status == PR_PAID:
raise PaymentFailure(_("This invoice has been paid already")) raise PaymentFailure(_("This invoice has been paid already"))
info = InvoiceInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID) info = PaymentInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID)
self.save_invoice_info(info) self.save_payment_info(info)
self._check_invoice(invoice, amount_sat) self._check_invoice(invoice, amount_sat)
self.wallet.set_label(key, lnaddr.get_description()) self.wallet.set_label(key, lnaddr.get_description())
for i in range(attempts): for i in range(attempts):
@ -870,12 +870,12 @@ class LNWallet(LNWorker):
if not chan: if not chan:
raise Exception(f"PathFinder returned path with short_channel_id " raise Exception(f"PathFinder returned path with short_channel_id "
f"{short_channel_id} that is not in channel list") f"{short_channel_id} that is not in channel list")
self.set_invoice_status(lnaddr.paymenthash, PR_INFLIGHT) self.set_payment_status(lnaddr.paymenthash, PR_INFLIGHT)
peer = self.peers[route[0].node_id] peer = self.peers[route[0].node_id]
htlc = await peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry()) htlc = await peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry())
self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT) self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT)
success = await self.pending_payments[(short_channel_id, htlc.htlc_id)] success = await self.pending_payments[(short_channel_id, htlc.htlc_id)]
self.set_invoice_status(lnaddr.paymenthash, (PR_PAID if success else PR_UNPAID)) self.set_payment_status(lnaddr.paymenthash, (PR_PAID if success else PR_UNPAID))
return success return success
@staticmethod @staticmethod
@ -968,7 +968,7 @@ class LNWallet(LNWorker):
"Other clients will likely not be able to send to us.") "Other clients will likely not be able to send to us.")
payment_preimage = os.urandom(32) payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage) payment_hash = sha256(payment_preimage)
info = InvoiceInfo(payment_hash, amount_sat, RECEIVED, PR_UNPAID) info = PaymentInfo(payment_hash, amount_sat, RECEIVED, PR_UNPAID)
amount_btc = amount_sat/Decimal(COIN) if amount_sat else None amount_btc = amount_sat/Decimal(COIN) if amount_sat else None
lnaddr = LnAddr(payment_hash, amount_btc, lnaddr = LnAddr(payment_hash, amount_btc,
tags=[('d', message), tags=[('d', message),
@ -988,7 +988,7 @@ class LNWallet(LNWorker):
'invoice': invoice 'invoice': invoice
} }
self.save_preimage(payment_hash, payment_preimage) self.save_preimage(payment_hash, payment_preimage)
self.save_invoice_info(info) self.save_payment_info(info)
self.wallet.add_payment_request(req) self.wallet.add_payment_request(req)
self.wallet.set_label(key, message) self.wallet.set_label(key, message)
return key return key
@ -1002,36 +1002,36 @@ class LNWallet(LNWorker):
def get_preimage(self, payment_hash: bytes) -> bytes: def get_preimage(self, payment_hash: bytes) -> bytes:
return bfh(self.preimages.get(bh2u(payment_hash))) return bfh(self.preimages.get(bh2u(payment_hash)))
def get_invoice_info(self, payment_hash: bytes) -> bytes: def get_payment_info(self, payment_hash: bytes) -> bytes:
key = payment_hash.hex() key = payment_hash.hex()
with self.lock: with self.lock:
if key not in self.invoices: if key not in self.payments:
raise UnknownPaymentHash(payment_hash) raise UnknownPaymentHash(payment_hash)
amount, direction, status = self.invoices[key] amount, direction, status = self.payments[key]
return InvoiceInfo(payment_hash, amount, direction, status) return PaymentInfo(payment_hash, amount, direction, status)
def save_invoice_info(self, info): def save_payment_info(self, info):
key = info.payment_hash.hex() key = info.payment_hash.hex()
with self.lock: with self.lock:
self.invoices[key] = info.amount, info.direction, info.status self.payments[key] = info.amount, info.direction, info.status
self.storage.put('lightning_invoices2', self.invoices) self.storage.put('lightning_payments', self.payments)
self.storage.write() self.storage.write()
def get_invoice_status(self, payment_hash): def get_payment_status(self, payment_hash):
try: try:
info = self.get_invoice_info(payment_hash) info = self.get_payment_info(payment_hash)
return info.status return info.status
except UnknownPaymentHash: except UnknownPaymentHash:
return PR_UNKNOWN return PR_UNKNOWN
def set_invoice_status(self, payment_hash: bytes, status): def set_payment_status(self, payment_hash: bytes, status):
try: try:
info = self.get_invoice_info(payment_hash) info = self.get_payment_info(payment_hash)
except UnknownPaymentHash: except UnknownPaymentHash:
# if we are forwarding # if we are forwarding
return return
info = info._replace(status=status) info = info._replace(status=status)
self.save_invoice_info(info) self.save_payment_info(info)
if info.direction == RECEIVED and info.status == PR_PAID: if info.direction == RECEIVED and info.status == PR_PAID:
self.network.trigger_callback('payment_received', self.wallet, bh2u(payment_hash), PR_PAID) self.network.trigger_callback('payment_received', self.wallet, bh2u(payment_hash), PR_PAID)
@ -1077,13 +1077,13 @@ class LNWallet(LNWorker):
cltv_expiry_delta)])) cltv_expiry_delta)]))
return routing_hints return routing_hints
def delete_invoice(self, payment_hash_hex: str): def delete_payment(self, payment_hash_hex: str):
try: try:
with self.lock: with self.lock:
del self.invoices[payment_hash_hex] del self.payments[payment_hash_hex]
except KeyError: except KeyError:
return return
self.storage.put('lightning_invoices', self.invoices) self.storage.put('lightning_payments', self.payments)
self.storage.write() self.storage.write()
def get_balance(self): def get_balance(self):

View file

@ -21,7 +21,7 @@ from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet from electrum.lnworker import LNWallet
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from electrum.logging import console_stderr_handler from electrum.logging import console_stderr_handler
from electrum.lnworker import InvoiceInfo, RECEIVED, PR_UNPAID from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID
from .test_lnchannel import create_test_channels from .test_lnchannel import create_test_channels
from . import ElectrumTestCase from . import ElectrumTestCase
@ -88,8 +88,7 @@ class MockLNWallet:
self.node_keypair = local_keypair self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue) self.network = MockNetwork(tx_queue)
self.channels = {self.chan.channel_id: self.chan} self.channels = {self.chan.channel_id: self.chan}
self.invoices = {} self.payments = {}
self.inflight = {}
self.wallet = MockWallet() self.wallet = MockWallet()
self.localfeatures = LnLocalFeatures(0) self.localfeatures = LnLocalFeatures(0)
self.pending_payments = defaultdict(asyncio.Future) self.pending_payments = defaultdict(asyncio.Future)
@ -120,13 +119,11 @@ class MockLNWallet:
def on_channels_updated(self): def on_channels_updated(self):
pass pass
def save_invoice(*args, is_paid=False):
pass
preimages = {} preimages = {}
get_invoice_info = LNWallet.get_invoice_info get_payment_info = LNWallet.get_payment_info
save_invoice_info = LNWallet.save_invoice_info save_payment_info = LNWallet.save_payment_info
set_invoice_status = LNWallet.set_invoice_status set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status
save_preimage = LNWallet.save_preimage save_preimage = LNWallet.save_preimage
get_preimage = LNWallet.get_preimage get_preimage = LNWallet.get_preimage
_create_route_from_invoice = LNWallet._create_route_from_invoice _create_route_from_invoice = LNWallet._create_route_from_invoice
@ -216,9 +213,9 @@ class TestPeer(ElectrumTestCase):
amount_btc = amount_sat/Decimal(COIN) amount_btc = amount_sat/Decimal(COIN)
payment_preimage = os.urandom(32) payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage) RHASH = sha256(payment_preimage)
info = InvoiceInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID) info = PaymentInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID)
w2.save_preimage(RHASH, payment_preimage) w2.save_preimage(RHASH, payment_preimage)
w2.save_invoice_info(info) w2.save_payment_info(info)
lnaddr = LnAddr( lnaddr = LnAddr(
RHASH, RHASH,
amount_btc, amount_btc,

View file

@ -574,7 +574,7 @@ class Abstract_Wallet(AddressSynchronizer):
if request_type == PR_TYPE_ONCHAIN: if request_type == PR_TYPE_ONCHAIN:
item['status'] = PR_PAID if item.get('txid') is not None else PR_UNPAID item['status'] = PR_PAID if item.get('txid') is not None else PR_UNPAID
elif self.lnworker and request_type == PR_TYPE_LN: elif self.lnworker and request_type == PR_TYPE_LN:
item['status'] = self.lnworker.get_invoice_status(bfh(item['rhash'])) item['status'] = self.lnworker.get_payment_status(bfh(item['rhash']))
else: else:
return return
self.check_if_expired(item) self.check_if_expired(item)
@ -1367,7 +1367,7 @@ class Abstract_Wallet(AddressSynchronizer):
if conf is not None: if conf is not None:
req['confirmations'] = conf req['confirmations'] = conf
elif self.lnworker and _type == PR_TYPE_LN: elif self.lnworker and _type == PR_TYPE_LN:
req['status'] = self.lnworker.get_invoice_status(bfh(key)) req['status'] = self.lnworker.get_payment_status(bfh(key))
else: else:
return return
self.check_if_expired(req) self.check_if_expired(req)
@ -1443,7 +1443,7 @@ class Abstract_Wallet(AddressSynchronizer):
if key in self.receive_requests: if key in self.receive_requests:
self.remove_payment_request(key) self.remove_payment_request(key)
elif self.lnworker: elif self.lnworker:
self.lnworker.delete_invoice(key) self.lnworker.delete_payment(key)
def delete_invoice(self, key): def delete_invoice(self, key):
""" lightning or on-chain """ """ lightning or on-chain """
@ -1451,7 +1451,7 @@ class Abstract_Wallet(AddressSynchronizer):
self.invoices.pop(key) self.invoices.pop(key)
self.storage.put('invoices', self.invoices) self.storage.put('invoices', self.invoices)
elif self.lnworker: elif self.lnworker:
self.lnworker.delete_invoice(key) self.lnworker.delete_payment(key)
def remove_payment_request(self, addr): def remove_payment_request(self, addr):
if addr not in self.receive_requests: if addr not in self.receive_requests: