mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-28 16:01:30 +00:00
lnworker._pay: allow specifying path as argument
not exposed to CLI/etc yet but will be used in tests soon
This commit is contained in:
parent
63b18dc30f
commit
7153e753d1
4 changed files with 50 additions and 13 deletions
|
@ -119,7 +119,7 @@ class Policy(NamedTuple):
|
||||||
return ShortChannelID.normalize(self.key[0:8])
|
return ShortChannelID.normalize(self.key[0:8])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_node(self):
|
def start_node(self) -> bytes:
|
||||||
return self.key[8:]
|
return self.key[8:]
|
||||||
|
|
||||||
|
|
||||||
|
@ -732,5 +732,18 @@ class ChannelDB(SqlDB):
|
||||||
relevant_channels.add(chan.short_channel_id)
|
relevant_channels.add(chan.short_channel_id)
|
||||||
return relevant_channels
|
return relevant_channels
|
||||||
|
|
||||||
|
def get_endnodes_for_chan(self, short_channel_id: ShortChannelID, *,
|
||||||
|
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[Tuple[bytes, bytes]]:
|
||||||
|
channel_info = self.get_channel_info(short_channel_id)
|
||||||
|
if channel_info is not None: # publicly announced channel
|
||||||
|
return channel_info.node1_id, channel_info.node2_id
|
||||||
|
# check if it's one of our own channels
|
||||||
|
if not my_channels:
|
||||||
|
return
|
||||||
|
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
||||||
|
if not chan:
|
||||||
|
return
|
||||||
|
return chan.get_local_pubkey(), chan.node_id
|
||||||
|
|
||||||
def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']:
|
def get_node_info_for_node_id(self, node_id: bytes) -> Optional['NodeInfo']:
|
||||||
return self._nodes.get(node_id)
|
return self._nodes.get(node_id)
|
||||||
|
|
|
@ -108,7 +108,8 @@ class Peer(Logger):
|
||||||
return
|
return
|
||||||
assert channel_id
|
assert channel_id
|
||||||
chan = self.get_channel_by_id(channel_id)
|
chan = self.get_channel_by_id(channel_id)
|
||||||
assert chan
|
if not chan:
|
||||||
|
raise Exception(f"channel {channel_id.hex()} not found for peer {self.pubkey.hex()}")
|
||||||
chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
|
chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
|
||||||
if is_commitment_signed:
|
if is_commitment_signed:
|
||||||
# saving now, to ensure replaying updates works (in case of channel reestablishment)
|
# saving now, to ensure replaying updates works (in case of channel reestablishment)
|
||||||
|
|
|
@ -45,6 +45,9 @@ class NoChannelPolicy(Exception):
|
||||||
super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}')
|
super().__init__(f'cannot find channel policy for short_channel_id: {short_channel_id}')
|
||||||
|
|
||||||
|
|
||||||
|
class LNPathInconsistent(Exception): pass
|
||||||
|
|
||||||
|
|
||||||
def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int:
|
def fee_for_edge_msat(forwarded_amount_msat: int, fee_base_msat: int, fee_proportional_millionths: int) -> int:
|
||||||
return fee_base_msat \
|
return fee_base_msat \
|
||||||
+ (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
|
+ (forwarded_amount_msat * fee_proportional_millionths // 1_000_000)
|
||||||
|
@ -286,6 +289,9 @@ class LNPathFinder(Logger):
|
||||||
for edge in path:
|
for edge in path:
|
||||||
node_id = edge.node_id
|
node_id = edge.node_id
|
||||||
short_channel_id = edge.short_channel_id
|
short_channel_id = edge.short_channel_id
|
||||||
|
_endnodes = self.channel_db.get_endnodes_for_chan(short_channel_id, my_channels=my_channels)
|
||||||
|
if _endnodes and sorted(_endnodes) != sorted([prev_node_id, node_id]):
|
||||||
|
raise LNPathInconsistent("edges do not chain together")
|
||||||
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
|
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
|
||||||
node_id=prev_node_id,
|
node_id=prev_node_id,
|
||||||
my_channels=my_channels)
|
my_channels=my_channels)
|
||||||
|
|
|
@ -60,7 +60,8 @@ from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput
|
||||||
from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket
|
from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket
|
||||||
from .lnmsg import decode_msg
|
from .lnmsg import decode_msg
|
||||||
from .i18n import _
|
from .i18n import _
|
||||||
from .lnrouter import RouteEdge, LNPaymentRoute, is_route_sane_to_use
|
from .lnrouter import (RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_sane_to_use,
|
||||||
|
NoChannelPolicy, LNPathInconsistent)
|
||||||
from .address_synchronizer import TX_HEIGHT_LOCAL
|
from .address_synchronizer import TX_HEIGHT_LOCAL
|
||||||
from . import lnsweep
|
from . import lnsweep
|
||||||
from .lnwatcher import LNWalletWatcher
|
from .lnwatcher import LNWalletWatcher
|
||||||
|
@ -815,7 +816,8 @@ class LNWallet(LNWorker):
|
||||||
return chan
|
return chan
|
||||||
|
|
||||||
async def _pay(self, invoice: str, amount_sat: int = None, *,
|
async def _pay(self, invoice: str, amount_sat: int = None, *,
|
||||||
attempts: int = 1) -> Tuple[bool, List[PaymentAttemptLog]]:
|
attempts: int = 1,
|
||||||
|
full_path: LNPaymentPath = None) -> Tuple[bool, List[PaymentAttemptLog]]:
|
||||||
lnaddr = self._check_invoice(invoice, amount_sat)
|
lnaddr = self._check_invoice(invoice, amount_sat)
|
||||||
payment_hash = lnaddr.paymenthash
|
payment_hash = lnaddr.paymenthash
|
||||||
key = payment_hash.hex()
|
key = payment_hash.hex()
|
||||||
|
@ -837,7 +839,7 @@ class LNWallet(LNWorker):
|
||||||
# graph updates might occur during the computation
|
# graph updates might occur during the computation
|
||||||
self.set_invoice_status(key, PR_ROUTING)
|
self.set_invoice_status(key, PR_ROUTING)
|
||||||
util.trigger_callback('invoice_status', key)
|
util.trigger_callback('invoice_status', key)
|
||||||
route = await run_in_thread(self._create_route_from_invoice, lnaddr)
|
route = await run_in_thread(partial(self._create_route_from_invoice, lnaddr, full_path=full_path))
|
||||||
self.set_invoice_status(key, PR_INFLIGHT)
|
self.set_invoice_status(key, PR_INFLIGHT)
|
||||||
util.trigger_callback('invoice_status', key)
|
util.trigger_callback('invoice_status', key)
|
||||||
payment_attempt_log = await self._pay_to_route(route, lnaddr)
|
payment_attempt_log = await self._pay_to_route(route, lnaddr)
|
||||||
|
@ -974,7 +976,8 @@ class LNWallet(LNWorker):
|
||||||
return addr
|
return addr
|
||||||
|
|
||||||
@profiler
|
@profiler
|
||||||
def _create_route_from_invoice(self, decoded_invoice: 'LnAddr') -> LNPaymentRoute:
|
def _create_route_from_invoice(self, decoded_invoice: 'LnAddr',
|
||||||
|
*, full_path: LNPaymentPath = None) -> LNPaymentRoute:
|
||||||
amount_msat = int(decoded_invoice.amount * COIN * 1000)
|
amount_msat = int(decoded_invoice.amount * COIN * 1000)
|
||||||
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
||||||
# use 'r' field from invoice
|
# use 'r' field from invoice
|
||||||
|
@ -995,12 +998,22 @@ class LNWallet(LNWorker):
|
||||||
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
|
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
|
||||||
continue
|
continue
|
||||||
border_node_pubkey = private_route[0][0]
|
border_node_pubkey = private_route[0][0]
|
||||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat,
|
if full_path:
|
||||||
my_channels=scid_to_my_channels)
|
# user pre-selected path. check that end of given path coincides with private_route:
|
||||||
|
if [edge.short_channel_id for edge in full_path[-len(private_route):]] != [edge[1] for edge in private_route]:
|
||||||
|
continue
|
||||||
|
path = full_path[:-len(private_route)]
|
||||||
|
else:
|
||||||
|
# find path now on public graph, to border node
|
||||||
|
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat,
|
||||||
|
my_channels=scid_to_my_channels)
|
||||||
if not path:
|
if not path:
|
||||||
continue
|
continue
|
||||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
|
try:
|
||||||
my_channels=scid_to_my_channels)
|
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
|
||||||
|
my_channels=scid_to_my_channels)
|
||||||
|
except NoChannelPolicy:
|
||||||
|
continue
|
||||||
# we need to shift the node pubkey by one towards the destination:
|
# we need to shift the node pubkey by one towards the destination:
|
||||||
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
|
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
|
||||||
private_route_rest = [edge[1:] for edge in private_route]
|
private_route_rest = [edge[1:] for edge in private_route]
|
||||||
|
@ -1033,8 +1046,11 @@ class LNWallet(LNWorker):
|
||||||
break
|
break
|
||||||
# if could not find route using any hint; try without hint now
|
# if could not find route using any hint; try without hint now
|
||||||
if route is None:
|
if route is None:
|
||||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat,
|
if full_path: # user pre-selected path
|
||||||
my_channels=scid_to_my_channels)
|
path = full_path
|
||||||
|
else: # find path now
|
||||||
|
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat,
|
||||||
|
my_channels=scid_to_my_channels)
|
||||||
if not path:
|
if not path:
|
||||||
raise NoPathFound()
|
raise NoPathFound()
|
||||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
|
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
|
||||||
|
@ -1043,7 +1059,8 @@ class LNWallet(LNWorker):
|
||||||
self.logger.info(f"rejecting insane route {route}")
|
self.logger.info(f"rejecting insane route {route}")
|
||||||
raise NoPathFound()
|
raise NoPathFound()
|
||||||
assert len(route) > 0
|
assert len(route) > 0
|
||||||
assert route[-1].node_id == invoice_pubkey
|
if route[-1].node_id != invoice_pubkey:
|
||||||
|
raise LNPathInconsistent("last node_id != invoice pubkey")
|
||||||
# add features from invoice
|
# add features from invoice
|
||||||
invoice_features = decoded_invoice.get_tag('9') or 0
|
invoice_features = decoded_invoice.get_tag('9') or 0
|
||||||
route[-1].node_features |= invoice_features
|
route[-1].node_features |= invoice_features
|
||||||
|
|
Loading…
Add table
Reference in a new issue