mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-23 17:47:31 +00:00
use 'r' field in invoice when making payments (routing hints)
This commit is contained in:
parent
029ec5a5ab
commit
97393d05aa
3 changed files with 49 additions and 22 deletions
|
@ -1018,19 +1018,18 @@ class Peer(PrintError):
|
|||
await self.receive_commitment(chan)
|
||||
self.revoke(chan)
|
||||
|
||||
async def pay(self, path, chan, amount_msat, payment_hash, pubkey_in_invoice, min_final_cltv_expiry):
|
||||
async def pay(self, route: List[RouteEdge], chan, amount_msat, payment_hash, min_final_cltv_expiry):
|
||||
assert chan.get_state() == "OPEN", chan.get_state()
|
||||
assert amount_msat > 0, "amount_msat is not greater zero"
|
||||
height = self.network.get_local_height()
|
||||
route = self.network.path_finder.create_route_from_path(path, self.lnworker.node_keypair.pubkey)
|
||||
hops_data = []
|
||||
sum_of_deltas = sum(route_edge.channel_policy.cltv_expiry_delta for route_edge in route[1:])
|
||||
sum_of_deltas = sum(route_edge.cltv_expiry_delta for route_edge in route[1:])
|
||||
total_fee = 0
|
||||
final_cltv_expiry_without_deltas = (height + min_final_cltv_expiry)
|
||||
final_cltv_expiry_with_deltas = final_cltv_expiry_without_deltas + sum_of_deltas
|
||||
for idx, route_edge in enumerate(route[1:]):
|
||||
hops_data += [OnionHopsDataSingle(OnionPerHop(route_edge.short_channel_id, amount_msat.to_bytes(8, "big"), final_cltv_expiry_without_deltas.to_bytes(4, "big")))]
|
||||
total_fee += route_edge.channel_policy.fee_base_msat + ( amount_msat * route_edge.channel_policy.fee_proportional_millionths // 1000000 )
|
||||
total_fee += route_edge.fee_base_msat + ( amount_msat * route_edge.fee_proportional_millionths // 1000000 )
|
||||
associated_data = payment_hash
|
||||
secret_key = os.urandom(32)
|
||||
hops_data += [OnionHopsDataSingle(OnionPerHop(b"\x00"*8, amount_msat.to_bytes(8, "big"), (final_cltv_expiry_without_deltas).to_bytes(4, "big")))]
|
||||
|
|
|
@ -28,7 +28,7 @@ import os
|
|||
import json
|
||||
import threading
|
||||
from collections import namedtuple, defaultdict
|
||||
from typing import Sequence, Union, Tuple, Optional
|
||||
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple
|
||||
import binascii
|
||||
import base64
|
||||
import asyncio
|
||||
|
@ -478,14 +478,13 @@ class ChannelDB(JsonDB):
|
|||
direction))
|
||||
|
||||
|
||||
class RouteEdge:
|
||||
|
||||
def __init__(self, node_id: bytes, short_channel_id: bytes,
|
||||
channel_policy: ChannelInfoDirectedPolicy):
|
||||
# "if you travel through short_channel_id, you will reach node_id"
|
||||
self.node_id = node_id
|
||||
self.short_channel_id = short_channel_id
|
||||
self.channel_policy = channel_policy
|
||||
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
|
||||
('short_channel_id', bytes),
|
||||
('fee_base_msat', int),
|
||||
('fee_proportional_millionths', int),
|
||||
('cltv_expiry_delta', int)])):
|
||||
"""if you travel through short_channel_id, you will reach node_id"""
|
||||
pass
|
||||
|
||||
|
||||
class LNPathFinder(PrintError):
|
||||
|
@ -578,7 +577,7 @@ class LNPathFinder(PrintError):
|
|||
path.reverse()
|
||||
return path
|
||||
|
||||
def create_route_from_path(self, path, from_node_id: bytes) -> Sequence[RouteEdge]:
|
||||
def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]:
|
||||
assert type(from_node_id) is bytes
|
||||
if path is None:
|
||||
raise Exception('cannot create route from None path')
|
||||
|
@ -591,6 +590,10 @@ class LNPathFinder(PrintError):
|
|||
channel_policy = channel_info.get_policy_for_node(prev_node_id)
|
||||
if channel_policy is None:
|
||||
raise Exception('cannot find channel policy for short_channel_id: {}'.format(bh2u(short_channel_id)))
|
||||
route.append(RouteEdge(node_id, short_channel_id, channel_policy))
|
||||
route.append(RouteEdge(node_id,
|
||||
short_channel_id,
|
||||
channel_policy.fee_base_msat,
|
||||
channel_policy.fee_proportional_millionths,
|
||||
channel_policy.cltv_expiry_delta))
|
||||
prev_node_id = node_id
|
||||
return route
|
||||
|
|
|
@ -27,6 +27,7 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
|
|||
from .lnutil import LOCAL, REMOTE
|
||||
from .lnaddr import lndecode
|
||||
from .i18n import _
|
||||
from .lnrouter import RouteEdge
|
||||
|
||||
|
||||
NUM_PEERS_TARGET = 4
|
||||
|
@ -237,16 +238,12 @@ class LNWorker(PrintError):
|
|||
def pay(self, invoice, amount_sat=None):
|
||||
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
|
||||
payment_hash = addr.paymenthash
|
||||
invoice_pubkey = addr.pubkey.serialize()
|
||||
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
|
||||
if amount_sat is None:
|
||||
raise InvoiceError(_("Missing amount"))
|
||||
amount_msat = int(amount_sat * 1000)
|
||||
# TODO use 'r' field from invoice
|
||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat)
|
||||
if path is None:
|
||||
raise PaymentFailure(_("No path found"))
|
||||
node_id, short_channel_id = path[0]
|
||||
route = self._create_route_from_invoice(decoded_invoice=addr, amount_msat=amount_msat)
|
||||
node_id, short_channel_id = route[0].node_id, route[0].short_channel_id
|
||||
peer = self.peers[node_id]
|
||||
with self.lock:
|
||||
channels = list(self.channels.values())
|
||||
|
@ -255,9 +252,37 @@ class LNWorker(PrintError):
|
|||
break
|
||||
else:
|
||||
raise Exception("ChannelDB returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
|
||||
coro = peer.pay(path, chan, amount_msat, payment_hash, invoice_pubkey, addr.min_final_cltv_expiry)
|
||||
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.min_final_cltv_expiry)
|
||||
return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
||||
|
||||
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
|
||||
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
||||
# use 'r' field from invoice
|
||||
route = None # type: List[RouteEdge]
|
||||
for tag_type, data in decoded_invoice.tags:
|
||||
if tag_type != 'r': continue
|
||||
private_route = data
|
||||
if len(private_route) == 0: continue
|
||||
border_node_pubkey = private_route[0][0]
|
||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat)
|
||||
if path is None: continue
|
||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
|
||||
# we need to shift the node pubkey by one towards the destination:
|
||||
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
|
||||
private_route_rest = [edge[1:] for edge in private_route]
|
||||
for node_pubkey, edge_rest in zip(private_route_nodes, private_route_rest):
|
||||
short_channel_id, fee_base_msat, fee_proportional_millionths, cltv_expiry_delta = edge_rest
|
||||
route.append(RouteEdge(node_pubkey, short_channel_id, fee_base_msat, fee_proportional_millionths,
|
||||
cltv_expiry_delta))
|
||||
break
|
||||
# if could not find route using any hint; try without hint now
|
||||
if route is None:
|
||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat)
|
||||
if path is None:
|
||||
raise PaymentFailure(_("No path found"))
|
||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
|
||||
return route
|
||||
|
||||
def add_invoice(self, amount_sat, message):
|
||||
payment_preimage = os.urandom(32)
|
||||
RHASH = sha256(payment_preimage)
|
||||
|
|
Loading…
Add table
Reference in a new issue