ln gossip: don't put own channels into db; always pass them to fn calls

Previously we would put fake chan announcement and fake outgoing chan upd
for own channels into db (to make path finding work). See Peer.add_own_channel().
Now, instead of above, we pass a "my_channels" param to the relevant ChannelDB methods.
This commit is contained in:
SomberNight 2020-02-17 20:38:41 +01:00
parent 7d65fe1ba3
commit 46d8080c76
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
6 changed files with 188 additions and 149 deletions

View file

@ -39,9 +39,11 @@ from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enab
from .logging import Logger
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
from .lnmsg import decode_msg
if TYPE_CHECKING:
from .network import Network
from .lnchannel import Channel
class UnknownEvenFeatureBits(Exception): pass
@ -63,7 +65,7 @@ class ChannelInfo(NamedTuple):
capacity_sat: Optional[int]
@staticmethod
def from_msg(payload):
def from_msg(payload: dict) -> 'ChannelInfo':
features = int.from_bytes(payload['features'], 'big')
validate_features(features)
channel_id = payload['short_channel_id']
@ -78,6 +80,11 @@ class ChannelInfo(NamedTuple):
capacity_sat = capacity_sat
)
@staticmethod
def from_raw_msg(raw: bytes) -> 'ChannelInfo':
payload_dict = decode_msg(raw)[1]
return ChannelInfo.from_msg(payload_dict)
class Policy(NamedTuple):
key: bytes
@ -91,7 +98,7 @@ class Policy(NamedTuple):
timestamp: int
@staticmethod
def from_msg(payload):
def from_msg(payload: dict) -> 'Policy':
return Policy(
key = payload['short_channel_id'] + payload['start_node'],
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
@ -248,11 +255,11 @@ class ChannelDB(SqlDB):
self.ca_verifier = LNChannelVerifier(network, self)
# initialized in load_data
self._channels = {} # type: Dict[bytes, ChannelInfo]
self._policies = {}
self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy
self._nodes = {}
# node_id -> (host, port, ts)
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
self._channels_for_node = defaultdict(set)
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
self.data_loaded = asyncio.Event()
self.network = network # only for callback
@ -495,17 +502,6 @@ class ChannelDB(SqlDB):
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[Policy]:
if not start_node_id or not short_channel_id: return None
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None:
return self.get_policy_for_node(short_channel_id, start_node_id)
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg:
return None
return Policy.from_msg(msg) # won't actually be written to DB
def get_old_policies(self, delta):
now = int(time.time())
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
@ -587,12 +583,56 @@ class ChannelDB(SqlDB):
out.add(short_channel_id)
self.logger.info(f'semi-orphaned: {len(out)}')
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
return self._policies.get((node_id, short_channel_id))
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None: # publicly announced channel
policy = self._policies.get((node_id, short_channel_id))
if policy:
return policy
else: # private channel
chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
if chan_upd_dict:
return Policy.from_msg(chan_upd_dict)
# check if it's one of our own channels
if not my_channels:
return
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
if not chan:
return
if node_id == chan.node_id: # incoming direction (to us)
remote_update_raw = chan.get_remote_update()
if not remote_update_raw:
return
now = int(time.time())
remote_update_decoded = decode_msg(remote_update_raw)[1]
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
remote_update_decoded['start_node'] = node_id
return Policy.from_msg(remote_update_decoded)
elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
local_update_decoded['start_node'] = node_id
return Policy.from_msg(local_update_decoded)
def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
return self._channels.get(channel_id)
def get_channel_info(self, short_channel_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]:
ret = self._channels.get(short_channel_id)
if ret:
return ret
# check if it's one of our own channels
if not my_channels:
return
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
return ci._replace(capacity_sat=chan.constraints.capacity)
def get_channels_for_node(self, node_id) -> Set[bytes]:
"""Returns the set of channels that have node_id as one of the endpoints."""
return self._channels_for_node.get(node_id) or set()
def get_channels_for_node(self, node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
"""Returns the set of short channel IDs where node_id is one of the channel participants."""
relevant_channels = self._channels_for_node.get(node_id) or set()
relevant_channels = set(relevant_channels) # copy
# add our own channels # TODO maybe slow?
for chan in (my_channels.values() or []):
if node_id in (chan.node_id, chan.get_local_pubkey()):
relevant_channels.add(chan.short_channel_id)
return relevant_channels

View file

@ -32,13 +32,14 @@ import time
import threading
from . import ecc
from . import constants
from .util import bfh, bh2u
from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d
from .transaction import Transaction, PartialTransaction
from .logging import Logger
from .lnonion import decode_onion_error
from . import lnutil
from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints,
get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx,
sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey,
@ -47,10 +48,10 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey
funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script,
ShortChannelID, map_htlcs_to_ctx_output_idxs)
from .lnutil import FeeUpdate
from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo
from .lnhtlc import HTLCManager
from .lnmsg import encode_msg, decode_msg
if TYPE_CHECKING:
from .lnworker import LNWallet
@ -136,7 +137,6 @@ class Channel(Logger):
self.funding_outpoint = state["funding_outpoint"]
self.node_id = bfh(state["node_id"])
self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
self.short_channel_id_predicted = self.short_channel_id
self.onion_keys = state['onion_keys']
self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
@ -144,6 +144,7 @@ class Channel(Logger):
self.peer_state = peer_states.DISCONNECTED
self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
self._outgoing_channel_update = None # type: Optional[bytes]
self._chan_ann_without_sigs = None # type: Optional[bytes]
self.revocation_store = RevocationStore(state["revocation_store"])
def set_onion_key(self, key, value):
@ -158,12 +159,77 @@ class Channel(Logger):
def get_data_loss_protect_remote_pcp(self, key):
return self.data_loss_protect_remote_pcp.get(key)
def set_remote_update(self, raw):
def get_local_pubkey(self) -> bytes:
if not self.lnworker:
raise Exception('lnworker not set for channel!')
return self.lnworker.node_keypair.pubkey
def set_remote_update(self, raw: bytes) -> None:
self.storage['remote_update'] = raw.hex()
def get_remote_update(self):
def get_remote_update(self) -> Optional[bytes]:
return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None
def get_outgoing_gossip_channel_update(self) -> bytes:
if self._outgoing_channel_update is not None:
return self._outgoing_channel_update
if not self.lnworker:
raise Exception('lnworker not set for channel!')
sorted_node_ids = list(sorted([self.node_id, self.get_local_pubkey()]))
channel_flags = b'\x00' if sorted_node_ids[0] == self.get_local_pubkey() else b'\x01'
now = int(time.time())
htlc_maximum_msat = min(self.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * self.constraints.capacity)
chan_upd = encode_msg(
"channel_update",
short_channel_id=self.short_channel_id,
channel_flags=channel_flags,
message_flags=b'\x01',
cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
htlc_minimum_msat=self.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
chain_hash=constants.net.rev_genesis_bytes(),
timestamp=now.to_bytes(4, byteorder="big"),
)
sighash = sha256d(chan_upd[2 + 64:])
sig = ecc.ECPrivkey(self.lnworker.node_keypair.privkey).sign(sighash, ecc.sig_string_from_r_and_s)
message_type, payload = decode_msg(chan_upd)
payload['signature'] = sig
chan_upd = encode_msg(message_type, **payload)
self._outgoing_channel_update = chan_upd
return chan_upd
def construct_channel_announcement_without_sigs(self) -> bytes:
if self._chan_ann_without_sigs is not None:
return self._chan_ann_without_sigs
if not self.lnworker:
raise Exception('lnworker not set for channel!')
bitcoin_keys = [self.config[REMOTE].multisig_key.pubkey,
self.config[LOCAL].multisig_key.pubkey]
node_ids = [self.node_id, self.get_local_pubkey()]
sorted_node_ids = list(sorted(node_ids))
if sorted_node_ids != node_ids:
node_ids = sorted_node_ids
bitcoin_keys.reverse()
chan_ann = encode_msg("channel_announcement",
len=0,
features=b'',
chain_hash=constants.net.rev_genesis_bytes(),
short_channel_id=self.short_channel_id,
node_id_1=node_ids[0],
node_id_2=node_ids[1],
bitcoin_key_1=bitcoin_keys[0],
bitcoin_key_2=bitcoin_keys[1]
)
self._chan_ann_without_sigs = chan_ann
return chan_ann
def set_short_channel_id(self, short_id):
self.short_channel_id = short_id
self.storage["short_channel_id"] = short_id

View file

@ -953,112 +953,25 @@ class Peer(Logger):
assert chan.config[LOCAL].funding_locked_received
chan.set_state(channel_states.OPEN)
self.network.trigger_callback('channel', chan)
self.add_own_channel(chan)
# peer may have sent us a channel update for the incoming direction previously
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
if pending_channel_update:
chan.set_remote_update(pending_channel_update['raw'])
self.logger.info(f"CHANNEL OPENING COMPLETED for {scid}")
forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
if forwarding_enabled:
# send channel_update of outgoing edge to peer,
# so that channel can be used to to receive payments
self.logger.info(f"sending channel update for outgoing edge of {scid}")
chan_upd = self.get_outgoing_gossip_channel_update_for_chan(chan)
chan_upd = chan.get_outgoing_gossip_channel_update()
self.transport.send_bytes(chan_upd)
def add_own_channel(self, chan):
# add channel to database
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
sorted_node_ids = list(sorted(self.node_ids))
if sorted_node_ids != self.node_ids:
bitcoin_keys.reverse()
# note: we inject a channel announcement, and a channel update (for outgoing direction)
# This is atm needed for
# - finding routes
# - the ChanAnn is needed so that we can anchor to it a future ChanUpd
# that the remote sends, even if the channel was not announced
# (from BOLT-07: "MAY create a channel_update to communicate the channel
# parameters to the final node, even though the channel has not yet been announced")
self.channel_db.add_channel_announcement(
{
"short_channel_id": chan.short_channel_id,
"node_id_1": sorted_node_ids[0],
"node_id_2": sorted_node_ids[1],
'chain_hash': constants.net.rev_genesis_bytes(),
'len': b'\x00\x00',
'features': b'',
'bitcoin_key_1': bitcoin_keys[0],
'bitcoin_key_2': bitcoin_keys[1]
},
trusted=True)
# only inject outgoing direction:
chan_upd_bytes = self.get_outgoing_gossip_channel_update_for_chan(chan)
chan_upd_payload = decode_msg(chan_upd_bytes)[1]
self.channel_db.add_channel_update(chan_upd_payload)
# peer may have sent us a channel update for the incoming direction previously
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
if pending_channel_update:
chan.set_remote_update(pending_channel_update['raw'])
# add remote update with a fresh timestamp
if chan.get_remote_update():
now = int(time.time())
remote_update_decoded = decode_msg(chan.get_remote_update())[1]
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
self.channel_db.add_channel_update(remote_update_decoded)
def get_outgoing_gossip_channel_update_for_chan(self, chan: Channel) -> bytes:
if chan._outgoing_channel_update is not None:
return chan._outgoing_channel_update
sorted_node_ids = list(sorted(self.node_ids))
channel_flags = b'\x00' if sorted_node_ids[0] == privkey_to_pubkey(self.privkey) else b'\x01'
now = int(time.time())
htlc_maximum_msat = min(chan.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * chan.constraints.capacity)
chan_upd = encode_msg(
"channel_update",
short_channel_id=chan.short_channel_id,
channel_flags=channel_flags,
message_flags=b'\x01',
cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
htlc_minimum_msat=chan.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
chain_hash=constants.net.rev_genesis_bytes(),
timestamp=now.to_bytes(4, byteorder="big"),
)
sighash = sha256d(chan_upd[2 + 64:])
sig = ecc.ECPrivkey(self.privkey).sign(sighash, sig_string_from_r_and_s)
message_type, payload = decode_msg(chan_upd)
payload['signature'] = sig
chan_upd = encode_msg(message_type, **payload)
chan._outgoing_channel_update = chan_upd
return chan_upd
def send_announcement_signatures(self, chan: Channel):
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
chan.config[LOCAL].multisig_key.pubkey]
sorted_node_ids = list(sorted(self.node_ids))
if sorted_node_ids != self.node_ids:
node_ids = sorted_node_ids
bitcoin_keys.reverse()
else:
node_ids = self.node_ids
chan_ann = encode_msg("channel_announcement",
len=0,
#features not set (defaults to zeros)
chain_hash=constants.net.rev_genesis_bytes(),
short_channel_id=chan.short_channel_id,
node_id_1=node_ids[0],
node_id_2=node_ids[1],
bitcoin_key_1=bitcoin_keys[0],
bitcoin_key_2=bitcoin_keys[1]
)
to_hash = chan_ann[256+2:]
h = sha256d(to_hash)
bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(h, sig_string_from_r_and_s)
node_signature = ecc.ECPrivkey(self.privkey).sign(h, sig_string_from_r_and_s)
chan_ann = chan.construct_channel_announcement_without_sigs()
preimage = chan_ann[256+2:]
msg_hash = sha256d(preimage)
bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(msg_hash, sig_string_from_r_and_s)
node_signature = ecc.ECPrivkey(self.privkey).sign(msg_hash, sig_string_from_r_and_s)
self.send_message("announcement_signatures",
channel_id=chan.channel_id,
short_channel_id=chan.short_channel_id,
@ -1066,7 +979,7 @@ class Peer(Logger):
bitcoin_signature=bitcoin_signature
)
return h, node_signature, bitcoin_signature
return msg_hash, node_signature, bitcoin_signature
def on_update_fail_htlc(self, payload):
channel_id = payload["channel_id"]
@ -1255,7 +1168,7 @@ class Peer(Logger):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return
outgoing_chan_upd = self.get_outgoing_gossip_channel_update_for_chan(next_chan)[2:]
outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:]
outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
if next_chan.get_state() != channel_states.OPEN:
self.logger.info(f"cannot forward htlc. next_chan not OPEN: {next_chan_scid} in state {next_chan.get_state()}")

View file

@ -129,18 +129,20 @@ class LNPathFinder(Logger):
self.blacklist.add(short_channel_id)
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
payment_amt_msat: int, ignore_costs=False, is_mine=False) -> Tuple[float, int]:
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]:
"""Heuristic cost of going through a channel.
Returns (heuristic_cost, fee_for_edge_msat).
"""
channel_info = self.channel_db.get_channel_info(short_channel_id)
channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels)
if channel_info is None:
return float('inf'), 0
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node)
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels)
if channel_policy is None:
return float('inf'), 0
# channels that did not publish both policies often return temporary channel failure
if self.channel_db.get_policy_for_node(short_channel_id, end_node) is None and not is_mine:
if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \
and not is_mine:
return float('inf'), 0
if channel_policy.is_disabled():
return float('inf'), 0
@ -164,8 +166,9 @@ class LNPathFinder(Logger):
@profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int,
my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
invoice_amount_msat: int, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) \
-> Optional[Sequence[Tuple[bytes, bytes]]]:
"""Return a path from nodeA to nodeB.
Returns a list of (node_id, short_channel_id) representing a path.
@ -175,8 +178,7 @@ class LNPathFinder(Logger):
assert type(nodeA) is bytes
assert type(nodeB) is bytes
assert type(invoice_amount_msat) is int
if my_channels is None: my_channels = []
my_channels = {chan.short_channel_id: chan for chan in my_channels}
if my_channels is None: my_channels = {}
# FIXME paths cannot be longer than 20 edges (onion packet)...
@ -204,7 +206,8 @@ class LNPathFinder(Logger):
end_node=edge_endnode,
payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA),
is_mine=is_mine)
is_mine=is_mine,
my_channels=my_channels)
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour
@ -222,11 +225,11 @@ class LNPathFinder(Logger):
# so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now:
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
assert isinstance(edge_channel_id, bytes)
if edge_channel_id in self.blacklist:
continue
channel_info = self.channel_db.get_channel_info(edge_channel_id)
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
inspect_edge()
else:
@ -241,14 +244,17 @@ class LNPathFinder(Logger):
edge_startnode = edge_endnode
return path
def create_route_from_path(self, path, from_node_id: bytes) -> LNPaymentRoute:
def create_route_from_path(self, path, from_node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
assert isinstance(from_node_id, bytes)
if path is None:
raise Exception('cannot create route from None path')
route = []
prev_node_id = from_node_id
for node_id, short_channel_id in path:
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
node_id=prev_node_id,
my_channels=my_channels)
if channel_policy is None:
raise NoChannelPolicy(short_channel_id)
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))

View file

@ -942,16 +942,20 @@ class LNWallet(LNWorker):
random.shuffle(r_tags)
with self.lock:
channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
for private_route in r_tags:
if len(private_route) == 0:
continue
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
continue
border_node_pubkey = private_route[0][0]
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat,
my_channels=scid_to_my_channels)
if not path:
continue
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
my_channels=scid_to_my_channels)
# we need to shift the node pubkey by one towards the destination:
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
private_route_rest = [edge[1:] for edge in private_route]
@ -961,7 +965,9 @@ class LNWallet(LNWorker):
short_channel_id = ShortChannelID(short_channel_id)
# if we have a routing policy for this edge in the db, that takes precedence,
# as it is likely from a previous failure
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
node_id=prev_node_id,
my_channels=scid_to_my_channels)
if channel_policy:
fee_base_msat = channel_policy.fee_base_msat
fee_proportional_millionths = channel_policy.fee_proportional_millionths
@ -977,10 +983,12 @@ class LNWallet(LNWorker):
break
# if could not find route using any hint; try without hint now
if route is None:
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat, channels)
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat,
my_channels=scid_to_my_channels)
if not path:
raise NoPathFound()
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
my_channels=scid_to_my_channels)
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
self.logger.info(f"rejecting insane route {route}")
raise NoPathFound()
@ -1099,6 +1107,8 @@ class LNWallet(LNWorker):
routing_hints = []
with self.lock:
channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
# note: currently we add *all* our channels; but this might be a privacy leak?
for chan in channels:
# check channel is open
@ -1110,7 +1120,7 @@ class LNWallet(LNWorker):
continue
chan_id = chan.short_channel_id
assert isinstance(chan_id, bytes), chan_id
channel_info = self.channel_db.get_channel_info(chan_id)
channel_info = self.channel_db.get_channel_info(chan_id, my_channels=scid_to_my_channels)
# note: as a fallback, if we don't have a channel update for the
# incoming direction of our private channel, we fill the invoice with garbage.
# the sender should still be able to pay us, but will incur an extra round trip
@ -1120,7 +1130,8 @@ class LNWallet(LNWorker):
cltv_expiry_delta = 1 # lnd won't even try with zero
missing_info = True
if channel_info:
policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id)
policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id,
my_channels=scid_to_my_channels)
if policy:
fee_base_msat = policy.fee_base_msat
fee_proportional_millionths = policy.fee_proportional_millionths

View file

@ -18,7 +18,7 @@ from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure, LnLocalFeatures
from electrum.lnchannel import channel_states, peer_states
from electrum.lnchannel import channel_states, peer_states, Channel
from electrum.lnrouter import LNPathFinder
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound
@ -77,7 +77,7 @@ class MockWallet:
return False
class MockLNWallet:
def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
self.remote_keypair = remote_keypair
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
@ -88,6 +88,8 @@ class MockLNWallet:
self.localfeatures = LnLocalFeatures(0)
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
self.pending_payments = defaultdict(asyncio.Future)
chan.lnworker = self
chan.node_id = remote_keypair.pubkey
def get_invoice_status(self, key):
pass
@ -127,6 +129,7 @@ class MockLNWallet:
_pay_to_route = LNWallet._pay_to_route
force_close_channel = LNWallet.force_close_channel
get_first_timestamp = lambda self: 0
payment_completed = LNWallet.payment_completed
class MockTransport:
def __init__(self, name):
@ -264,7 +267,7 @@ class TestPeer(ElectrumTestCase):
pay_req = self.prepare_invoice(w2)
async def pay():
result = await LNWallet._pay(w1, pay_req)
self.assertEqual(result, True)
self.assertTrue(result)
gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
async def f():