mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-09-02 18:25:21 +00:00
ln gossip: don't put own channels into db; always pass them to fn calls
Previously we would put fake chan announcement and fake outgoing chan upd for own channels into db (to make path finding work). See Peer.add_own_channel(). Now, instead of above, we pass a "my_channels" param to the relevant ChannelDB methods.
This commit is contained in:
parent
7d65fe1ba3
commit
46d8080c76
6 changed files with 188 additions and 149 deletions
|
@ -39,9 +39,11 @@ from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enab
|
|||
from .logging import Logger
|
||||
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
|
||||
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
|
||||
from .lnmsg import decode_msg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import Network
|
||||
from .lnchannel import Channel
|
||||
|
||||
|
||||
class UnknownEvenFeatureBits(Exception): pass
|
||||
|
@ -63,7 +65,7 @@ class ChannelInfo(NamedTuple):
|
|||
capacity_sat: Optional[int]
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload):
|
||||
def from_msg(payload: dict) -> 'ChannelInfo':
|
||||
features = int.from_bytes(payload['features'], 'big')
|
||||
validate_features(features)
|
||||
channel_id = payload['short_channel_id']
|
||||
|
@ -78,6 +80,11 @@ class ChannelInfo(NamedTuple):
|
|||
capacity_sat = capacity_sat
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_raw_msg(raw: bytes) -> 'ChannelInfo':
|
||||
payload_dict = decode_msg(raw)[1]
|
||||
return ChannelInfo.from_msg(payload_dict)
|
||||
|
||||
|
||||
class Policy(NamedTuple):
|
||||
key: bytes
|
||||
|
@ -91,7 +98,7 @@ class Policy(NamedTuple):
|
|||
timestamp: int
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload):
|
||||
def from_msg(payload: dict) -> 'Policy':
|
||||
return Policy(
|
||||
key = payload['short_channel_id'] + payload['start_node'],
|
||||
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
|
||||
|
@ -248,11 +255,11 @@ class ChannelDB(SqlDB):
|
|||
self.ca_verifier = LNChannelVerifier(network, self)
|
||||
# initialized in load_data
|
||||
self._channels = {} # type: Dict[bytes, ChannelInfo]
|
||||
self._policies = {}
|
||||
self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy
|
||||
self._nodes = {}
|
||||
# node_id -> (host, port, ts)
|
||||
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
|
||||
self._channels_for_node = defaultdict(set)
|
||||
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
|
||||
self.data_loaded = asyncio.Event()
|
||||
self.network = network # only for callback
|
||||
|
||||
|
@ -495,17 +502,6 @@ class ChannelDB(SqlDB):
|
|||
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||
self.update_counts()
|
||||
|
||||
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
||||
short_channel_id: bytes) -> Optional[Policy]:
|
||||
if not start_node_id or not short_channel_id: return None
|
||||
channel_info = self.get_channel_info(short_channel_id)
|
||||
if channel_info is not None:
|
||||
return self.get_policy_for_node(short_channel_id, start_node_id)
|
||||
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
||||
if not msg:
|
||||
return None
|
||||
return Policy.from_msg(msg) # won't actually be written to DB
|
||||
|
||||
def get_old_policies(self, delta):
|
||||
now = int(time.time())
|
||||
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
|
||||
|
@ -587,12 +583,56 @@ class ChannelDB(SqlDB):
|
|||
out.add(short_channel_id)
|
||||
self.logger.info(f'semi-orphaned: {len(out)}')
|
||||
|
||||
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
|
||||
return self._policies.get((node_id, short_channel_id))
|
||||
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
|
||||
channel_info = self.get_channel_info(short_channel_id)
|
||||
if channel_info is not None: # publicly announced channel
|
||||
policy = self._policies.get((node_id, short_channel_id))
|
||||
if policy:
|
||||
return policy
|
||||
else: # private channel
|
||||
chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id))
|
||||
if chan_upd_dict:
|
||||
return Policy.from_msg(chan_upd_dict)
|
||||
# check if it's one of our own channels
|
||||
if not my_channels:
|
||||
return
|
||||
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
||||
if not chan:
|
||||
return
|
||||
if node_id == chan.node_id: # incoming direction (to us)
|
||||
remote_update_raw = chan.get_remote_update()
|
||||
if not remote_update_raw:
|
||||
return
|
||||
now = int(time.time())
|
||||
remote_update_decoded = decode_msg(remote_update_raw)[1]
|
||||
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
|
||||
remote_update_decoded['start_node'] = node_id
|
||||
return Policy.from_msg(remote_update_decoded)
|
||||
elif node_id == chan.get_local_pubkey(): # outgoing direction (from us)
|
||||
local_update_decoded = decode_msg(chan.get_outgoing_gossip_channel_update())[1]
|
||||
local_update_decoded['start_node'] = node_id
|
||||
return Policy.from_msg(local_update_decoded)
|
||||
|
||||
def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
|
||||
return self._channels.get(channel_id)
|
||||
def get_channel_info(self, short_channel_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional[ChannelInfo]:
|
||||
ret = self._channels.get(short_channel_id)
|
||||
if ret:
|
||||
return ret
|
||||
# check if it's one of our own channels
|
||||
if not my_channels:
|
||||
return
|
||||
chan = my_channels.get(short_channel_id) # type: Optional[Channel]
|
||||
ci = ChannelInfo.from_raw_msg(chan.construct_channel_announcement_without_sigs())
|
||||
return ci._replace(capacity_sat=chan.constraints.capacity)
|
||||
|
||||
def get_channels_for_node(self, node_id) -> Set[bytes]:
|
||||
"""Returns the set of channels that have node_id as one of the endpoints."""
|
||||
return self._channels_for_node.get(node_id) or set()
|
||||
def get_channels_for_node(self, node_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Set[bytes]:
|
||||
"""Returns the set of short channel IDs where node_id is one of the channel participants."""
|
||||
relevant_channels = self._channels_for_node.get(node_id) or set()
|
||||
relevant_channels = set(relevant_channels) # copy
|
||||
# add our own channels # TODO maybe slow?
|
||||
for chan in (my_channels.values() or []):
|
||||
if node_id in (chan.node_id, chan.get_local_pubkey()):
|
||||
relevant_channels.add(chan.short_channel_id)
|
||||
return relevant_channels
|
||||
|
|
|
@ -32,13 +32,14 @@ import time
|
|||
import threading
|
||||
|
||||
from . import ecc
|
||||
from . import constants
|
||||
from .util import bfh, bh2u
|
||||
from .bitcoin import redeem_script_to_address
|
||||
from .crypto import sha256, sha256d
|
||||
from .transaction import Transaction, PartialTransaction
|
||||
from .logging import Logger
|
||||
|
||||
from .lnonion import decode_onion_error
|
||||
from . import lnutil
|
||||
from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKeypair, ChannelConstraints,
|
||||
get_per_commitment_secret_from_seed, secret_to_pubkey, derive_privkey, make_closing_tx,
|
||||
sign_and_get_sig_string, RevocationStore, derive_blinded_pubkey, Direction, derive_pubkey,
|
||||
|
@ -47,10 +48,10 @@ from .lnutil import (Outpoint, LocalConfig, RemoteConfig, Keypair, OnlyPubkeyKey
|
|||
funding_output_script, SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, make_commitment_outputs,
|
||||
ScriptHtlc, PaymentFailure, calc_onchain_fees, RemoteMisbehaving, make_htlc_output_witness_script,
|
||||
ShortChannelID, map_htlcs_to_ctx_output_idxs)
|
||||
from .lnutil import FeeUpdate
|
||||
from .lnsweep import create_sweeptxs_for_our_ctx, create_sweeptxs_for_their_ctx
|
||||
from .lnsweep import create_sweeptx_for_their_revoked_htlc, SweepInfo
|
||||
from .lnhtlc import HTLCManager
|
||||
from .lnmsg import encode_msg, decode_msg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lnworker import LNWallet
|
||||
|
@ -136,7 +137,6 @@ class Channel(Logger):
|
|||
self.funding_outpoint = state["funding_outpoint"]
|
||||
self.node_id = bfh(state["node_id"])
|
||||
self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
|
||||
self.short_channel_id_predicted = self.short_channel_id
|
||||
self.onion_keys = state['onion_keys']
|
||||
self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
|
||||
self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
|
||||
|
@ -144,6 +144,7 @@ class Channel(Logger):
|
|||
self.peer_state = peer_states.DISCONNECTED
|
||||
self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
|
||||
self._outgoing_channel_update = None # type: Optional[bytes]
|
||||
self._chan_ann_without_sigs = None # type: Optional[bytes]
|
||||
self.revocation_store = RevocationStore(state["revocation_store"])
|
||||
|
||||
def set_onion_key(self, key, value):
|
||||
|
@ -158,12 +159,77 @@ class Channel(Logger):
|
|||
def get_data_loss_protect_remote_pcp(self, key):
|
||||
return self.data_loss_protect_remote_pcp.get(key)
|
||||
|
||||
def set_remote_update(self, raw):
|
||||
def get_local_pubkey(self) -> bytes:
|
||||
if not self.lnworker:
|
||||
raise Exception('lnworker not set for channel!')
|
||||
return self.lnworker.node_keypair.pubkey
|
||||
|
||||
def set_remote_update(self, raw: bytes) -> None:
|
||||
self.storage['remote_update'] = raw.hex()
|
||||
|
||||
def get_remote_update(self):
|
||||
def get_remote_update(self) -> Optional[bytes]:
|
||||
return bfh(self.storage.get('remote_update')) if self.storage.get('remote_update') else None
|
||||
|
||||
def get_outgoing_gossip_channel_update(self) -> bytes:
|
||||
if self._outgoing_channel_update is not None:
|
||||
return self._outgoing_channel_update
|
||||
if not self.lnworker:
|
||||
raise Exception('lnworker not set for channel!')
|
||||
sorted_node_ids = list(sorted([self.node_id, self.get_local_pubkey()]))
|
||||
channel_flags = b'\x00' if sorted_node_ids[0] == self.get_local_pubkey() else b'\x01'
|
||||
now = int(time.time())
|
||||
htlc_maximum_msat = min(self.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * self.constraints.capacity)
|
||||
|
||||
chan_upd = encode_msg(
|
||||
"channel_update",
|
||||
short_channel_id=self.short_channel_id,
|
||||
channel_flags=channel_flags,
|
||||
message_flags=b'\x01',
|
||||
cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
|
||||
htlc_minimum_msat=self.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
|
||||
htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
|
||||
fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
|
||||
fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
timestamp=now.to_bytes(4, byteorder="big"),
|
||||
)
|
||||
sighash = sha256d(chan_upd[2 + 64:])
|
||||
sig = ecc.ECPrivkey(self.lnworker.node_keypair.privkey).sign(sighash, ecc.sig_string_from_r_and_s)
|
||||
message_type, payload = decode_msg(chan_upd)
|
||||
payload['signature'] = sig
|
||||
chan_upd = encode_msg(message_type, **payload)
|
||||
|
||||
self._outgoing_channel_update = chan_upd
|
||||
return chan_upd
|
||||
|
||||
def construct_channel_announcement_without_sigs(self) -> bytes:
|
||||
if self._chan_ann_without_sigs is not None:
|
||||
return self._chan_ann_without_sigs
|
||||
if not self.lnworker:
|
||||
raise Exception('lnworker not set for channel!')
|
||||
|
||||
bitcoin_keys = [self.config[REMOTE].multisig_key.pubkey,
|
||||
self.config[LOCAL].multisig_key.pubkey]
|
||||
node_ids = [self.node_id, self.get_local_pubkey()]
|
||||
sorted_node_ids = list(sorted(node_ids))
|
||||
if sorted_node_ids != node_ids:
|
||||
node_ids = sorted_node_ids
|
||||
bitcoin_keys.reverse()
|
||||
|
||||
chan_ann = encode_msg("channel_announcement",
|
||||
len=0,
|
||||
features=b'',
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
short_channel_id=self.short_channel_id,
|
||||
node_id_1=node_ids[0],
|
||||
node_id_2=node_ids[1],
|
||||
bitcoin_key_1=bitcoin_keys[0],
|
||||
bitcoin_key_2=bitcoin_keys[1]
|
||||
)
|
||||
|
||||
self._chan_ann_without_sigs = chan_ann
|
||||
return chan_ann
|
||||
|
||||
def set_short_channel_id(self, short_id):
|
||||
self.short_channel_id = short_id
|
||||
self.storage["short_channel_id"] = short_id
|
||||
|
|
|
@ -953,112 +953,25 @@ class Peer(Logger):
|
|||
assert chan.config[LOCAL].funding_locked_received
|
||||
chan.set_state(channel_states.OPEN)
|
||||
self.network.trigger_callback('channel', chan)
|
||||
self.add_own_channel(chan)
|
||||
# peer may have sent us a channel update for the incoming direction previously
|
||||
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
|
||||
if pending_channel_update:
|
||||
chan.set_remote_update(pending_channel_update['raw'])
|
||||
self.logger.info(f"CHANNEL OPENING COMPLETED for {scid}")
|
||||
forwarding_enabled = self.network.config.get('lightning_forward_payments', False)
|
||||
if forwarding_enabled:
|
||||
# send channel_update of outgoing edge to peer,
|
||||
# so that channel can be used to to receive payments
|
||||
self.logger.info(f"sending channel update for outgoing edge of {scid}")
|
||||
chan_upd = self.get_outgoing_gossip_channel_update_for_chan(chan)
|
||||
chan_upd = chan.get_outgoing_gossip_channel_update()
|
||||
self.transport.send_bytes(chan_upd)
|
||||
|
||||
def add_own_channel(self, chan):
|
||||
# add channel to database
|
||||
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
|
||||
sorted_node_ids = list(sorted(self.node_ids))
|
||||
if sorted_node_ids != self.node_ids:
|
||||
bitcoin_keys.reverse()
|
||||
# note: we inject a channel announcement, and a channel update (for outgoing direction)
|
||||
# This is atm needed for
|
||||
# - finding routes
|
||||
# - the ChanAnn is needed so that we can anchor to it a future ChanUpd
|
||||
# that the remote sends, even if the channel was not announced
|
||||
# (from BOLT-07: "MAY create a channel_update to communicate the channel
|
||||
# parameters to the final node, even though the channel has not yet been announced")
|
||||
self.channel_db.add_channel_announcement(
|
||||
{
|
||||
"short_channel_id": chan.short_channel_id,
|
||||
"node_id_1": sorted_node_ids[0],
|
||||
"node_id_2": sorted_node_ids[1],
|
||||
'chain_hash': constants.net.rev_genesis_bytes(),
|
||||
'len': b'\x00\x00',
|
||||
'features': b'',
|
||||
'bitcoin_key_1': bitcoin_keys[0],
|
||||
'bitcoin_key_2': bitcoin_keys[1]
|
||||
},
|
||||
trusted=True)
|
||||
# only inject outgoing direction:
|
||||
chan_upd_bytes = self.get_outgoing_gossip_channel_update_for_chan(chan)
|
||||
chan_upd_payload = decode_msg(chan_upd_bytes)[1]
|
||||
self.channel_db.add_channel_update(chan_upd_payload)
|
||||
# peer may have sent us a channel update for the incoming direction previously
|
||||
pending_channel_update = self.orphan_channel_updates.get(chan.short_channel_id)
|
||||
if pending_channel_update:
|
||||
chan.set_remote_update(pending_channel_update['raw'])
|
||||
# add remote update with a fresh timestamp
|
||||
if chan.get_remote_update():
|
||||
now = int(time.time())
|
||||
remote_update_decoded = decode_msg(chan.get_remote_update())[1]
|
||||
remote_update_decoded['timestamp'] = now.to_bytes(4, byteorder="big")
|
||||
self.channel_db.add_channel_update(remote_update_decoded)
|
||||
|
||||
def get_outgoing_gossip_channel_update_for_chan(self, chan: Channel) -> bytes:
|
||||
if chan._outgoing_channel_update is not None:
|
||||
return chan._outgoing_channel_update
|
||||
sorted_node_ids = list(sorted(self.node_ids))
|
||||
channel_flags = b'\x00' if sorted_node_ids[0] == privkey_to_pubkey(self.privkey) else b'\x01'
|
||||
now = int(time.time())
|
||||
htlc_maximum_msat = min(chan.config[REMOTE].max_htlc_value_in_flight_msat, 1000 * chan.constraints.capacity)
|
||||
|
||||
chan_upd = encode_msg(
|
||||
"channel_update",
|
||||
short_channel_id=chan.short_channel_id,
|
||||
channel_flags=channel_flags,
|
||||
message_flags=b'\x01',
|
||||
cltv_expiry_delta=lnutil.NBLOCK_OUR_CLTV_EXPIRY_DELTA.to_bytes(2, byteorder="big"),
|
||||
htlc_minimum_msat=chan.config[REMOTE].htlc_minimum_msat.to_bytes(8, byteorder="big"),
|
||||
htlc_maximum_msat=htlc_maximum_msat.to_bytes(8, byteorder="big"),
|
||||
fee_base_msat=lnutil.OUR_FEE_BASE_MSAT.to_bytes(4, byteorder="big"),
|
||||
fee_proportional_millionths=lnutil.OUR_FEE_PROPORTIONAL_MILLIONTHS.to_bytes(4, byteorder="big"),
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
timestamp=now.to_bytes(4, byteorder="big"),
|
||||
)
|
||||
sighash = sha256d(chan_upd[2 + 64:])
|
||||
sig = ecc.ECPrivkey(self.privkey).sign(sighash, sig_string_from_r_and_s)
|
||||
message_type, payload = decode_msg(chan_upd)
|
||||
payload['signature'] = sig
|
||||
chan_upd = encode_msg(message_type, **payload)
|
||||
|
||||
chan._outgoing_channel_update = chan_upd
|
||||
return chan_upd
|
||||
|
||||
def send_announcement_signatures(self, chan: Channel):
|
||||
|
||||
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
|
||||
chan.config[LOCAL].multisig_key.pubkey]
|
||||
|
||||
sorted_node_ids = list(sorted(self.node_ids))
|
||||
if sorted_node_ids != self.node_ids:
|
||||
node_ids = sorted_node_ids
|
||||
bitcoin_keys.reverse()
|
||||
else:
|
||||
node_ids = self.node_ids
|
||||
|
||||
chan_ann = encode_msg("channel_announcement",
|
||||
len=0,
|
||||
#features not set (defaults to zeros)
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
short_channel_id=chan.short_channel_id,
|
||||
node_id_1=node_ids[0],
|
||||
node_id_2=node_ids[1],
|
||||
bitcoin_key_1=bitcoin_keys[0],
|
||||
bitcoin_key_2=bitcoin_keys[1]
|
||||
)
|
||||
to_hash = chan_ann[256+2:]
|
||||
h = sha256d(to_hash)
|
||||
bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(h, sig_string_from_r_and_s)
|
||||
node_signature = ecc.ECPrivkey(self.privkey).sign(h, sig_string_from_r_and_s)
|
||||
chan_ann = chan.construct_channel_announcement_without_sigs()
|
||||
preimage = chan_ann[256+2:]
|
||||
msg_hash = sha256d(preimage)
|
||||
bitcoin_signature = ecc.ECPrivkey(chan.config[LOCAL].multisig_key.privkey).sign(msg_hash, sig_string_from_r_and_s)
|
||||
node_signature = ecc.ECPrivkey(self.privkey).sign(msg_hash, sig_string_from_r_and_s)
|
||||
self.send_message("announcement_signatures",
|
||||
channel_id=chan.channel_id,
|
||||
short_channel_id=chan.short_channel_id,
|
||||
|
@ -1066,7 +979,7 @@ class Peer(Logger):
|
|||
bitcoin_signature=bitcoin_signature
|
||||
)
|
||||
|
||||
return h, node_signature, bitcoin_signature
|
||||
return msg_hash, node_signature, bitcoin_signature
|
||||
|
||||
def on_update_fail_htlc(self, payload):
|
||||
channel_id = payload["channel_id"]
|
||||
|
@ -1255,7 +1168,7 @@ class Peer(Logger):
|
|||
reason = OnionRoutingFailureMessage(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
|
||||
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
|
||||
return
|
||||
outgoing_chan_upd = self.get_outgoing_gossip_channel_update_for_chan(next_chan)[2:]
|
||||
outgoing_chan_upd = next_chan.get_outgoing_gossip_channel_update()[2:]
|
||||
outgoing_chan_upd_len = len(outgoing_chan_upd).to_bytes(2, byteorder="big")
|
||||
if next_chan.get_state() != channel_states.OPEN:
|
||||
self.logger.info(f"cannot forward htlc. next_chan not OPEN: {next_chan_scid} in state {next_chan.get_state()}")
|
||||
|
|
|
@ -129,18 +129,20 @@ class LNPathFinder(Logger):
|
|||
self.blacklist.add(short_channel_id)
|
||||
|
||||
def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes,
|
||||
payment_amt_msat: int, ignore_costs=False, is_mine=False) -> Tuple[float, int]:
|
||||
payment_amt_msat: int, ignore_costs=False, is_mine=False, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Tuple[float, int]:
|
||||
"""Heuristic cost of going through a channel.
|
||||
Returns (heuristic_cost, fee_for_edge_msat).
|
||||
"""
|
||||
channel_info = self.channel_db.get_channel_info(short_channel_id)
|
||||
channel_info = self.channel_db.get_channel_info(short_channel_id, my_channels=my_channels)
|
||||
if channel_info is None:
|
||||
return float('inf'), 0
|
||||
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node)
|
||||
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node, my_channels=my_channels)
|
||||
if channel_policy is None:
|
||||
return float('inf'), 0
|
||||
# channels that did not publish both policies often return temporary channel failure
|
||||
if self.channel_db.get_policy_for_node(short_channel_id, end_node) is None and not is_mine:
|
||||
if self.channel_db.get_policy_for_node(short_channel_id, end_node, my_channels=my_channels) is None \
|
||||
and not is_mine:
|
||||
return float('inf'), 0
|
||||
if channel_policy.is_disabled():
|
||||
return float('inf'), 0
|
||||
|
@ -164,8 +166,9 @@ class LNPathFinder(Logger):
|
|||
|
||||
@profiler
|
||||
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
||||
invoice_amount_msat: int,
|
||||
my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
|
||||
invoice_amount_msat: int, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) \
|
||||
-> Optional[Sequence[Tuple[bytes, bytes]]]:
|
||||
"""Return a path from nodeA to nodeB.
|
||||
|
||||
Returns a list of (node_id, short_channel_id) representing a path.
|
||||
|
@ -175,8 +178,7 @@ class LNPathFinder(Logger):
|
|||
assert type(nodeA) is bytes
|
||||
assert type(nodeB) is bytes
|
||||
assert type(invoice_amount_msat) is int
|
||||
if my_channels is None: my_channels = []
|
||||
my_channels = {chan.short_channel_id: chan for chan in my_channels}
|
||||
if my_channels is None: my_channels = {}
|
||||
|
||||
# FIXME paths cannot be longer than 20 edges (onion packet)...
|
||||
|
||||
|
@ -204,7 +206,8 @@ class LNPathFinder(Logger):
|
|||
end_node=edge_endnode,
|
||||
payment_amt_msat=amount_msat,
|
||||
ignore_costs=(edge_startnode == nodeA),
|
||||
is_mine=is_mine)
|
||||
is_mine=is_mine,
|
||||
my_channels=my_channels)
|
||||
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
|
||||
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
|
||||
distance_from_start[edge_startnode] = alt_dist_to_neighbour
|
||||
|
@ -222,11 +225,11 @@ class LNPathFinder(Logger):
|
|||
# so instead of decreasing priorities, we add items again into the queue.
|
||||
# so there are duplicates in the queue, that we discard now:
|
||||
continue
|
||||
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
|
||||
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode, my_channels=my_channels):
|
||||
assert isinstance(edge_channel_id, bytes)
|
||||
if edge_channel_id in self.blacklist:
|
||||
continue
|
||||
channel_info = self.channel_db.get_channel_info(edge_channel_id)
|
||||
channel_info = self.channel_db.get_channel_info(edge_channel_id, my_channels=my_channels)
|
||||
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
|
||||
inspect_edge()
|
||||
else:
|
||||
|
@ -241,14 +244,17 @@ class LNPathFinder(Logger):
|
|||
edge_startnode = edge_endnode
|
||||
return path
|
||||
|
||||
def create_route_from_path(self, path, from_node_id: bytes) -> LNPaymentRoute:
|
||||
def create_route_from_path(self, path, from_node_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> LNPaymentRoute:
|
||||
assert isinstance(from_node_id, bytes)
|
||||
if path is None:
|
||||
raise Exception('cannot create route from None path')
|
||||
route = []
|
||||
prev_node_id = from_node_id
|
||||
for node_id, short_channel_id in path:
|
||||
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
|
||||
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
|
||||
node_id=prev_node_id,
|
||||
my_channels=my_channels)
|
||||
if channel_policy is None:
|
||||
raise NoChannelPolicy(short_channel_id)
|
||||
route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id))
|
||||
|
|
|
@ -942,16 +942,20 @@ class LNWallet(LNWorker):
|
|||
random.shuffle(r_tags)
|
||||
with self.lock:
|
||||
channels = list(self.channels.values())
|
||||
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
|
||||
if chan.short_channel_id is not None}
|
||||
for private_route in r_tags:
|
||||
if len(private_route) == 0:
|
||||
continue
|
||||
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
|
||||
continue
|
||||
border_node_pubkey = private_route[0][0]
|
||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
|
||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat,
|
||||
my_channels=scid_to_my_channels)
|
||||
if not path:
|
||||
continue
|
||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
|
||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
|
||||
my_channels=scid_to_my_channels)
|
||||
# we need to shift the node pubkey by one towards the destination:
|
||||
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
|
||||
private_route_rest = [edge[1:] for edge in private_route]
|
||||
|
@ -961,7 +965,9 @@ class LNWallet(LNWorker):
|
|||
short_channel_id = ShortChannelID(short_channel_id)
|
||||
# if we have a routing policy for this edge in the db, that takes precedence,
|
||||
# as it is likely from a previous failure
|
||||
channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id)
|
||||
channel_policy = self.channel_db.get_policy_for_node(short_channel_id=short_channel_id,
|
||||
node_id=prev_node_id,
|
||||
my_channels=scid_to_my_channels)
|
||||
if channel_policy:
|
||||
fee_base_msat = channel_policy.fee_base_msat
|
||||
fee_proportional_millionths = channel_policy.fee_proportional_millionths
|
||||
|
@ -977,10 +983,12 @@ class LNWallet(LNWorker):
|
|||
break
|
||||
# if could not find route using any hint; try without hint now
|
||||
if route is None:
|
||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat, channels)
|
||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, invoice_pubkey, amount_msat,
|
||||
my_channels=scid_to_my_channels)
|
||||
if not path:
|
||||
raise NoPathFound()
|
||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
|
||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey,
|
||||
my_channels=scid_to_my_channels)
|
||||
if not is_route_sane_to_use(route, amount_msat, decoded_invoice.get_min_final_cltv_expiry()):
|
||||
self.logger.info(f"rejecting insane route {route}")
|
||||
raise NoPathFound()
|
||||
|
@ -1099,6 +1107,8 @@ class LNWallet(LNWorker):
|
|||
routing_hints = []
|
||||
with self.lock:
|
||||
channels = list(self.channels.values())
|
||||
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
|
||||
if chan.short_channel_id is not None}
|
||||
# note: currently we add *all* our channels; but this might be a privacy leak?
|
||||
for chan in channels:
|
||||
# check channel is open
|
||||
|
@ -1110,7 +1120,7 @@ class LNWallet(LNWorker):
|
|||
continue
|
||||
chan_id = chan.short_channel_id
|
||||
assert isinstance(chan_id, bytes), chan_id
|
||||
channel_info = self.channel_db.get_channel_info(chan_id)
|
||||
channel_info = self.channel_db.get_channel_info(chan_id, my_channels=scid_to_my_channels)
|
||||
# note: as a fallback, if we don't have a channel update for the
|
||||
# incoming direction of our private channel, we fill the invoice with garbage.
|
||||
# the sender should still be able to pay us, but will incur an extra round trip
|
||||
|
@ -1120,7 +1130,8 @@ class LNWallet(LNWorker):
|
|||
cltv_expiry_delta = 1 # lnd won't even try with zero
|
||||
missing_info = True
|
||||
if channel_info:
|
||||
policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id)
|
||||
policy = self.channel_db.get_policy_for_node(channel_info.short_channel_id, chan.node_id,
|
||||
my_channels=scid_to_my_channels)
|
||||
if policy:
|
||||
fee_base_msat = policy.fee_base_msat
|
||||
fee_proportional_millionths = policy.fee_proportional_millionths
|
||||
|
|
|
@ -18,7 +18,7 @@ from electrum.lnpeer import Peer
|
|||
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
||||
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
||||
from electrum.lnutil import PaymentFailure, LnLocalFeatures
|
||||
from electrum.lnchannel import channel_states, peer_states
|
||||
from electrum.lnchannel import channel_states, peer_states, Channel
|
||||
from electrum.lnrouter import LNPathFinder
|
||||
from electrum.channel_db import ChannelDB
|
||||
from electrum.lnworker import LNWallet, NoPathFound
|
||||
|
@ -77,7 +77,7 @@ class MockWallet:
|
|||
return False
|
||||
|
||||
class MockLNWallet:
|
||||
def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
|
||||
def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue):
|
||||
self.remote_keypair = remote_keypair
|
||||
self.node_keypair = local_keypair
|
||||
self.network = MockNetwork(tx_queue)
|
||||
|
@ -88,6 +88,8 @@ class MockLNWallet:
|
|||
self.localfeatures = LnLocalFeatures(0)
|
||||
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
|
||||
self.pending_payments = defaultdict(asyncio.Future)
|
||||
chan.lnworker = self
|
||||
chan.node_id = remote_keypair.pubkey
|
||||
|
||||
def get_invoice_status(self, key):
|
||||
pass
|
||||
|
@ -127,6 +129,7 @@ class MockLNWallet:
|
|||
_pay_to_route = LNWallet._pay_to_route
|
||||
force_close_channel = LNWallet.force_close_channel
|
||||
get_first_timestamp = lambda self: 0
|
||||
payment_completed = LNWallet.payment_completed
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self, name):
|
||||
|
@ -264,7 +267,7 @@ class TestPeer(ElectrumTestCase):
|
|||
pay_req = self.prepare_invoice(w2)
|
||||
async def pay():
|
||||
result = await LNWallet._pay(w1, pay_req)
|
||||
self.assertEqual(result, True)
|
||||
self.assertTrue(result)
|
||||
gath.cancel()
|
||||
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
|
||||
async def f():
|
||||
|
|
Loading…
Add table
Reference in a new issue