mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-09-13 05:59:51 +00:00
separate channel_db module
This commit is contained in:
parent
06b5299b0f
commit
180f6d34be
3 changed files with 596 additions and 538 deletions
589
electrum/channel_db.py
Normal file
589
electrum/channel_db.py
Normal file
|
@ -0,0 +1,589 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Electrum - lightweight Bitcoin client
|
||||||
|
# Copyright (C) 2018 The Electrum developers
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person
|
||||||
|
# obtaining a copy of this software and associated documentation files
|
||||||
|
# (the "Software"), to deal in the Software without restriction,
|
||||||
|
# including without limitation the rights to use, copy, modify, merge,
|
||||||
|
# publish, distribute, sublicense, and/or sell copies of the Software,
|
||||||
|
# and to permit persons to whom the Software is furnished to do so,
|
||||||
|
# subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be
|
||||||
|
# included in all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
||||||
|
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
||||||
|
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import queue
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import concurrent
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
|
||||||
|
import binascii
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
|
||||||
|
from sqlalchemy.orm.query import Query
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.sql import not_, or_
|
||||||
|
|
||||||
|
from .sql_db import SqlDB, sql
|
||||||
|
from . import constants
|
||||||
|
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks
|
||||||
|
from .logging import Logger
|
||||||
|
from .storage import JsonDB
|
||||||
|
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
|
||||||
|
from .crypto import sha256d
|
||||||
|
from . import ecc
|
||||||
|
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
|
||||||
|
NotFoundChanAnnouncementForUpdate)
|
||||||
|
from .lnmsg import encode_msg
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .lnchannel import Channel
|
||||||
|
from .network import Network
|
||||||
|
|
||||||
|
class UnknownEvenFeatureBits(Exception): pass
|
||||||
|
|
||||||
|
def validate_features(features : int):
|
||||||
|
enabled_features = list_enabled_bits(features)
|
||||||
|
for fbit in enabled_features:
|
||||||
|
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
|
||||||
|
raise UnknownEvenFeatureBits()
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
FLAG_DISABLE = 1 << 1
|
||||||
|
FLAG_DIRECTION = 1 << 0
|
||||||
|
|
||||||
|
class ChannelInfo(Base):
|
||||||
|
__tablename__ = 'channel_info'
|
||||||
|
short_channel_id = Column(String(64), primary_key=True)
|
||||||
|
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||||
|
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||||
|
capacity_sat = Column(Integer)
|
||||||
|
msg_payload_hex = Column(String(1024), nullable=False)
|
||||||
|
trusted = Column(Boolean, nullable=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_msg(payload):
|
||||||
|
features = int.from_bytes(payload['features'], 'big')
|
||||||
|
validate_features(features)
|
||||||
|
channel_id = payload['short_channel_id'].hex()
|
||||||
|
node_id_1 = payload['node_id_1'].hex()
|
||||||
|
node_id_2 = payload['node_id_2'].hex()
|
||||||
|
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
|
||||||
|
msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
|
||||||
|
capacity_sat = None
|
||||||
|
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
|
||||||
|
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
|
||||||
|
trusted = False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def msg_payload(self):
|
||||||
|
return bytes.fromhex(self.msg_payload_hex)
|
||||||
|
|
||||||
|
|
||||||
|
class Policy(Base):
|
||||||
|
__tablename__ = 'policy'
|
||||||
|
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
||||||
|
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True)
|
||||||
|
cltv_expiry_delta = Column(Integer, nullable=False)
|
||||||
|
htlc_minimum_msat = Column(Integer, nullable=False)
|
||||||
|
htlc_maximum_msat = Column(Integer)
|
||||||
|
fee_base_msat = Column(Integer, nullable=False)
|
||||||
|
fee_proportional_millionths = Column(Integer, nullable=False)
|
||||||
|
channel_flags = Column(Integer, nullable=False)
|
||||||
|
timestamp = Column(Integer, nullable=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_msg(payload):
|
||||||
|
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
|
||||||
|
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
|
||||||
|
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
|
||||||
|
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
|
||||||
|
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
|
||||||
|
channel_flags = int.from_bytes(payload['channel_flags'], "big")
|
||||||
|
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||||
|
start_node = payload['start_node'].hex()
|
||||||
|
short_channel_id = payload['short_channel_id'].hex()
|
||||||
|
|
||||||
|
return Policy(start_node=start_node,
|
||||||
|
short_channel_id=short_channel_id,
|
||||||
|
cltv_expiry_delta=cltv_expiry_delta,
|
||||||
|
htlc_minimum_msat=htlc_minimum_msat,
|
||||||
|
fee_base_msat=fee_base_msat,
|
||||||
|
fee_proportional_millionths=fee_proportional_millionths,
|
||||||
|
channel_flags=channel_flags,
|
||||||
|
timestamp=timestamp,
|
||||||
|
htlc_maximum_msat=htlc_maximum_msat)
|
||||||
|
|
||||||
|
def is_disabled(self):
|
||||||
|
return self.channel_flags & FLAG_DISABLE
|
||||||
|
|
||||||
|
class NodeInfo(Base):
|
||||||
|
__tablename__ = 'node_info'
|
||||||
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||||
|
features = Column(Integer, nullable=False)
|
||||||
|
timestamp = Column(Integer, nullable=False)
|
||||||
|
alias = Column(String(64), nullable=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_msg(payload):
|
||||||
|
node_id = payload['node_id'].hex()
|
||||||
|
features = int.from_bytes(payload['features'], "big")
|
||||||
|
validate_features(features)
|
||||||
|
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
|
||||||
|
alias = payload['alias'].rstrip(b'\x00').hex()
|
||||||
|
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||||
|
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
|
||||||
|
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_addresses_field(addresses_field):
|
||||||
|
buf = addresses_field
|
||||||
|
def read(n):
|
||||||
|
nonlocal buf
|
||||||
|
data, buf = buf[0:n], buf[n:]
|
||||||
|
return data
|
||||||
|
addresses = []
|
||||||
|
while buf:
|
||||||
|
atype = ord(read(1))
|
||||||
|
if atype == 0:
|
||||||
|
pass
|
||||||
|
elif atype == 1: # IPv4
|
||||||
|
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
|
||||||
|
port = int.from_bytes(read(2), 'big')
|
||||||
|
if is_ip_address(ipv4_addr) and port != 0:
|
||||||
|
addresses.append((ipv4_addr, port))
|
||||||
|
elif atype == 2: # IPv6
|
||||||
|
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
|
||||||
|
ipv6_addr = ipv6_addr.decode('ascii')
|
||||||
|
port = int.from_bytes(read(2), 'big')
|
||||||
|
if is_ip_address(ipv6_addr) and port != 0:
|
||||||
|
addresses.append((ipv6_addr, port))
|
||||||
|
elif atype == 3: # onion v2
|
||||||
|
host = base64.b32encode(read(10)) + b'.onion'
|
||||||
|
host = host.decode('ascii').lower()
|
||||||
|
port = int.from_bytes(read(2), 'big')
|
||||||
|
addresses.append((host, port))
|
||||||
|
elif atype == 4: # onion v3
|
||||||
|
host = base64.b32encode(read(35)) + b'.onion'
|
||||||
|
host = host.decode('ascii').lower()
|
||||||
|
port = int.from_bytes(read(2), 'big')
|
||||||
|
addresses.append((host, port))
|
||||||
|
else:
|
||||||
|
# unknown address type
|
||||||
|
# we don't know how long it is -> have to escape
|
||||||
|
# if there are other addresses we could have parsed later, they are lost.
|
||||||
|
break
|
||||||
|
return addresses
|
||||||
|
|
||||||
|
class Address(Base):
|
||||||
|
__tablename__ = 'address'
|
||||||
|
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
||||||
|
host = Column(String(256), primary_key=True)
|
||||||
|
port = Column(Integer, primary_key=True)
|
||||||
|
last_connected_date = Column(Integer(), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDB(SqlDB):
|
||||||
|
|
||||||
|
NUM_MAX_RECENT_PEERS = 20
|
||||||
|
|
||||||
|
def __init__(self, network: 'Network'):
|
||||||
|
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
||||||
|
super().__init__(network, path, Base)
|
||||||
|
self.num_nodes = 0
|
||||||
|
self.num_channels = 0
|
||||||
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
||||||
|
self.ca_verifier = LNChannelVerifier(network, self)
|
||||||
|
self.update_counts()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def update_counts(self):
|
||||||
|
self._update_counts()
|
||||||
|
|
||||||
|
def _update_counts(self):
|
||||||
|
self.num_channels = self.DBSession.query(ChannelInfo).count()
|
||||||
|
self.num_policies = self.DBSession.query(Policy).count()
|
||||||
|
self.num_nodes = self.DBSession.query(NodeInfo).count()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def known_ids(self):
|
||||||
|
known = self.DBSession.query(ChannelInfo.short_channel_id).all()
|
||||||
|
return set(bfh(r.short_channel_id) for r in known)
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def add_recent_peer(self, peer: LNPeerAddr):
|
||||||
|
now = int(time.time())
|
||||||
|
node_id = peer.pubkey.hex()
|
||||||
|
addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
|
||||||
|
if addr:
|
||||||
|
addr.last_connected_date = now
|
||||||
|
else:
|
||||||
|
addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
|
||||||
|
self.DBSession.add(addr)
|
||||||
|
self.DBSession.commit()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
|
||||||
|
unshuffled = self.DBSession \
|
||||||
|
.query(NodeInfo) \
|
||||||
|
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
|
||||||
|
.limit(200) \
|
||||||
|
.all()
|
||||||
|
return random.sample(unshuffled, len(unshuffled))
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def nodes_get(self, node_id):
|
||||||
|
return self.DBSession \
|
||||||
|
.query(NodeInfo) \
|
||||||
|
.filter_by(node_id = node_id.hex()) \
|
||||||
|
.one_or_none()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
|
||||||
|
r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all()
|
||||||
|
if not r:
|
||||||
|
return None
|
||||||
|
addr = r[0]
|
||||||
|
return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id))
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_recent_peers(self):
|
||||||
|
r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all()
|
||||||
|
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def missing_channel_announcements(self) -> Set[int]:
|
||||||
|
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
|
||||||
|
return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def missing_channel_updates(self) -> Set[int]:
|
||||||
|
expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
|
||||||
|
return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def add_verified_channel_info(self, short_id, capacity):
|
||||||
|
# called from lnchannelverifier
|
||||||
|
channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none()
|
||||||
|
channel_info.trusted = True
|
||||||
|
channel_info.capacity = capacity
|
||||||
|
self.DBSession.commit()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def on_channel_announcement(self, msg_payloads, trusted=True):
|
||||||
|
if type(msg_payloads) is dict:
|
||||||
|
msg_payloads = [msg_payloads]
|
||||||
|
new_channels = {}
|
||||||
|
for msg in msg_payloads:
|
||||||
|
short_channel_id = bh2u(msg['short_channel_id'])
|
||||||
|
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
|
||||||
|
continue
|
||||||
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
|
||||||
|
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
channel_info = ChannelInfo.from_msg(msg)
|
||||||
|
except UnknownEvenFeatureBits:
|
||||||
|
self.logger.info("unknown feature bits")
|
||||||
|
continue
|
||||||
|
channel_info.trusted = trusted
|
||||||
|
new_channels[short_channel_id] = channel_info
|
||||||
|
if not trusted:
|
||||||
|
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
|
||||||
|
for channel_info in new_channels.values():
|
||||||
|
self.DBSession.add(channel_info)
|
||||||
|
self.DBSession.commit()
|
||||||
|
self._update_counts()
|
||||||
|
self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_last_timestamp(self):
|
||||||
|
return self._get_last_timestamp()
|
||||||
|
|
||||||
|
def _get_last_timestamp(self):
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
|
||||||
|
return r.max_timestamp or 0
|
||||||
|
|
||||||
|
def print_change(self, old_policy, new_policy):
|
||||||
|
# print what changed between policies
|
||||||
|
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
|
||||||
|
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
|
||||||
|
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
|
||||||
|
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
|
||||||
|
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
|
||||||
|
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
|
||||||
|
if old_policy.fee_base_msat != new_policy.fee_base_msat:
|
||||||
|
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
|
||||||
|
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
|
||||||
|
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
|
||||||
|
if old_policy.channel_flags != new_policy.channel_flags:
|
||||||
|
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_info_for_updates(self, payloads):
|
||||||
|
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
|
||||||
|
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
|
||||||
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
|
||||||
|
return channel_infos
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_policies_for_updates(self, payloads):
|
||||||
|
out = {}
|
||||||
|
for payload in payloads:
|
||||||
|
short_channel_id = payload['short_channel_id'].hex()
|
||||||
|
start_node = payload['start_node'].hex()
|
||||||
|
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
|
||||||
|
if policy:
|
||||||
|
out[short_channel_id+start_node] = policy
|
||||||
|
return out
|
||||||
|
|
||||||
|
@profiler
|
||||||
|
def filter_channel_updates(self, payloads, max_age=None):
|
||||||
|
orphaned = [] # no channel announcement for channel update
|
||||||
|
expired = [] # update older than two weeks
|
||||||
|
deprecated = [] # update older than database entry
|
||||||
|
good = {} # good updates
|
||||||
|
to_delete = [] # database entries to delete
|
||||||
|
# filter orphaned and expired first
|
||||||
|
known = []
|
||||||
|
now = int(time.time())
|
||||||
|
channel_infos = self.get_info_for_updates(payloads)
|
||||||
|
for payload in payloads:
|
||||||
|
short_channel_id = payload['short_channel_id']
|
||||||
|
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||||
|
if max_age and now - timestamp > max_age:
|
||||||
|
expired.append(short_channel_id)
|
||||||
|
continue
|
||||||
|
channel_info = channel_infos.get(short_channel_id)
|
||||||
|
if not channel_info:
|
||||||
|
orphaned.append(short_channel_id)
|
||||||
|
continue
|
||||||
|
flags = int.from_bytes(payload['channel_flags'], 'big')
|
||||||
|
direction = flags & FLAG_DIRECTION
|
||||||
|
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
|
||||||
|
payload['start_node'] = bfh(start_node)
|
||||||
|
known.append(payload)
|
||||||
|
# compare updates to existing database entries
|
||||||
|
old_policies = self.get_policies_for_updates(known)
|
||||||
|
for payload in known:
|
||||||
|
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||||
|
start_node = payload['start_node']
|
||||||
|
short_channel_id = payload['short_channel_id']
|
||||||
|
key = (short_channel_id+start_node).hex()
|
||||||
|
old_policy = old_policies.get(key)
|
||||||
|
if old_policy:
|
||||||
|
if timestamp <= old_policy.timestamp:
|
||||||
|
deprecated.append(short_channel_id)
|
||||||
|
else:
|
||||||
|
good[key] = payload
|
||||||
|
to_delete.append(old_policy)
|
||||||
|
else:
|
||||||
|
good[key] = payload
|
||||||
|
good = list(good.values())
|
||||||
|
return orphaned, expired, deprecated, good, to_delete
|
||||||
|
|
||||||
|
def add_channel_update(self, payload):
|
||||||
|
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
|
||||||
|
assert len(good) == 1
|
||||||
|
self.update_policies(good, to_delete)
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def update_policies(self, to_add, to_delete):
|
||||||
|
for policy in to_delete:
|
||||||
|
self.DBSession.delete(policy)
|
||||||
|
self.DBSession.commit()
|
||||||
|
for payload in to_add:
|
||||||
|
policy = Policy.from_msg(payload)
|
||||||
|
self.DBSession.add(policy)
|
||||||
|
self.DBSession.commit()
|
||||||
|
self._update_counts()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def on_node_announcement(self, msg_payloads):
|
||||||
|
if type(msg_payloads) is dict:
|
||||||
|
msg_payloads = [msg_payloads]
|
||||||
|
old_addr = None
|
||||||
|
new_nodes = {}
|
||||||
|
new_addresses = {}
|
||||||
|
for msg_payload in msg_payloads:
|
||||||
|
try:
|
||||||
|
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
|
||||||
|
except UnknownEvenFeatureBits:
|
||||||
|
continue
|
||||||
|
node_id = node_info.node_id
|
||||||
|
# Ignore node if it has no associated channel (DoS protection)
|
||||||
|
# FIXME this is slow
|
||||||
|
expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id)
|
||||||
|
if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0:
|
||||||
|
#self.logger.info('ignoring orphan node_announcement')
|
||||||
|
continue
|
||||||
|
node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
|
||||||
|
if node and node.timestamp >= node_info.timestamp:
|
||||||
|
continue
|
||||||
|
node = new_nodes.get(node_id)
|
||||||
|
if node and node.timestamp >= node_info.timestamp:
|
||||||
|
continue
|
||||||
|
new_nodes[node_id] = node_info
|
||||||
|
for addr in node_addresses:
|
||||||
|
new_addresses[(addr.node_id,addr.host,addr.port)] = addr
|
||||||
|
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||||
|
for node_info in new_nodes.values():
|
||||||
|
self.DBSession.add(node_info)
|
||||||
|
for new_addr in new_addresses.values():
|
||||||
|
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
|
||||||
|
if not old_addr:
|
||||||
|
self.DBSession.add(new_addr)
|
||||||
|
self.DBSession.commit()
|
||||||
|
self._update_counts()
|
||||||
|
|
||||||
|
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
||||||
|
short_channel_id: bytes) -> Optional[bytes]:
|
||||||
|
if not start_node_id or not short_channel_id: return None
|
||||||
|
channel_info = self.get_channel_info(short_channel_id)
|
||||||
|
if channel_info is not None:
|
||||||
|
return self.get_policy_for_node(short_channel_id, start_node_id)
|
||||||
|
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
||||||
|
if not msg:
|
||||||
|
return None
|
||||||
|
return Policy.from_msg(msg) # won't actually be written to DB
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def get_old_policies(self, delta):
|
||||||
|
timestamp = int(time.time()) - delta
|
||||||
|
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
|
||||||
|
return old_policies.distinct().count()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def prune_old_policies(self, delta):
|
||||||
|
# note: delete queries are order sensitive
|
||||||
|
timestamp = int(time.time()) - delta
|
||||||
|
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
|
||||||
|
delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies))
|
||||||
|
delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp)
|
||||||
|
self.DBSession.execute(delete_old_channels)
|
||||||
|
self.DBSession.execute(delete_old_policies)
|
||||||
|
self.DBSession.commit()
|
||||||
|
self._update_counts()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def get_orphaned_channels(self):
|
||||||
|
subquery = self.DBSession.query(Policy.short_channel_id)
|
||||||
|
orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery)))
|
||||||
|
return orphaned.count()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def prune_orphaned_channels(self):
|
||||||
|
subquery = self.DBSession.query(Policy.short_channel_id)
|
||||||
|
delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery)))
|
||||||
|
self.DBSession.execute(delete_orphaned)
|
||||||
|
self.DBSession.commit()
|
||||||
|
self._update_counts()
|
||||||
|
|
||||||
|
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
|
||||||
|
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
||||||
|
return # ignore
|
||||||
|
short_channel_id = msg_payload['short_channel_id']
|
||||||
|
msg_payload['start_node'] = start_node_id
|
||||||
|
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def remove_channel(self, short_channel_id):
|
||||||
|
r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none()
|
||||||
|
if not r:
|
||||||
|
return
|
||||||
|
self.DBSession.delete(r)
|
||||||
|
self.DBSession.commit()
|
||||||
|
|
||||||
|
def print_graph(self, full_ids=False):
|
||||||
|
# used for debugging.
|
||||||
|
# FIXME there is a race here - iterables could change size from another thread
|
||||||
|
def other_node_id(node_id, channel_id):
|
||||||
|
channel_info = self.get_channel_info(channel_id)
|
||||||
|
if node_id == channel_info.node1_id:
|
||||||
|
other = channel_info.node2_id
|
||||||
|
else:
|
||||||
|
other = channel_info.node1_id
|
||||||
|
return other if full_ids else other[-4:]
|
||||||
|
|
||||||
|
print_msg('nodes')
|
||||||
|
for node in self.DBSession.query(NodeInfo).all():
|
||||||
|
print_msg(node)
|
||||||
|
|
||||||
|
print_msg('channels')
|
||||||
|
for channel_info in self.DBSession.query(ChannelInfo).all():
|
||||||
|
short_channel_id = channel_info.short_channel_id
|
||||||
|
node1 = channel_info.node1_id
|
||||||
|
node2 = channel_info.node2_id
|
||||||
|
direction1 = self.get_policy_for_node(channel_info, node1) is not None
|
||||||
|
direction2 = self.get_policy_for_node(channel_info, node2) is not None
|
||||||
|
if direction1 and direction2:
|
||||||
|
direction = 'both'
|
||||||
|
elif direction1:
|
||||||
|
direction = 'forward'
|
||||||
|
elif direction2:
|
||||||
|
direction = 'backward'
|
||||||
|
else:
|
||||||
|
direction = 'none'
|
||||||
|
print_msg('{}: {}, {}, {}'
|
||||||
|
.format(bh2u(short_channel_id),
|
||||||
|
bh2u(node1) if full_ids else bh2u(node1[-4:]),
|
||||||
|
bh2u(node2) if full_ids else bh2u(node2[-4:]),
|
||||||
|
direction))
|
||||||
|
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_node_addresses(self, node_info):
|
||||||
|
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def load_data(self):
|
||||||
|
r = self.DBSession.query(ChannelInfo).all()
|
||||||
|
self._channels = dict([(bfh(x.short_channel_id), x) for x in r])
|
||||||
|
r = self.DBSession.query(Policy).filter_by().all()
|
||||||
|
self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r])
|
||||||
|
self._channels_for_node = defaultdict(set)
|
||||||
|
for channel_info in self._channels.values():
|
||||||
|
self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id))
|
||||||
|
self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id))
|
||||||
|
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
|
||||||
|
|
||||||
|
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
|
||||||
|
return self._policies.get((node_id, short_channel_id))
|
||||||
|
|
||||||
|
def get_channel_info(self, channel_id: bytes):
|
||||||
|
return self._channels.get(channel_id)
|
||||||
|
|
||||||
|
def get_channels_for_node(self, node_id) -> Set[bytes]:
|
||||||
|
"""Returns the set of channels that have node_id as one of the endpoints."""
|
||||||
|
return self._channels_for_node.get(node_id) or set()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,12 +36,6 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
|
||||||
import binascii
|
import binascii
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from sqlalchemy import Column, ForeignKey, Integer, String, Boolean
|
|
||||||
from sqlalchemy.orm.query import Query
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.sql import not_, or_
|
|
||||||
|
|
||||||
from .sql_db import SqlDB, sql
|
|
||||||
from . import constants
|
from . import constants
|
||||||
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks
|
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks
|
||||||
from .logging import Logger
|
from .logging import Logger
|
||||||
|
@ -52,543 +46,16 @@ from . import ecc
|
||||||
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
|
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
|
||||||
NotFoundChanAnnouncementForUpdate)
|
NotFoundChanAnnouncementForUpdate)
|
||||||
from .lnmsg import encode_msg
|
from .lnmsg import encode_msg
|
||||||
|
from .channel_db import ChannelDB
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .lnchannel import Channel
|
from .lnchannel import Channel
|
||||||
from .network import Network
|
from .network import Network
|
||||||
|
|
||||||
class UnknownEvenFeatureBits(Exception): pass
|
|
||||||
class NoChannelPolicy(Exception):
|
class NoChannelPolicy(Exception):
|
||||||
def __init__(self, short_channel_id: bytes):
|
def __init__(self, short_channel_id: bytes):
|
||||||
super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
|
super().__init__(f'cannot find channel policy for short_channel_id: {bh2u(short_channel_id)}')
|
||||||
|
|
||||||
def validate_features(features : int):
|
|
||||||
enabled_features = list_enabled_bits(features)
|
|
||||||
for fbit in enabled_features:
|
|
||||||
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
|
|
||||||
raise UnknownEvenFeatureBits()
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
FLAG_DISABLE = 1 << 1
|
|
||||||
FLAG_DIRECTION = 1 << 0
|
|
||||||
|
|
||||||
class ChannelInfo(Base):
|
|
||||||
__tablename__ = 'channel_info'
|
|
||||||
short_channel_id = Column(String(64), primary_key=True)
|
|
||||||
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
|
||||||
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
|
||||||
capacity_sat = Column(Integer)
|
|
||||||
msg_payload_hex = Column(String(1024), nullable=False)
|
|
||||||
trusted = Column(Boolean, nullable=False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_msg(payload):
|
|
||||||
features = int.from_bytes(payload['features'], 'big')
|
|
||||||
validate_features(features)
|
|
||||||
channel_id = payload['short_channel_id'].hex()
|
|
||||||
node_id_1 = payload['node_id_1'].hex()
|
|
||||||
node_id_2 = payload['node_id_2'].hex()
|
|
||||||
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
|
|
||||||
msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
|
|
||||||
capacity_sat = None
|
|
||||||
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
|
|
||||||
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
|
|
||||||
trusted = False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def msg_payload(self):
|
|
||||||
return bytes.fromhex(self.msg_payload_hex)
|
|
||||||
|
|
||||||
|
|
||||||
class Policy(Base):
|
|
||||||
__tablename__ = 'policy'
|
|
||||||
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
|
||||||
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True)
|
|
||||||
cltv_expiry_delta = Column(Integer, nullable=False)
|
|
||||||
htlc_minimum_msat = Column(Integer, nullable=False)
|
|
||||||
htlc_maximum_msat = Column(Integer)
|
|
||||||
fee_base_msat = Column(Integer, nullable=False)
|
|
||||||
fee_proportional_millionths = Column(Integer, nullable=False)
|
|
||||||
channel_flags = Column(Integer, nullable=False)
|
|
||||||
timestamp = Column(Integer, nullable=False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_msg(payload):
|
|
||||||
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
|
|
||||||
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
|
|
||||||
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
|
|
||||||
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
|
|
||||||
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
|
|
||||||
channel_flags = int.from_bytes(payload['channel_flags'], "big")
|
|
||||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
|
||||||
start_node = payload['start_node'].hex()
|
|
||||||
short_channel_id = payload['short_channel_id'].hex()
|
|
||||||
|
|
||||||
return Policy(start_node=start_node,
|
|
||||||
short_channel_id=short_channel_id,
|
|
||||||
cltv_expiry_delta=cltv_expiry_delta,
|
|
||||||
htlc_minimum_msat=htlc_minimum_msat,
|
|
||||||
fee_base_msat=fee_base_msat,
|
|
||||||
fee_proportional_millionths=fee_proportional_millionths,
|
|
||||||
channel_flags=channel_flags,
|
|
||||||
timestamp=timestamp,
|
|
||||||
htlc_maximum_msat=htlc_maximum_msat)
|
|
||||||
|
|
||||||
def is_disabled(self):
|
|
||||||
return self.channel_flags & FLAG_DISABLE
|
|
||||||
|
|
||||||
class NodeInfo(Base):
|
|
||||||
__tablename__ = 'node_info'
|
|
||||||
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
|
||||||
features = Column(Integer, nullable=False)
|
|
||||||
timestamp = Column(Integer, nullable=False)
|
|
||||||
alias = Column(String(64), nullable=False)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_msg(payload):
|
|
||||||
node_id = payload['node_id'].hex()
|
|
||||||
features = int.from_bytes(payload['features'], "big")
|
|
||||||
validate_features(features)
|
|
||||||
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
|
|
||||||
alias = payload['alias'].rstrip(b'\x00').hex()
|
|
||||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
|
||||||
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
|
|
||||||
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_addresses_field(addresses_field):
|
|
||||||
buf = addresses_field
|
|
||||||
def read(n):
|
|
||||||
nonlocal buf
|
|
||||||
data, buf = buf[0:n], buf[n:]
|
|
||||||
return data
|
|
||||||
addresses = []
|
|
||||||
while buf:
|
|
||||||
atype = ord(read(1))
|
|
||||||
if atype == 0:
|
|
||||||
pass
|
|
||||||
elif atype == 1: # IPv4
|
|
||||||
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
|
|
||||||
port = int.from_bytes(read(2), 'big')
|
|
||||||
if is_ip_address(ipv4_addr) and port != 0:
|
|
||||||
addresses.append((ipv4_addr, port))
|
|
||||||
elif atype == 2: # IPv6
|
|
||||||
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
|
|
||||||
ipv6_addr = ipv6_addr.decode('ascii')
|
|
||||||
port = int.from_bytes(read(2), 'big')
|
|
||||||
if is_ip_address(ipv6_addr) and port != 0:
|
|
||||||
addresses.append((ipv6_addr, port))
|
|
||||||
elif atype == 3: # onion v2
|
|
||||||
host = base64.b32encode(read(10)) + b'.onion'
|
|
||||||
host = host.decode('ascii').lower()
|
|
||||||
port = int.from_bytes(read(2), 'big')
|
|
||||||
addresses.append((host, port))
|
|
||||||
elif atype == 4: # onion v3
|
|
||||||
host = base64.b32encode(read(35)) + b'.onion'
|
|
||||||
host = host.decode('ascii').lower()
|
|
||||||
port = int.from_bytes(read(2), 'big')
|
|
||||||
addresses.append((host, port))
|
|
||||||
else:
|
|
||||||
# unknown address type
|
|
||||||
# we don't know how long it is -> have to escape
|
|
||||||
# if there are other addresses we could have parsed later, they are lost.
|
|
||||||
break
|
|
||||||
return addresses
|
|
||||||
|
|
||||||
class Address(Base):
|
|
||||||
__tablename__ = 'address'
|
|
||||||
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
|
||||||
host = Column(String(256), primary_key=True)
|
|
||||||
port = Column(Integer, primary_key=True)
|
|
||||||
last_connected_date = Column(Integer(), nullable=True)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelDB(SqlDB):
|
|
||||||
|
|
||||||
NUM_MAX_RECENT_PEERS = 20
|
|
||||||
|
|
||||||
def __init__(self, network: 'Network'):
|
|
||||||
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
|
||||||
super().__init__(network, path, Base)
|
|
||||||
self.num_nodes = 0
|
|
||||||
self.num_channels = 0
|
|
||||||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
|
||||||
self.ca_verifier = LNChannelVerifier(network, self)
|
|
||||||
self.update_counts()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def update_counts(self):
|
|
||||||
self._update_counts()
|
|
||||||
|
|
||||||
def _update_counts(self):
|
|
||||||
self.num_channels = self.DBSession.query(ChannelInfo).count()
|
|
||||||
self.num_policies = self.DBSession.query(Policy).count()
|
|
||||||
self.num_nodes = self.DBSession.query(NodeInfo).count()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def known_ids(self):
|
|
||||||
known = self.DBSession.query(ChannelInfo.short_channel_id).all()
|
|
||||||
return set(bfh(r.short_channel_id) for r in known)
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def add_recent_peer(self, peer: LNPeerAddr):
|
|
||||||
now = int(time.time())
|
|
||||||
node_id = peer.pubkey.hex()
|
|
||||||
addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
|
|
||||||
if addr:
|
|
||||||
addr.last_connected_date = now
|
|
||||||
else:
|
|
||||||
addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
|
|
||||||
self.DBSession.add(addr)
|
|
||||||
self.DBSession.commit()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
|
|
||||||
unshuffled = self.DBSession \
|
|
||||||
.query(NodeInfo) \
|
|
||||||
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
|
|
||||||
.limit(200) \
|
|
||||||
.all()
|
|
||||||
return random.sample(unshuffled, len(unshuffled))
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def nodes_get(self, node_id):
|
|
||||||
return self.DBSession \
|
|
||||||
.query(NodeInfo) \
|
|
||||||
.filter_by(node_id = node_id.hex()) \
|
|
||||||
.one_or_none()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
|
|
||||||
r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all()
|
|
||||||
if not r:
|
|
||||||
return None
|
|
||||||
addr = r[0]
|
|
||||||
return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id))
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_recent_peers(self):
|
|
||||||
r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all()
|
|
||||||
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def missing_channel_announcements(self) -> Set[int]:
|
|
||||||
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
|
|
||||||
return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def missing_channel_updates(self) -> Set[int]:
|
|
||||||
expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
|
|
||||||
return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def add_verified_channel_info(self, short_id, capacity):
|
|
||||||
# called from lnchannelverifier
|
|
||||||
channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none()
|
|
||||||
channel_info.trusted = True
|
|
||||||
channel_info.capacity = capacity
|
|
||||||
self.DBSession.commit()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
@profiler
|
|
||||||
def on_channel_announcement(self, msg_payloads, trusted=True):
|
|
||||||
if type(msg_payloads) is dict:
|
|
||||||
msg_payloads = [msg_payloads]
|
|
||||||
new_channels = {}
|
|
||||||
for msg in msg_payloads:
|
|
||||||
short_channel_id = bh2u(msg['short_channel_id'])
|
|
||||||
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
|
|
||||||
continue
|
|
||||||
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
|
|
||||||
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
channel_info = ChannelInfo.from_msg(msg)
|
|
||||||
except UnknownEvenFeatureBits:
|
|
||||||
self.logger.info("unknown feature bits")
|
|
||||||
continue
|
|
||||||
channel_info.trusted = trusted
|
|
||||||
new_channels[short_channel_id] = channel_info
|
|
||||||
if not trusted:
|
|
||||||
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
|
|
||||||
for channel_info in new_channels.values():
|
|
||||||
self.DBSession.add(channel_info)
|
|
||||||
self.DBSession.commit()
|
|
||||||
self._update_counts()
|
|
||||||
self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_last_timestamp(self):
|
|
||||||
return self._get_last_timestamp()
|
|
||||||
|
|
||||||
def _get_last_timestamp(self):
|
|
||||||
from sqlalchemy.sql import func
|
|
||||||
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
|
|
||||||
return r.max_timestamp or 0
|
|
||||||
|
|
||||||
def print_change(self, old_policy, new_policy):
|
|
||||||
# print what changed between policies
|
|
||||||
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
|
|
||||||
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
|
|
||||||
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
|
|
||||||
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
|
|
||||||
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
|
|
||||||
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
|
|
||||||
if old_policy.fee_base_msat != new_policy.fee_base_msat:
|
|
||||||
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
|
|
||||||
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
|
|
||||||
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
|
|
||||||
if old_policy.channel_flags != new_policy.channel_flags:
|
|
||||||
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_info_for_updates(self, payloads):
|
|
||||||
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
|
|
||||||
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
|
|
||||||
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
|
|
||||||
return channel_infos
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_policies_for_updates(self, payloads):
|
|
||||||
out = {}
|
|
||||||
for payload in payloads:
|
|
||||||
short_channel_id = payload['short_channel_id'].hex()
|
|
||||||
start_node = payload['start_node'].hex()
|
|
||||||
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
|
|
||||||
if policy:
|
|
||||||
out[short_channel_id+start_node] = policy
|
|
||||||
return out
|
|
||||||
|
|
||||||
@profiler
|
|
||||||
def filter_channel_updates(self, payloads, max_age=None):
|
|
||||||
orphaned = [] # no channel announcement for channel update
|
|
||||||
expired = [] # update older than two weeks
|
|
||||||
deprecated = [] # update older than database entry
|
|
||||||
good = {} # good updates
|
|
||||||
to_delete = [] # database entries to delete
|
|
||||||
# filter orphaned and expired first
|
|
||||||
known = []
|
|
||||||
now = int(time.time())
|
|
||||||
channel_infos = self.get_info_for_updates(payloads)
|
|
||||||
for payload in payloads:
|
|
||||||
short_channel_id = payload['short_channel_id']
|
|
||||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
|
||||||
if max_age and now - timestamp > max_age:
|
|
||||||
expired.append(short_channel_id)
|
|
||||||
continue
|
|
||||||
channel_info = channel_infos.get(short_channel_id)
|
|
||||||
if not channel_info:
|
|
||||||
orphaned.append(short_channel_id)
|
|
||||||
continue
|
|
||||||
flags = int.from_bytes(payload['channel_flags'], 'big')
|
|
||||||
direction = flags & FLAG_DIRECTION
|
|
||||||
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
|
|
||||||
payload['start_node'] = bfh(start_node)
|
|
||||||
known.append(payload)
|
|
||||||
# compare updates to existing database entries
|
|
||||||
old_policies = self.get_policies_for_updates(known)
|
|
||||||
for payload in known:
|
|
||||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
|
||||||
start_node = payload['start_node']
|
|
||||||
short_channel_id = payload['short_channel_id']
|
|
||||||
key = (short_channel_id+start_node).hex()
|
|
||||||
old_policy = old_policies.get(key)
|
|
||||||
if old_policy:
|
|
||||||
if timestamp <= old_policy.timestamp:
|
|
||||||
deprecated.append(short_channel_id)
|
|
||||||
else:
|
|
||||||
good[key] = payload
|
|
||||||
to_delete.append(old_policy)
|
|
||||||
else:
|
|
||||||
good[key] = payload
|
|
||||||
good = list(good.values())
|
|
||||||
return orphaned, expired, deprecated, good, to_delete
|
|
||||||
|
|
||||||
def add_channel_update(self, payload):
|
|
||||||
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
|
|
||||||
assert len(good) == 1
|
|
||||||
self.update_policies(good, to_delete)
|
|
||||||
|
|
||||||
@sql
|
|
||||||
@profiler
|
|
||||||
def update_policies(self, to_add, to_delete):
|
|
||||||
for policy in to_delete:
|
|
||||||
self.DBSession.delete(policy)
|
|
||||||
self.DBSession.commit()
|
|
||||||
for payload in to_add:
|
|
||||||
policy = Policy.from_msg(payload)
|
|
||||||
self.DBSession.add(policy)
|
|
||||||
self.DBSession.commit()
|
|
||||||
self._update_counts()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
@profiler
|
|
||||||
def on_node_announcement(self, msg_payloads):
|
|
||||||
if type(msg_payloads) is dict:
|
|
||||||
msg_payloads = [msg_payloads]
|
|
||||||
old_addr = None
|
|
||||||
new_nodes = {}
|
|
||||||
new_addresses = {}
|
|
||||||
for msg_payload in msg_payloads:
|
|
||||||
try:
|
|
||||||
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
|
|
||||||
except UnknownEvenFeatureBits:
|
|
||||||
continue
|
|
||||||
node_id = node_info.node_id
|
|
||||||
# Ignore node if it has no associated channel (DoS protection)
|
|
||||||
# FIXME this is slow
|
|
||||||
expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id)
|
|
||||||
if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0:
|
|
||||||
#self.logger.info('ignoring orphan node_announcement')
|
|
||||||
continue
|
|
||||||
node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
|
|
||||||
if node and node.timestamp >= node_info.timestamp:
|
|
||||||
continue
|
|
||||||
node = new_nodes.get(node_id)
|
|
||||||
if node and node.timestamp >= node_info.timestamp:
|
|
||||||
continue
|
|
||||||
new_nodes[node_id] = node_info
|
|
||||||
for addr in node_addresses:
|
|
||||||
new_addresses[(addr.node_id,addr.host,addr.port)] = addr
|
|
||||||
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
|
||||||
for node_info in new_nodes.values():
|
|
||||||
self.DBSession.add(node_info)
|
|
||||||
for new_addr in new_addresses.values():
|
|
||||||
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
|
|
||||||
if not old_addr:
|
|
||||||
self.DBSession.add(new_addr)
|
|
||||||
self.DBSession.commit()
|
|
||||||
self._update_counts()
|
|
||||||
|
|
||||||
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
|
||||||
short_channel_id: bytes) -> Optional[bytes]:
|
|
||||||
if not start_node_id or not short_channel_id: return None
|
|
||||||
channel_info = self.get_channel_info(short_channel_id)
|
|
||||||
if channel_info is not None:
|
|
||||||
return self.get_policy_for_node(short_channel_id, start_node_id)
|
|
||||||
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
|
||||||
if not msg:
|
|
||||||
return None
|
|
||||||
return Policy.from_msg(msg) # won't actually be written to DB
|
|
||||||
|
|
||||||
@sql
|
|
||||||
@profiler
|
|
||||||
def get_old_policies(self, delta):
|
|
||||||
timestamp = int(time.time()) - delta
|
|
||||||
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
|
|
||||||
return old_policies.distinct().count()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
@profiler
|
|
||||||
def prune_old_policies(self, delta):
|
|
||||||
# note: delete queries are order sensitive
|
|
||||||
timestamp = int(time.time()) - delta
|
|
||||||
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
|
|
||||||
delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies))
|
|
||||||
delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp)
|
|
||||||
self.DBSession.execute(delete_old_channels)
|
|
||||||
self.DBSession.execute(delete_old_policies)
|
|
||||||
self.DBSession.commit()
|
|
||||||
self._update_counts()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
@profiler
|
|
||||||
def get_orphaned_channels(self):
|
|
||||||
subquery = self.DBSession.query(Policy.short_channel_id)
|
|
||||||
orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery)))
|
|
||||||
return orphaned.count()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
@profiler
|
|
||||||
def prune_orphaned_channels(self):
|
|
||||||
subquery = self.DBSession.query(Policy.short_channel_id)
|
|
||||||
delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery)))
|
|
||||||
self.DBSession.execute(delete_orphaned)
|
|
||||||
self.DBSession.commit()
|
|
||||||
self._update_counts()
|
|
||||||
|
|
||||||
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
|
|
||||||
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
|
||||||
return # ignore
|
|
||||||
short_channel_id = msg_payload['short_channel_id']
|
|
||||||
msg_payload['start_node'] = start_node_id
|
|
||||||
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def remove_channel(self, short_channel_id):
|
|
||||||
r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none()
|
|
||||||
if not r:
|
|
||||||
return
|
|
||||||
self.DBSession.delete(r)
|
|
||||||
self.DBSession.commit()
|
|
||||||
|
|
||||||
def print_graph(self, full_ids=False):
|
|
||||||
# used for debugging.
|
|
||||||
# FIXME there is a race here - iterables could change size from another thread
|
|
||||||
def other_node_id(node_id, channel_id):
|
|
||||||
channel_info = self.get_channel_info(channel_id)
|
|
||||||
if node_id == channel_info.node1_id:
|
|
||||||
other = channel_info.node2_id
|
|
||||||
else:
|
|
||||||
other = channel_info.node1_id
|
|
||||||
return other if full_ids else other[-4:]
|
|
||||||
|
|
||||||
print_msg('nodes')
|
|
||||||
for node in self.DBSession.query(NodeInfo).all():
|
|
||||||
print_msg(node)
|
|
||||||
|
|
||||||
print_msg('channels')
|
|
||||||
for channel_info in self.DBSession.query(ChannelInfo).all():
|
|
||||||
short_channel_id = channel_info.short_channel_id
|
|
||||||
node1 = channel_info.node1_id
|
|
||||||
node2 = channel_info.node2_id
|
|
||||||
direction1 = self.get_policy_for_node(channel_info, node1) is not None
|
|
||||||
direction2 = self.get_policy_for_node(channel_info, node2) is not None
|
|
||||||
if direction1 and direction2:
|
|
||||||
direction = 'both'
|
|
||||||
elif direction1:
|
|
||||||
direction = 'forward'
|
|
||||||
elif direction2:
|
|
||||||
direction = 'backward'
|
|
||||||
else:
|
|
||||||
direction = 'none'
|
|
||||||
print_msg('{}: {}, {}, {}'
|
|
||||||
.format(bh2u(short_channel_id),
|
|
||||||
bh2u(node1) if full_ids else bh2u(node1[-4:]),
|
|
||||||
bh2u(node2) if full_ids else bh2u(node2[-4:]),
|
|
||||||
direction))
|
|
||||||
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_node_addresses(self, node_info):
|
|
||||||
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
@profiler
|
|
||||||
def load_data(self):
|
|
||||||
r = self.DBSession.query(ChannelInfo).all()
|
|
||||||
self._channels = dict([(bfh(x.short_channel_id), x) for x in r])
|
|
||||||
r = self.DBSession.query(Policy).filter_by().all()
|
|
||||||
self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r])
|
|
||||||
self._channels_for_node = defaultdict(set)
|
|
||||||
for channel_info in self._channels.values():
|
|
||||||
self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id))
|
|
||||||
self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id))
|
|
||||||
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
|
|
||||||
|
|
||||||
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
|
|
||||||
return self._policies.get((node_id, short_channel_id))
|
|
||||||
|
|
||||||
def get_channel_info(self, channel_id: bytes):
|
|
||||||
return self._channels.get(channel_id)
|
|
||||||
|
|
||||||
def get_channels_for_node(self, node_id) -> Set[bytes]:
|
|
||||||
"""Returns the set of channels that have node_id as one of the endpoints."""
|
|
||||||
return self._channels_for_node.get(node_id) or set()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
|
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
|
||||||
('short_channel_id', bytes),
|
('short_channel_id', bytes),
|
||||||
|
|
|
@ -297,15 +297,17 @@ class Network(Logger):
|
||||||
self._set_status('disconnected')
|
self._set_status('disconnected')
|
||||||
|
|
||||||
# lightning network
|
# lightning network
|
||||||
from . import lnwatcher
|
|
||||||
from . import lnworker
|
|
||||||
from . import lnrouter
|
|
||||||
if self.config.get('lightning'):
|
if self.config.get('lightning'):
|
||||||
self.channel_db = lnrouter.ChannelDB(self)
|
from . import lnwatcher
|
||||||
|
from . import lnworker
|
||||||
|
from . import lnrouter
|
||||||
|
from . import channel_db
|
||||||
|
self.channel_db = channel_db.ChannelDB(self)
|
||||||
self.path_finder = lnrouter.LNPathFinder(self.channel_db)
|
self.path_finder = lnrouter.LNPathFinder(self.channel_db)
|
||||||
self.lnwatcher = lnwatcher.LNWatcher(self)
|
self.lnwatcher = lnwatcher.LNWatcher(self)
|
||||||
self.lngossip = lnworker.LNGossip(self)
|
self.lngossip = lnworker.LNGossip(self)
|
||||||
else:
|
else:
|
||||||
|
self.channel_db = None
|
||||||
self.lnwatcher = None
|
self.lnwatcher = None
|
||||||
self.lngossip = None
|
self.lngossip = None
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue