ln gossip: don't put own channels into db; always pass them to fn calls

Previously we would put fake chan announcement and fake outgoing chan upd
for own channels into db (to make path finding work). See Peer.add_own_channel().
Now, instead of above, we pass a "my_channels" param to the relevant ChannelDB methods.
This commit is contained in:
SomberNight 2020-02-17 20:38:41 +01:00
parent 7d65fe1ba3
commit 46d8080c76
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
6 changed files with 188 additions and 149 deletions

View file

@ -39,9 +39,11 @@ from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enab
from .logging import Logger from .logging import Logger
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
from .lnmsg import decode_msg
if TYPE_CHECKING: if TYPE_CHECKING:
from .network import Network from .network import Network
from .lnchannel import Channel
class UnknownEvenFeatureBits(Exception): pass class UnknownEvenFeatureBits(Exception): pass
@ -63,7 +65,7 @@ class ChannelInfo(NamedTuple):
capacity_sat: Optional[int] capacity_sat: Optional[int]
@staticmethod @staticmethod
def from_msg(payload): def from_msg(payload: dict) -> 'ChannelInfo':
features = int.from_bytes(payload['features'], 'big') features = int.from_bytes(payload['features'], 'big')
validate_features(features) validate_features(features)
channel_id = payload['short_channel_id'] channel_id = payload['short_channel_id']
@ -78,6 +80,11 @@ class ChannelInfo(NamedTuple):
capacity_sat = capacity_sat capacity_sat = capacity_sat
) )
@staticmethod
def from_raw_msg(raw: bytes) -> 'ChannelInfo':
payload_dict = decode_msg(raw)[1]
return ChannelInfo.from_msg(payload_dict)
class Policy(NamedTuple): class Policy(NamedTuple):
key: bytes key: bytes
@ -91,7 +98,7 @@ class Policy(NamedTuple):
timestamp: int timestamp: int
@staticmethod @staticmethod
def from_msg(payload): def from_msg(payload: dict) -> 'Policy':
return Policy( return Policy(
key = payload['short_channel_id'] + payload['start_node'], key = payload['short_channel_id'] + payload['start_node'],
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"), cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
@ -248,11 +255,11 @@ class ChannelDB(SqlDB):
self.ca_verifier = LNChannelVerifier(network, self) self.ca_verifier = LNChannelVerifier(network, self)
# initialized in load_data # initialized in load_data
self._channels = {} # type: Dict[bytes, ChannelInfo] self._channels = {} # type: Dict[bytes, ChannelInfo]
self._policies = {} self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy
self._nodes = {} self._nodes = {}
# node_id -> (host, port, ts) # node_id -> (host, port, ts)
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]] self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
self._channels_for_node = defaultdict(set) self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
self.data_loaded = asyncio.Event() self.data_loaded = asyncio.Event()
self.network = network # only for callback self.network = network # only for callback
@ -495,17 +502,6 @@ class ChannelDB(SqlDB):
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts() self.update_counts()
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[Policy]:
if not start_node_id or not short_channel_id: return None
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None:
return self.get_policy_for_node(short_channel_id, start_node_id)
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg:
return None
return Policy.from_msg(msg) # won't actually be written to DB
def get_old_policies(self, delta): def get_old_policies(self, delta):
now = int(time.time()) now = int(time.time())
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta) return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
@ -587,12 +583,56 @@ class ChannelDB(SqlDB):
out.add(short_channel_id) out.add(short_channel_id)
self.logger.info(f'semi-orphaned: {len(out)}') self.logger.info(f'semi-orphaned: {len(out)}')
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
return self._policies.get((node_id, short_channel_id)) my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None: # publicly announced channel
policy = self._policies.get((node_id, short_channel_id))
if policy:
return policy
else: # private channel
chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
if chan_upd_dict:
return Policy.from_msg(chan_upd_dict)
# check if it's one of our own channels
if not my_channels:
return
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
if not chan:
return
if node_id == chan.node_id: # incoming direction (to us)
remote_update_raw = chan.get_remote_update()
if not remote_update_raw:
return
now = int(time.time())
remote_update_decoded = decode_msg(remote_update_raw)[1]
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
remote_update_decoded['start_node'] = node_id
return Policy.from_msg(remote_update_decoded)
elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
local_update_decoded['start_node'] = node_id
return Policy.from_msg(local_update_decoded)
def get_channel_info(self, channel_id: bytes) -> ChannelInfo: def get_channel_info(self, short_channel_id: bytes, *,
return self._channels.get(channel_id) my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]:
ret = self._channels.get(short_channel_id)
if ret:
return ret
# check if it's one of our own channels
if not my_channels:
return
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
return ci._replace(capacity_sat=chan.constraints.capacity)
def get_channels_for_node(self, node_id) -> Set[bytes]: def get_channels_for_node(self, node_id: bytes, *,
"""Returns the set of channels that have node_id as one of the endpoints.""" my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
return self._channels_for_node.get(node_id) or set() """Returns the set of short channel IDs where node_id is one of the channel participants."""
relevant_channels = self._channels_for_node.get(node_id) or set()
relevant_channels = set(relevant_channels) # copy
# add our own channels # TODO maybe slow?
for chan in (my_channels.values() or []):
if node_id in (chan.node_id, chan.get_local_pubkey()):
relevant_channels.add(chan.short_channel_id)
return relevant_channels

View file

@ -32,13 +32,14 @@ import time
import threading import threading
from . import ecc from . import ecc
from . import constants
from .util import bfh, bh2u from .util import bfh, bh2u
from .bitcoin import redeem_script_to_address from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d from .crypto import sha256, sha256d
from .transaction import Transaction, PartialTransaction from .transaction import Transaction, PartialTransaction
from .logging import Logger from .logging import Logger
from .lnonion import decode_onion_error from .lnonion import decode_onion_error
from . import lnutil
from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints, from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints,
get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx, get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx,
sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey, sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey,
@ -47,10 +48,10 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey
funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs, funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script, ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script,
ShortChannelID, map_htlcs_to_ctx_output_idxs) ShortChannelID, map_htlcs_to_ctx_output_idxs)
from .lnutil import FeeUpdate
from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo
from .lnhtlc import HTLCManager from .lnhtlc import HTLCManager
from .lnmsg import encode_msg, decode_msg
if TYPE_CHECKING: if TYPE_CHECKING:
from .lnworker import LNWallet from .lnworker import LNWallet
@ -136,7 +137,6 @@ class Channel(Logger):
self.funding_outpoint = state["funding_outpoint"] self.funding_outpoint = state["funding_outpoint"]
self.node_id = bfh(state["node_id"]) self.node_id = bfh(state["node_id"])
self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"]) self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
self.short_channel_id_predicted = self.short_channel_id
self.onion_keys = state['onion_keys'] self.onion_keys = state['onion_keys']
self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
@ -144,6 +144,7 @@ class Channel(Logger):
self.peer_state = peer_states.DISCONNECTED self.peer_state = peer_states.DISCONNECTED
self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]] self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
self._outgoing_channel_update = None # type: Optional[bytes] self._outgoing_channel_update = None # type: Optional[bytes]
self._chan_ann_without_sigs = None # type: Optional[bytes]
self.revocation_store = RevocationStore(state["revocation_store"]) self.revocation_store = RevocationStore(state["revocation_store"])
def set_onion_key(self, key, value): def set_onion_key(self, key, value):
@ -158,12 +159,77 @@ class Channel(Logger):
def get_data_loss_protect_remote_pcp(self, key): def get_data_loss_protect_remote_pcp(self, key):
return self.data_loss_protect_remote_pcp.get(key) return self.data_loss_protect_remote_pcp.get(key)
def set_remote_update(self, raw): def get_local_pubkey(self) -> bytes:
if not self.lnworker:
raise Exception('lnworker not set for channel!')
return self.lnworker.node_keypair.pubkey
def set_remote_update(self, raw: bytes) -> None:
self.storage['remote_update'] = raw.hex() self.storage['remote_update'] = raw.hex()
def get_remote_update(self): def get_remote_update(self) -> Optional[bytes]:
return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None
def get_outgoing_gossip_channel_update(self) -> bytes:
if self._outgoing_channel_update is not None:
return self._outgoing_channel_update
if not self.lnworker:
raise Exception('lnworker not set for channel!')
sorted_node_ids = list(sorted([self.node_id, self.get_local_pubkey()]))
channel_flags = b'\x00' if sorted_node_ids[0] == self.get_local_pubkey() else b'\x01'
now = int(time.time())
htlc_maximum_msat = min(self.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * self.constraints.capacity)
chan_upd = encode_msg(
"channel_update",
short_channel_id=self.short_channel_id,
channel_flags=channel_flags,
message_flags=b'\x01',
cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
htlc_minimum_msat=self.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
chain_hash=constants.net.rev_genesis_bytes(),
timestamp=now.to_bytes(4, byteorder="big"),
)
sighash = sha256d(chan_upd[2 + 64:])
sig = ecc.ECPrivkey(self.lnworker.node_keypair.privkey).sign(sighash, ecc.sig_string_from_r_and_s)
message_type, payload = decode_msg(chan_upd)
payload['signature'] = sig
chan_upd = encode_msg(message_type, **payload)
self._outgoing_channel_update = chan_upd
return chan_upd
def construct_channel_announcement_without_sigs(self) -> bytes:
if self._chan_ann_without_sigs is not None:
return self._chan_ann_without_sigs
if not self.lnworker:
raise Exception('lnworker not set for channel!')
bitcoin_keys = [self.config[REMOTE].multisig_key.pubkey,
self.config[LOCAL].multisig_key.pubkey]
node_ids = [self.node_id, self.get_local_pubkey()]
sorted_node_ids = list(sorted(node_ids))
if sorted_node_ids != node_ids:
node_ids = sorted_node_ids
bitcoin_keys.reverse()
chan_ann = encode_msg("channel_announcement",
len=0,
features=b'',
chain_hash=constants.net.rev_genesis_bytes(),
short_channel_id=self.short_channel_id,
node_id_1=node_ids[0],
node_id_2=node_ids[1],
bitcoin_key_1=bitcoin_keys[0],
bitcoin_key_2=bitcoin_keys[1]
)
self._chan_ann_without_sigs = chan_ann
return chan_ann
def set_short_channel_id(self, short_id): def set_short_channel_id(self, short_id):
self.short_channel_id = short_id self.short_channel_id = short_id
self.storage["short_channel_id"] = short_id self.storage["short_channel_id"] = short_id

View file

@ -953,112 +953,25 @@ class Peer(Logger):
assert chan.config[LOCAL].funding_locked_received assert chan.config[LOCAL].funding_locked_received
chan.set_state(channel_states.OPEN) chan.set_state(channel_states.OPEN)
self.network.trigger_callback('channel', chan) self.network.trigger_callback('channel', chan)
self.add_own_channel(chan) # peer may have sent us a channel update for the incoming direction previously
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
if pending_channel_update:
chan.set_remote_update(pending_channel_update['raw'])
self.logger.info(f"CHANNEL OPENING COMPLETED for {scid}") self.logger.info(f"CHANNEL OPENING COMPLETED for {scid}")
forwarding_enabled = self.network.config.get('lightning_forward_payments', False) forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
if forwarding_enabled: if forwarding_enabled:
# send channel_update of outgoing edge to peer, # send channel_update of outgoing edge to peer,
# so that channel can be used to to receive payments # so that channel can be used to to receive payments
self.logger.info(f"sending channel update for outgoing edge of {scid}") self.logger.info(f"sending channel update for outgoing edge of {scid}")
chan_upd = self.get_outgoing_gossip_channel_update_for_chan(chan) chan_upd = chan.get_outgoing_gossip_channel_update()
self.transport.send_bytes(chan_upd) self.transport.send_bytes(chan_upd)
def add_own_channel(self, chan):
# add channel to database
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
sorted_node_ids = list(sorted(self.node_ids))
if sorted_node_ids != self.node_ids:
bitcoin_keys.reverse()
# note: we inject a channel announcement, and a channel update (for outgoing direction)
# This is atm needed for
# - finding routes
# - the ChanAnn is needed so that we can anchor to it a future ChanUpd
# that the remote sends, even if the channel was not announced
# (from BOLT-07: "MAY create a channel_update to communicate the channel
# parameters to the final node, even though the channel has not yet been announced")
self.channel_db.add_channel_announcement(
{
"short_channel_id": chan.short_channel_id,
"node_id_1": sorted_node_ids[0],
"node_id_2": sorted_node_ids[1],
'chain_hash': constants.net.rev_genesis_bytes(),
'len': b'\x00\x00',
'features': b'',
'bitcoin_key_1': bitcoin_keys[0],
'bitcoin_key_2': bitcoin_keys[1]
},
trusted=True)
# only inject outgoing direction:
chan_upd_bytes = self.get_outgoing_gossip_channel_update_for_chan(chan)
chan_upd_payload = decode_msg(chan_upd_bytes)[1]
self.channel_db.add_channel_update(chan_upd_payload)
# peer may have sent us a channel update for the incoming direction previously
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
if pending_channel_update:
chan.set_remote_update(pending_channel_update['raw'])
# add remote update with a fresh timestamp
if chan.get_remote_update():
now = int(time.time())
remote_update_decoded = decode_msg(chan.get_remote_update())[1]
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
self.channel_db.add_channel_update(remote_update_decoded)
def get_outgoing_gossip_channel_update_for_chan(self, chan: Channel) -> bytes:
if chan._outgoing_channel_update is not None:
return chan._outgoing_channel_update
sorted_node_ids = list(sorted(self.node_ids))
channel_flags = b'\x00' if sorted_node_ids[0] == privkey_to_pubkey(self.privkey) else b'\x01'
now = int(time.time())
htlc_maximum_msat = min(chan.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * chan.constraints.capacity)
chan_upd = encode_msg(
"channel_update",
short_channel_id=chan.short_channel_id,
channel_flags=channel_flags,
message_flags=b'\x01',
cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
htlc_minimum_msat=chan.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
chain_hash=constants.net.rev_genesis_bytes(),
timestamp=now.to_bytes(4, byteorder="big"),
)
sighash = sha256d(chan_upd[2 + 64:])
sig = ecc.ECPrivkey(self.privkey).sign(sighash, sig_string_from_r_and_s)
message_type, payload = decode_msg(chan_upd)
payload['signature'] = sig
chan_upd = encode_msg(message_type, **payload)
chan._outgoing_channel_update = chan_upd
return chan_upd
def send_announcement_signatures(self, chan: Channel): def send_announcement_signatures(self, chan: Channel):
chan_ann = chan.construct_channel_announcement_without_sigs()
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, preimage = chan_ann[256+2:]
chan.config[LOCAL].multisig_key.pubkey] msg_hash = sha256d(preimage)
bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(msg_hash, sig_string_from_r_and_s)
sorted_node_ids = list(sorted(self.node_ids)) node_signature = ecc.ECPrivkey(self.privkey).sign(msg_hash, sig_string_from_r_and_s)
if sorted_node_ids != self.node_ids:
node_ids = sorted_node_ids
bitcoin_keys.reverse()
else:
node_ids = self.node_ids
chan_ann = encode_msg("channel_announcement",
len=0,
#features not set (defaults to zeros)
chain_hash=constants.net.rev_genesis_bytes(),
short_channel_id=chan.short_channel_id,
node_id_1=node_ids[0],
node_id_2=node_ids[1],
bitcoin_key_1=bitcoin_keys[0],
bitcoin_key_2=bitcoin_keys[1]
)
to_hash = chan_ann[256+2:]
h = sha256d(to_hash)
bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(h, sig_string_from_r_and_s)
node_signature = ecc.ECPrivkey(self.privkey).sign(h, sig_string_from_r_and_s)
self.send_message("announcement_signatures", self.send_message("announcement_signatures",
channel_id=chan.channel_id, channel_id=chan.channel_id,
short_channel_id=chan.short_channel_id, short_channel_id=chan.short_channel_id,
@ -1066,7 +979,7 @@ class Peer(Logger):
bitcoin_signature=bitcoin_signature bitcoin_signature=bitcoin_signature
) )
return h, node_signature, bitcoin_signature return msg_hash, node_signature, bitcoin_signature
def on_update_fail_htlc(self, payload): def on_update_fail_htlc(self, payload):
channel_id = payload["channel_id"] channel_id = payload["channel_id"]
@ -1255,7 +1168,7 @@ class Peer(Logger):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
return return
outgoing_chan_upd = self.get_outgoing_gossip_channel_update_for_chan(next_chan)[2:] outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:]
outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big") outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
if next_chan.get_state() != channel_states.OPEN: if next_chan.get_state() != channel_states.OPEN:
self.logger.info(f"cannot forward htlc. next_chan not OPEN: {next_chan_scid} in state {next_chan.get_state()}") self.logger.info(f"cannot forward htlc. next_chan not OPEN: {next_chan_scid} in state {next_chan.get_state()}")

View file

@ -129,18 +129,20 @@ class LNPathFinder(Logger):
self.blacklist.add(short_channel_id) self.blacklist.add(short_channel_id)
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
payment_amt_msat: int, ignore_costs=False, is_mine=False) -> Tuple[float, int]: payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]:
"""Heuristic cost of going through a channel. """Heuristic cost of going through a channel.
Returns (heuristic_cost, fee_for_edge_msat). Returns (heuristic_cost, fee_for_edge_msat).
""" """
channel_info = self.channel_db.get_channel_info(short_channel_id) channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels)
if channel_info is None: if channel_info is None:
return float('inf'), 0 return float('inf'), 0
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node) channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels)
if channel_policy is None: if channel_policy is None:
return float('inf'), 0 return float('inf'), 0
# channels that did not publish both policies often return temporary channel failure # channels that did not publish both policies often return temporary channel failure
if self.channel_db.get_policy_for_node(short_channel_id, end_node) is None and not is_mine: if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \
and not is_mine:
return float('inf'), 0 return float('inf'), 0
if channel_policy.is_disabled(): if channel_policy.is_disabled():
return float('inf'), 0 return float('inf'), 0
@ -164,8 +166,9 @@ class LNPathFinder(Logger):
@profiler @profiler
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
invoice_amount_msat: int, invoice_amount_msat: int, *,
my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]: my_channels: Dict[ShortChannelID, 'Channel'] = None) \
-> Optional[Sequence[Tuple[bytes, bytes]]]:
"""Return a path from nodeA to nodeB. """Return a path from nodeA to nodeB.
Returns a list of (node_id, short_channel_id) representing a path. Returns a list of (node_id, short_channel_id) representing a path.
@ -175,8 +178,7 @@ class LNPathFinder(Logger):
assert type(nodeA) is bytes assert type(nodeA) is bytes
assert type(nodeB) is bytes assert type(nodeB) is bytes
assert type(invoice_amount_msat) is int assert type(invoice_amount_msat) is int
if my_channels is None: my_channels = [] if my_channels is None: my_channels = {}
my_channels = {chan.short_channel_id: chan for chan in my_channels}
# FIXME paths cannot be longer than 20 edges (onion packet)... # FIXME paths cannot be longer than 20 edges (onion packet)...
@ -204,7 +206,8 @@ class LNPathFinder(Logger):
end_node=edge_endnode, end_node=edge_endnode,
payment_amt_msat=amount_msat, payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA), ignore_costs=(edge_startnode == nodeA),
is_mine=is_mine) is_mine=is_mine,
my_channels=my_channels)
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]: if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour distance_from_start[edge_startnode] = alt_dist_to_neighbour
@ -222,11 +225,11 @@ class LNPathFinder(Logger):
# so instead of decreasing priorities, we add items again into the queue. # so instead of decreasing priorities, we add items again into the queue.
# so there are duplicates in the queue, that we discard now: # so there are duplicates in the queue, that we discard now:
continue continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode): for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
assert isinstance(edge_channel_id, bytes) assert isinstance(edge_channel_id, bytes)
if edge_channel_id in self.blacklist: if edge_channel_id in self.blacklist:
continue continue
channel_info = self.channel_db.get_channel_info(edge_channel_id) channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
inspect_edge() inspect_edge()
else: else:
@ -241,14 +244,17 @@ class LNPathFinder(Logger):
edge_startnode = edge_endnode edge_startnode = edge_endnode
return path return path
def create_route_from_path(self, path, from_node_id: bytes) -> LNPaymentRoute: def create_route_from_path(self, path, from_node_id: bytes, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
assert isinstance(from_node_id, bytes) assert isinstance(from_node_id, bytes)
if path is None: if path is None:
raise Exception('cannot create route from None path') raise Exception('cannot create route from None path')
route = [] route = []
prev_node_id = from_node_id prev_node_id = from_node_id
for node_id, short_channel_id in path: for node_id, short_channel_id in path:
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
node_id=prev_node_id,
my_channels=my_channels)
if channel_policy is None: if channel_policy is None:
raise NoChannelPolicy(short_channel_id) raise NoChannelPolicy(short_channel_id)
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id)) route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))

View file

@ -942,16 +942,20 @@ class LNWallet(LNWorker):
random.shuffle(r_tags) random.shuffle(r_tags)
with self.lock: with self.lock:
channels = list(self.channels.values()) channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
for private_route in r_tags: for private_route in r_tags:
if len(private_route) == 0: if len(private_route) == 0:
continue continue
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
continue continue
border_node_pubkey = private_route[0][0] border_node_pubkey = private_route[0][0]
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels) path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat,
my_channels=scid_to_my_channels)
if not path: if not path:
continue continue
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
my_channels=scid_to_my_channels)
# we need to shift the node pubkey by one towards the destination: # we need to shift the node pubkey by one towards the destination:
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey] private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
private_route_rest = [edge[1:] for edge in private_route] private_route_rest = [edge[1:] for edge in private_route]
@ -961,7 +965,9 @@ class LNWallet(LNWorker):
short_channel_id = ShortChannelID(short_channel_id) short_channel_id = ShortChannelID(short_channel_id)
# if we have a routing policy for this edge in the db, that takes precedence, # if we have a routing policy for this edge in the db, that takes precedence,
# as it is likely from a previous failure # as it is likely from a previous failure
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
node_id=prev_node_id,
my_channels=scid_to_my_channels)
if channel_policy: if channel_policy:
fee_base_msat = channel_policy.fee_base_msat fee_base_msat = channel_policy.fee_base_msat
fee_proportional_millionths = channel_policy.fee_proportional_millionths fee_proportional_millionths = channel_policy.fee_proportional_millionths
@ -977,10 +983,12 @@ class LNWallet(LNWorker):
break break
# if could not find route using any hint; try without hint now # if could not find route using any hint; try without hint now
if route is None: if route is None:
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat, channels) path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat,
my_channels=scid_to_my_channels)
if not path: if not path:
raise NoPathFound() raise NoPathFound()
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
my_channels=scid_to_my_channels)
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()): if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
self.logger.info(f"rejecting insane route {route}") self.logger.info(f"rejecting insane route {route}")
raise NoPathFound() raise NoPathFound()
@ -1099,6 +1107,8 @@ class LNWallet(LNWorker):
routing_hints = [] routing_hints = []
with self.lock: with self.lock:
channels = list(self.channels.values()) channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
# note: currently we add *all* our channels; but this might be a privacy leak? # note: currently we add *all* our channels; but this might be a privacy leak?
for chan in channels: for chan in channels:
# check channel is open # check channel is open
@ -1110,7 +1120,7 @@ class LNWallet(LNWorker):
continue continue
chan_id = chan.short_channel_id chan_id = chan.short_channel_id
assert isinstance(chan_id, bytes), chan_id assert isinstance(chan_id, bytes), chan_id
channel_info = self.channel_db.get_channel_info(chan_id) channel_info = self.channel_db.get_channel_info(chan_id, my_channels=scid_to_my_channels)
# note: as a fallback, if we don't have a channel update for the # note: as a fallback, if we don't have a channel update for the
# incoming direction of our private channel, we fill the invoice with garbage. # incoming direction of our private channel, we fill the invoice with garbage.
# the sender should still be able to pay us, but will incur an extra round trip # the sender should still be able to pay us, but will incur an extra round trip
@ -1120,7 +1130,8 @@ class LNWallet(LNWorker):
cltv_expiry_delta = 1 # lnd won't even try with zero cltv_expiry_delta = 1 # lnd won't even try with zero
missing_info = True missing_info = True
if channel_info: if channel_info:
policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id) policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id,
my_channels=scid_to_my_channels)
if policy: if policy:
fee_base_msat = policy.fee_base_msat fee_base_msat = policy.fee_base_msat
fee_proportional_millionths = policy.fee_proportional_millionths fee_proportional_millionths = policy.fee_proportional_millionths

View file

@ -18,7 +18,7 @@ from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure, LnLocalFeatures from electrum.lnutil import PaymentFailure, LnLocalFeatures
from electrum.lnchannel import channel_states, peer_states from electrum.lnchannel import channel_states, peer_states, Channel
from electrum.lnrouter import LNPathFinder from electrum.lnrouter import LNPathFinder
from electrum.channel_db import ChannelDB from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound from electrum.lnworker import LNWallet, NoPathFound
@ -77,7 +77,7 @@ class MockWallet:
return False return False
class MockLNWallet: class MockLNWallet:
def __init__(self, remote_keypair, local_keypair, chan, tx_queue): def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
self.remote_keypair = remote_keypair self.remote_keypair = remote_keypair
self.node_keypair = local_keypair self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue) self.network = MockNetwork(tx_queue)
@ -88,6 +88,8 @@ class MockLNWallet:
self.localfeatures = LnLocalFeatures(0) self.localfeatures = LnLocalFeatures(0)
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
self.pending_payments = defaultdict(asyncio.Future) self.pending_payments = defaultdict(asyncio.Future)
chan.lnworker = self
chan.node_id = remote_keypair.pubkey
def get_invoice_status(self, key): def get_invoice_status(self, key):
pass pass
@ -127,6 +129,7 @@ class MockLNWallet:
_pay_to_route = LNWallet._pay_to_route _pay_to_route = LNWallet._pay_to_route
force_close_channel = LNWallet.force_close_channel force_close_channel = LNWallet.force_close_channel
get_first_timestamp = lambda self: 0 get_first_timestamp = lambda self: 0
payment_completed = LNWallet.payment_completed
class MockTransport: class MockTransport:
def __init__(self, name): def __init__(self, name):
@ -264,7 +267,7 @@ class TestPeer(ElectrumTestCase):
pay_req = self.prepare_invoice(w2) pay_req = self.prepare_invoice(w2)
async def pay(): async def pay():
result = await LNWallet._pay(w1, pay_req) result = await LNWallet._pay(w1, pay_req)
self.assertEqual(result, True) self.assertTrue(result)
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
async def f(): async def f():