# -*- coding: utf-8 -*- # # Electrum - lightweight Bitcoin client # Copyright (C) 2018 The Electrum developers # # Permission is hereby granted, free of charge, to any person # obtaining a copy of this software and associated documentation files # (the "Software"), to deal in the Software without restriction, # including without limitation the rights to use, copy, modify, merge, # publish, distribute, sublicense, and/or sell copies of the Software, # and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. from datetime import datetime import time import random import queue import os import json import threading import concurrent from collections import defaultdict from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set import binascii import base64 from sqlalchemy import Column, ForeignKey, Integer, String, Boolean from sqlalchemy.orm.query import Query from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import not_, or_ from .sql_db import SqlDB, sql, SQLITE_LIMIT_VARIABLE_NUMBER from . import constants from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks from .logging import Logger from .storage import JsonDB from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update from .crypto import sha256d from . import ecc from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, NotFoundChanAnnouncementForUpdate) from .lnmsg import encode_msg if TYPE_CHECKING: from .lnchannel import Channel from .network import Network class UnknownEvenFeatureBits(Exception): pass class NoChannelPolicy(Exception): def __init__(self, short_channel_id: bytes): super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}') def validate_features(features : int): enabled_features = list_enabled_bits(features) for fbit in enabled_features: if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: raise UnknownEvenFeatureBits() Base = declarative_base() FLAG_DISABLE = 1 << 1 FLAG_DIRECTION = 1 << 0 class ChannelInfo(Base): __tablename__ = 'channel_info' short_channel_id = Column(String(64), primary_key=True) node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) capacity_sat = Column(Integer) msg_payload_hex = Column(String(1024), nullable=False) trusted = Column(Boolean, nullable=False) @staticmethod def from_msg(payload): features = int.from_bytes(payload['features'], 'big') validate_features(features) channel_id = payload['short_channel_id'].hex() node_id_1 = payload['node_id_1'].hex() node_id_2 = payload['node_id_2'].hex() assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] msg_payload_hex = encode_msg('channel_announcement', **payload).hex() capacity_sat = None return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, trusted = False) @property def msg_payload(self): return bytes.fromhex(self.msg_payload_hex) class Policy(Base): __tablename__ = 'policy' start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True) cltv_expiry_delta = Column(Integer, nullable=False) htlc_minimum_msat = Column(Integer, nullable=False) htlc_maximum_msat = Column(Integer) fee_base_msat = Column(Integer, nullable=False) fee_proportional_millionths = Column(Integer, nullable=False) channel_flags = Column(Integer, nullable=False) timestamp = Column(Integer, nullable=False) @staticmethod def from_msg(payload): cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big") htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big") htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big") fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big") channel_flags = int.from_bytes(payload['channel_flags'], "big") timestamp = int.from_bytes(payload['timestamp'], "big") start_node = payload['start_node'].hex() short_channel_id = payload['short_channel_id'].hex() return Policy(start_node=start_node, short_channel_id=short_channel_id, cltv_expiry_delta=cltv_expiry_delta, htlc_minimum_msat=htlc_minimum_msat, fee_base_msat=fee_base_msat, fee_proportional_millionths=fee_proportional_millionths, channel_flags=channel_flags, timestamp=timestamp, htlc_maximum_msat=htlc_maximum_msat) def is_disabled(self): return self.channel_flags & FLAG_DISABLE class NodeInfo(Base): __tablename__ = 'node_info' node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') features = Column(Integer, nullable=False) timestamp = Column(Integer, nullable=False) alias = Column(String(64), nullable=False) @staticmethod def from_msg(payload): node_id = payload['node_id'].hex() features = int.from_bytes(payload['features'], "big") validate_features(features) addresses = NodeInfo.parse_addresses_field(payload['addresses']) alias = payload['alias'].rstrip(b'\x00').hex() timestamp = int.from_bytes(payload['timestamp'], "big") return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses] @staticmethod def parse_addresses_field(addresses_field): buf = addresses_field def read(n): nonlocal buf data, buf = buf[0:n], buf[n:] return data addresses = [] while buf: atype = ord(read(1)) if atype == 0: pass elif atype == 1: # IPv4 ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4))) port = int.from_bytes(read(2), 'big') if is_ip_address(ipv4_addr) and port != 0: addresses.append((ipv4_addr, port)) elif atype == 2: # IPv6 ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)]) ipv6_addr = ipv6_addr.decode('ascii') port = int.from_bytes(read(2), 'big') if is_ip_address(ipv6_addr) and port != 0: addresses.append((ipv6_addr, port)) elif atype == 3: # onion v2 host = base64.b32encode(read(10)) + b'.onion' host = host.decode('ascii').lower() port = int.from_bytes(read(2), 'big') addresses.append((host, port)) elif atype == 4: # onion v3 host = base64.b32encode(read(35)) + b'.onion' host = host.decode('ascii').lower() port = int.from_bytes(read(2), 'big') addresses.append((host, port)) else: # unknown address type # we don't know how long it is -> have to escape # if there are other addresses we could have parsed later, they are lost. break return addresses class Address(Base): __tablename__ = 'address' node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) host = Column(String(256), primary_key=True) port = Column(Integer, primary_key=True) last_connected_date = Column(Integer(), nullable=True) class ChannelDB(SqlDB): NUM_MAX_RECENT_PEERS = 20 def __init__(self, network: 'Network'): path = os.path.join(get_headers_dir(network.config), 'channel_db') super().__init__(network, path, Base) self.num_nodes = 0 self.num_channels = 0 self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self.ca_verifier = LNChannelVerifier(network, self) self.update_counts() @sql def update_counts(self): self._update_counts() def _update_counts(self): self.num_channels = self.DBSession.query(ChannelInfo).count() self.num_policies = self.DBSession.query(Policy).count() self.num_nodes = self.DBSession.query(NodeInfo).count() @sql @profiler def compare_channels(self, channel_ids): ids = [x.hex() for x in channel_ids] known = set() # I need to get the unknown, and also the channels that need refresh for ids_chunk in chunks(ids, SQLITE_LIMIT_VARIABLE_NUMBER): known_part = self.DBSession \ .query(ChannelInfo) \ .filter(ChannelInfo.short_channel_id.in_(ids_chunk)) \ .all() known |= set(bfh(r.short_channel_id) for r in known_part) return known @sql def add_recent_peer(self, peer: LNPeerAddr): now = int(time.time()) node_id = peer.pubkey.hex() addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() if addr: addr.last_connected_date = now else: addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) self.DBSession.add(addr) self.DBSession.commit() @sql def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): unshuffled = self.DBSession \ .query(NodeInfo) \ .filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ .limit(200) \ .all() return random.sample(unshuffled, len(unshuffled)) @sql def nodes_get(self, node_id): return self.DBSession \ .query(NodeInfo) \ .filter_by(node_id = node_id.hex()) \ .one_or_none() @sql def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all() if not r: return None addr = r[0] return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id)) @sql def get_recent_peers(self): r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all() return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] @sql def missing_channel_announcements(self) -> Set[int]: expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) @sql def missing_channel_updates(self) -> Set[int]: expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id))) return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) @sql def add_verified_channel_info(self, short_id, capacity): # called from lnchannelverifier channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none() channel_info.trusted = True channel_info.capacity = capacity self.DBSession.commit() @sql @profiler def on_channel_announcement(self, msg_payloads, trusted=True): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] new_channels = {} for msg in msg_payloads: short_channel_id = bh2u(msg['short_channel_id']) if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count(): continue if constants.net.rev_genesis_bytes() != msg['chain_hash']: self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) continue try: channel_info = ChannelInfo.from_msg(msg) except UnknownEvenFeatureBits: self.logger.info("unknown feature bits") continue channel_info.trusted = trusted new_channels[short_channel_id] = channel_info if not trusted: self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) for channel_info in new_channels.values(): self.DBSession.add(channel_info) self.DBSession.commit() self._update_counts() self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads))) @sql def get_last_timestamp(self): return self._get_last_timestamp() def _get_last_timestamp(self): from sqlalchemy.sql import func r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() return r.max_timestamp or 0 def print_change(self, old_policy, new_policy): # print what changed between policies if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta: self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}') if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') if old_policy.fee_base_msat != new_policy.fee_base_msat: self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') if old_policy.channel_flags != new_policy.channel_flags: self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') @sql def get_info_for_updates(self, payloads): short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads] channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} return channel_infos @sql def get_policies_for_updates(self, payloads): out = {} for payload in payloads: short_channel_id = payload['short_channel_id'].hex() start_node = payload['start_node'].hex() policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none() if policy: out[short_channel_id+start_node] = policy return out @profiler def filter_channel_updates(self, payloads, max_age=None): orphaned = [] # no channel announcement for channel update expired = [] # update older than two weeks deprecated = [] # update older than database entry good = {} # good updates to_delete = [] # database entries to delete # filter orphaned and expired first known = [] now = int(time.time()) channel_infos = self.get_info_for_updates(payloads) for payload in payloads: short_channel_id = payload['short_channel_id'] timestamp = int.from_bytes(payload['timestamp'], "big") if max_age and now - timestamp > max_age: expired.append(short_channel_id) continue channel_info = channel_infos.get(short_channel_id) if not channel_info: orphaned.append(short_channel_id) continue flags = int.from_bytes(payload['channel_flags'], 'big') direction = flags & FLAG_DIRECTION start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id payload['start_node'] = bfh(start_node) known.append(payload) # compare updates to existing database entries old_policies = self.get_policies_for_updates(known) for payload in known: timestamp = int.from_bytes(payload['timestamp'], "big") start_node = payload['start_node'] short_channel_id = payload['short_channel_id'] key = (short_channel_id+start_node).hex() old_policy = old_policies.get(key) if old_policy: if timestamp <= old_policy.timestamp: deprecated.append(short_channel_id) else: good[key] = payload to_delete.append(old_policy) else: good[key] = payload good = list(good.values()) return orphaned, expired, deprecated, good, to_delete def add_channel_update(self, payload): orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload]) assert len(good) == 1 self.update_policies(good, to_delete) @sql @profiler def update_policies(self, to_add, to_delete): for policy in to_delete: self.DBSession.delete(policy) self.DBSession.commit() for payload in to_add: policy = Policy.from_msg(payload) self.DBSession.add(policy) self.DBSession.commit() self._update_counts() @sql @profiler def on_node_announcement(self, msg_payloads): if type(msg_payloads) is dict: msg_payloads = [msg_payloads] old_addr = None new_nodes = {} new_addresses = {} for msg_payload in msg_payloads: try: node_info, node_addresses = NodeInfo.from_msg(msg_payload) except UnknownEvenFeatureBits: continue node_id = node_info.node_id # Ignore node if it has no associated channel (DoS protection) # FIXME this is slow expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id) if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0: #self.logger.info('ignoring orphan node_announcement') continue node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none() if node and node.timestamp >= node_info.timestamp: continue node = new_nodes.get(node_id) if node and node.timestamp >= node_info.timestamp: continue new_nodes[node_id] = node_info for addr in node_addresses: new_addresses[(addr.node_id,addr.host,addr.port)] = addr self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) for node_info in new_nodes.values(): self.DBSession.add(node_info) for new_addr in new_addresses.values(): old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() if not old_addr: self.DBSession.add(new_addr) self.DBSession.commit() self._update_counts() def get_routing_policy_for_channel(self, start_node_id: bytes, short_channel_id: bytes) -> Optional[bytes]: if not start_node_id or not short_channel_id: return None channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: return self.get_policy_for_node(short_channel_id, start_node_id) msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) if not msg: return None return Policy.from_msg(msg) # won't actually be written to DB @sql @profiler def get_old_policies(self, delta): timestamp = int(time.time()) - delta old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) return old_policies.distinct().count() @sql @profiler def prune_old_policies(self, delta): # note: delete queries are order sensitive timestamp = int(time.time()) - delta old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies)) delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp) self.DBSession.execute(delete_old_channels) self.DBSession.execute(delete_old_policies) self.DBSession.commit() self._update_counts() @sql @profiler def get_orphaned_channels(self): subquery = self.DBSession.query(Policy.short_channel_id) orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery))) return orphaned.count() @sql @profiler def prune_orphaned_channels(self): subquery = self.DBSession.query(Policy.short_channel_id) delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery))) self.DBSession.execute(delete_orphaned) self.DBSession.commit() self._update_counts() def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): if not verify_sig_for_channel_update(msg_payload, start_node_id): return # ignore short_channel_id = msg_payload['short_channel_id'] msg_payload['start_node'] = start_node_id self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload @sql def remove_channel(self, short_channel_id): r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none() if not r: return self.DBSession.delete(r) self.DBSession.commit() def print_graph(self, full_ids=False): # used for debugging. # FIXME there is a race here - iterables could change size from another thread def other_node_id(node_id, channel_id): channel_info = self.get_channel_info(channel_id) if node_id == channel_info.node1_id: other = channel_info.node2_id else: other = channel_info.node1_id return other if full_ids else other[-4:] print_msg('nodes') for node in self.DBSession.query(NodeInfo).all(): print_msg(node) print_msg('channels') for channel_info in self.DBSession.query(ChannelInfo).all(): short_channel_id = channel_info.short_channel_id node1 = channel_info.node1_id node2 = channel_info.node2_id direction1 = self.get_policy_for_node(channel_info, node1) is not None direction2 = self.get_policy_for_node(channel_info, node2) is not None if direction1 and direction2: direction = 'both' elif direction1: direction = 'forward' elif direction2: direction = 'backward' else: direction = 'none' print_msg('{}: {}, {}, {}' .format(bh2u(short_channel_id), bh2u(node1) if full_ids else bh2u(node1[-4:]), bh2u(node2) if full_ids else bh2u(node2[-4:]), direction)) @sql def get_node_addresses(self, node_info): return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() @sql @profiler def load_data(self): r = self.DBSession.query(ChannelInfo).all() self._channels = dict([(bfh(x.short_channel_id), x) for x in r]) r = self.DBSession.query(Policy).filter_by().all() self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r]) self._channels_for_node = defaultdict(set) for channel_info in self._channels.values(): self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id)) self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id)) self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: return self._policies.get((node_id, short_channel_id)) def get_channel_info(self, channel_id: bytes): return self._channels.get(channel_id) def get_channels_for_node(self, node_id): """Returns the set of channels that have node_id as one of the endpoints.""" return self._channels_for_node.get(node_id) class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes), ('short_channel_id', bytes), ('fee_base_msat', int), ('fee_proportional_millionths', int), ('cltv_expiry_delta', int)])): """if you travel through short_channel_id, you will reach node_id""" def fee_for_edge(self, amount_msat: int) -> int: return self.fee_base_msat \ + (amount_msat * self.fee_proportional_millionths // 1_000_000) @classmethod def from_channel_policy(cls, channel_policy: 'Policy', short_channel_id: bytes, end_node: bytes) -> 'RouteEdge': assert type(short_channel_id) is bytes assert type(end_node) is bytes return RouteEdge(end_node, short_channel_id, channel_policy.fee_base_msat, channel_policy.fee_proportional_millionths, channel_policy.cltv_expiry_delta) def is_sane_to_use(self, amount_msat: int) -> bool: # TODO revise ad-hoc heuristics # cltv cannot be more than 2 weeks if self.cltv_expiry_delta > 14 * 144: return False total_fee = self.fee_for_edge(amount_msat) # fees below 50 sat are fine if total_fee > 50_000: # fee cannot be higher than amt if total_fee > amount_msat: return False # fee cannot be higher than 5000 sat if total_fee > 5_000_000: return False # unless amt is tiny, fee cannot be more than 10% if amount_msat > 1_000_000 and total_fee > amount_msat/10: return False return True def is_route_sane_to_use(route: List[RouteEdge], invoice_amount_msat: int, min_final_cltv_expiry: int) -> bool: """Run some sanity checks on the whole route, before attempting to use it. called when we are paying; so e.g. lower cltv is better """ if len(route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: return False amt = invoice_amount_msat cltv = min_final_cltv_expiry for route_edge in reversed(route[1:]): if not route_edge.is_sane_to_use(amt): return False amt += route_edge.fee_for_edge(amt) cltv += route_edge.cltv_expiry_delta total_fee = amt - invoice_amount_msat # TODO revise ad-hoc heuristics # cltv cannot be more than 2 months if cltv > 60 * 144: return False # fees below 50 sat are fine if total_fee > 50_000: # fee cannot be higher than amt if total_fee > invoice_amount_msat: return False # fee cannot be higher than 5000 sat if total_fee > 5_000_000: return False # unless amt is tiny, fee cannot be more than 10% if invoice_amount_msat > 1_000_000 and total_fee > invoice_amount_msat/10: return False return True class LNPathFinder(Logger): def __init__(self, channel_db: ChannelDB): Logger.__init__(self) self.channel_db = channel_db self.blacklist = set() def _edge_cost(self, short_channel_id: bytes, start_node: bytes, end_node: bytes, payment_amt_msat: int, ignore_costs=False) -> Tuple[float, int]: """Heuristic cost of going through a channel. Returns (heuristic_cost, fee_for_edge_msat). """ channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo if channel_info is None: return float('inf'), 0 channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node) if channel_policy is None: return float('inf'), 0 if channel_policy.is_disabled(): return float('inf'), 0 route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node) if payment_amt_msat < channel_policy.htlc_minimum_msat: return float('inf'), 0 # payment amount too little if channel_info.capacity_sat is not None and \ payment_amt_msat // 1000 > channel_info.capacity_sat: return float('inf'), 0 # payment amount too large if channel_policy.htlc_maximum_msat is not None and \ payment_amt_msat > channel_policy.htlc_maximum_msat: return float('inf'), 0 # payment amount too large if not route_edge.is_sane_to_use(payment_amt_msat): return float('inf'), 0 # thanks but no thanks fee_msat = route_edge.fee_for_edge(payment_amt_msat) if not ignore_costs else 0 # TODO revise # paying 10 more satoshis ~ waiting one more block fee_cost = fee_msat / 1000 / 10 cltv_cost = route_edge.cltv_expiry_delta if not ignore_costs else 0 return cltv_cost + fee_cost + 1, fee_msat @profiler def find_path_for_payment(self, nodeA: bytes, nodeB: bytes, invoice_amount_msat: int, my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]: """Return a path from nodeA to nodeB. Returns a list of (node_id, short_channel_id) representing a path. To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; i.e. an element reads as, "to get to node_id, travel through short_channel_id" """ self.channel_db.load_data() assert type(nodeA) is bytes assert type(nodeB) is bytes assert type(invoice_amount_msat) is int if my_channels is None: my_channels = [] my_channels = {chan.short_channel_id: chan for chan in my_channels} # FIXME paths cannot be longer than 21 edges (onion packet)... # run Dijkstra # The search is run in the REVERSE direction, from nodeB to nodeA, # to properly calculate compound routing fees. distance_from_start = defaultdict(lambda: float('inf')) distance_from_start[nodeB] = 0 prev_node = {} nodes_to_explore = queue.PriorityQueue() nodes_to_explore.put((0, invoice_amount_msat, nodeB)) # order of fields (in tuple) matters! def inspect_edge(): if edge_channel_id in my_channels: if edge_startnode == nodeA: # payment outgoing, on our channel if not my_channels[edge_channel_id].can_pay(amount_msat): return else: # payment incoming, on our channel. (funny business, cycle weirdness) assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode)) pass # TODO? edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id, start_node=edge_startnode, end_node=edge_endnode, payment_amt_msat=amount_msat, ignore_costs=(edge_startnode == nodeA)) alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost if alt_dist_to_neighbour < distance_from_start[edge_startnode]: distance_from_start[edge_startnode] = alt_dist_to_neighbour prev_node[edge_startnode] = edge_endnode, edge_channel_id amount_to_forward_msat = amount_msat + fee_for_edge_msat nodes_to_explore.put((alt_dist_to_neighbour, amount_to_forward_msat, edge_startnode)) # main loop of search while nodes_to_explore.qsize() > 0: dist_to_edge_endnode, amount_msat, edge_endnode = nodes_to_explore.get() if edge_endnode == nodeA: break if dist_to_edge_endnode != distance_from_start[edge_endnode]: # queue.PriorityQueue does not implement decrease_priority, # so instead of decreasing priorities, we add items again into the queue. # so there are duplicates in the queue, that we discard now: continue for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode): assert type(edge_channel_id) is bytes if edge_channel_id in self.blacklist: continue channel_info = self.channel_db.get_channel_info(edge_channel_id) edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id) inspect_edge() else: return None # no path found # backtrack from search_end (nodeA) to search_start (nodeB) edge_startnode = nodeA path = [] while edge_startnode != nodeB: edge_endnode, edge_taken = prev_node[edge_startnode] path += [(edge_endnode, edge_taken)] edge_startnode = edge_endnode return path def create_route_from_path(self, path, from_node_id: bytes) -> List[RouteEdge]: assert type(from_node_id) is bytes if path is None: raise Exception('cannot create route from None path') route = [] prev_node_id = from_node_id for node_id, short_channel_id in path: channel_policy = self.channel_db.get_routing_policy_for_channel(prev_node_id, short_channel_id) if channel_policy is None: raise NoChannelPolicy(short_channel_id) route.append(RouteEdge.from_channel_policy(channel_policy, short_channel_id, node_id)) prev_node_id = node_id return route