LBRY-Vault/electrum/channel_db.py
2019-12-10 01:14:38 +01:00

598 lines
24 KiB
Python

# -*- coding: utf-8 -*-
#
# Electrum - lightweight Bitcoin client
# Copyright (C) 2018 The Electrum developers
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import time
import random
import os
from collections import defaultdict
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set
import binascii
import base64
import asyncio
from .sql_db import SqlDB, sql
from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
from .logging import Logger
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, format_short_channel_id, ShortChannelID
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
if TYPE_CHECKING:
from .network import Network
class UnknownEvenFeatureBits(Exception): pass
def validate_features(features : int):
enabled_features = list_enabled_bits(features)
for fbit in enabled_features:
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0:
raise UnknownEvenFeatureBits()
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
class ChannelInfo(NamedTuple):
short_channel_id: ShortChannelID
node1_id: bytes
node2_id: bytes
capacity_sat: Optional[int]
@staticmethod
def from_msg(payload):
features = int.from_bytes(payload['features'], 'big')
validate_features(features)
channel_id = payload['short_channel_id']
node_id_1 = payload['node_id_1']
node_id_2 = payload['node_id_2']
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
capacity_sat = None
return ChannelInfo(
short_channel_id = ShortChannelID.normalize(channel_id),
node1_id = node_id_1,
node2_id = node_id_2,
capacity_sat = capacity_sat
)
class Policy(NamedTuple):
key: bytes
cltv_expiry_delta: int
htlc_minimum_msat: int
htlc_maximum_msat: Optional[int]
fee_base_msat: int
fee_proportional_millionths: int
channel_flags: int
message_flags: int
timestamp: int
@staticmethod
def from_msg(payload):
return Policy(
key = payload['short_channel_id'] + payload['start_node'],
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"),
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None,
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"),
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"),
message_flags = int.from_bytes(payload['message_flags'], "big"),
channel_flags = int.from_bytes(payload['channel_flags'], "big"),
timestamp = int.from_bytes(payload['timestamp'], "big")
)
def is_disabled(self):
return self.channel_flags & FLAG_DISABLE
@property
def short_channel_id(self) -> ShortChannelID:
return ShortChannelID.normalize(self.key[0:8])
@property
def start_node(self):
return self.key[8:]
class NodeInfo(NamedTuple):
node_id: bytes
features: int
timestamp: int
alias: str
@staticmethod
def from_msg(payload):
node_id = payload['node_id']
features = int.from_bytes(payload['features'], "big")
validate_features(features)
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
alias = payload['alias'].rstrip(b'\x00')
timestamp = int.from_bytes(payload['timestamp'], "big")
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses]
@staticmethod
def parse_addresses_field(addresses_field):
buf = addresses_field
def read(n):
nonlocal buf
data, buf = buf[0:n], buf[n:]
return data
addresses = []
while buf:
atype = ord(read(1))
if atype == 0:
pass
elif atype == 1: # IPv4
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4)))
port = int.from_bytes(read(2), 'big')
if is_ip_address(ipv4_addr) and port != 0:
addresses.append((ipv4_addr, port))
elif atype == 2: # IPv6
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)])
ipv6_addr = ipv6_addr.decode('ascii')
port = int.from_bytes(read(2), 'big')
if is_ip_address(ipv6_addr) and port != 0:
addresses.append((ipv6_addr, port))
elif atype == 3: # onion v2
host = base64.b32encode(read(10)) + b'.onion'
host = host.decode('ascii').lower()
port = int.from_bytes(read(2), 'big')
addresses.append((host, port))
elif atype == 4: # onion v3
host = base64.b32encode(read(35)) + b'.onion'
host = host.decode('ascii').lower()
port = int.from_bytes(read(2), 'big')
addresses.append((host, port))
else:
# unknown address type
# we don't know how long it is -> have to escape
# if there are other addresses we could have parsed later, they are lost.
break
return addresses
class Address(NamedTuple):
node_id: bytes
host: str
port: int
last_connected_date: Optional[int]
class CategorizedChannelUpdates(NamedTuple):
orphaned: List # no channel announcement for channel update
expired: List # update older than two weeks
deprecated: List # update older than database entry
good: List # good updates
to_delete: List # database entries to delete
# TODO It would make more sense to store the raw gossip messages in the db.
# That is pretty much a pre-requisite of actively participating in gossip.
create_channel_info = """
CREATE TABLE IF NOT EXISTS channel_info (
short_channel_id VARCHAR(64),
node1_id VARCHAR(66),
node2_id VARCHAR(66),
capacity_sat INTEGER,
PRIMARY KEY(short_channel_id)
)"""
create_policy = """
CREATE TABLE IF NOT EXISTS policy (
key VARCHAR(66),
cltv_expiry_delta INTEGER NOT NULL,
htlc_minimum_msat INTEGER NOT NULL,
htlc_maximum_msat INTEGER,
fee_base_msat INTEGER NOT NULL,
fee_proportional_millionths INTEGER NOT NULL,
channel_flags INTEGER NOT NULL,
message_flags INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
PRIMARY KEY(key)
)"""
create_address = """
CREATE TABLE IF NOT EXISTS address (
node_id VARCHAR(66),
host STRING(256),
port INTEGER NOT NULL,
timestamp INTEGER,
PRIMARY KEY(node_id, host, port)
)"""
create_node_info = """
CREATE TABLE IF NOT EXISTS node_info (
node_id VARCHAR(66),
features INTEGER NOT NULL,
timestamp INTEGER NOT NULL,
alias STRING(64),
PRIMARY KEY(node_id)
)"""
class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, commit_interval=100)
self.num_nodes = 0
self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
# initialized in load_data
self._channels = {} # type: Dict[bytes, ChannelInfo]
self._policies = {}
self._nodes = {}
# node_id -> (host, port, ts)
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
self._channels_for_node = defaultdict(set)
self.data_loaded = asyncio.Event()
self.network = network # only for callback
def update_counts(self):
self.num_nodes = len(self._nodes)
self.num_channels = len(self._channels)
self.num_policies = len(self._policies)
self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
def get_channel_ids(self):
return set(self._channels.keys())
def add_recent_peer(self, peer: LNPeerAddr):
now = int(time.time())
node_id = peer.pubkey
self._addresses[node_id].add((peer.host, peer.port, now))
self.save_node_address(node_id, peer, now)
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
unshuffled = set(self._nodes.keys()) - node_ids
return random.sample(unshuffled, min(200, len(unshuffled)))
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
r = self._addresses.get(node_id)
if not r:
return None
addr = sorted(list(r), key=lambda x: x[2])[0]
host, port, timestamp = addr
try:
return LNPeerAddr(host, port, node_id)
except ValueError:
return None
def get_recent_peers(self):
assert self.data_loaded.is_set(), "channelDB load_data did not finish yet!"
# FIXME this does not reliably return "recent" peers...
# Also, the list() cast over the whole dict (thousands of elements),
# is really inefficient.
r = [self.get_last_good_address(node_id)
for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]]
return list(reversed(r))
# note: currently channel announcements are trusted by default (trusted=True);
# they are not verified. Verifying them would make the gossip sync
# even slower; especially as servers will start throttling us.
# It would probably put significant strain on servers if all clients
# verified the complete gossip.
def add_channel_announcement(self, msg_payloads, *, trusted=True):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
added = 0
for msg in msg_payloads:
short_channel_id = ShortChannelID(msg['short_channel_id'])
if short_channel_id in self._channels:
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
continue
try:
channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
self.logger.info("unknown feature bits")
continue
if trusted:
added += 1
self.add_verified_channel_info(msg)
else:
added += self.ca_verifier.add_new_channel_info(short_channel_id, msg)
self.update_counts()
self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
def add_verified_channel_info(self, msg: dict, *, capacity_sat: int = None) -> None:
try:
channel_info = ChannelInfo.from_msg(msg)
except UnknownEvenFeatureBits:
return
channel_info = channel_info._replace(capacity_sat=capacity_sat)
self._channels[channel_info.short_channel_id] = channel_info
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.save_channel(channel_info)
def print_change(self, old_policy: Policy, new_policy: Policy):
# print what changed between policies
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
if old_policy.fee_base_msat != new_policy.fee_base_msat:
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
if old_policy.channel_flags != new_policy.channel_flags:
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
if old_policy.message_flags != new_policy.message_flags:
self.logger.info(f'message_flags: {old_policy.message_flags} -> {new_policy.message_flags}')
def add_channel_updates(self, payloads, max_age=None, verify=True) -> CategorizedChannelUpdates:
orphaned = []
expired = []
deprecated = []
good = []
to_delete = []
# filter orphaned and expired first
known = []
now = int(time.time())
for payload in payloads:
short_channel_id = ShortChannelID(payload['short_channel_id'])
timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age:
expired.append(payload)
continue
channel_info = self._channels.get(short_channel_id)
if not channel_info:
orphaned.append(payload)
continue
flags = int.from_bytes(payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
payload['start_node'] = start_node
known.append(payload)
# compare updates to existing database entries
for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node']
short_channel_id = ShortChannelID(payload['short_channel_id'])
key = (start_node, short_channel_id)
old_policy = self._policies.get(key)
if old_policy and timestamp <= old_policy.timestamp:
deprecated.append(payload)
continue
good.append(payload)
if verify:
self.verify_channel_update(payload)
policy = Policy.from_msg(payload)
self._policies[key] = policy
self.save_policy(policy)
#
self.update_counts()
return CategorizedChannelUpdates(
orphaned=orphaned,
expired=expired,
deprecated=deprecated,
good=good,
to_delete=to_delete,
)
def add_channel_update(self, payload):
# called from add_own_channel
# the update may be categorized as deprecated because of caching
categorized_chan_upds = self.add_channel_updates([payload], verify=False)
def create_database(self):
c = self.conn.cursor()
c.execute(create_node_info)
c.execute(create_address)
c.execute(create_policy)
c.execute(create_channel_info)
self.conn.commit()
@sql
def save_policy(self, policy):
c = self.conn.cursor()
c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, message_flags, timestamp) VALUES (?,?,?,?,?,?,?,?,?)""", list(policy))
@sql
def delete_policy(self, node_id, short_channel_id):
key = short_channel_id + node_id
c = self.conn.cursor()
c.execute("""DELETE FROM policy WHERE key=?""", (key,))
@sql
def save_channel(self, channel_info):
c = self.conn.cursor()
c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info))
@sql
def delete_channel(self, short_channel_id):
c = self.conn.cursor()
c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,))
@sql
def save_node(self, node_info):
c = self.conn.cursor()
c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info))
@sql
def save_node_address(self, node_id, peer, now):
c = self.conn.cursor()
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now))
@sql
def save_node_addresses(self, node_id, node_addresses):
c = self.conn.cursor()
for addr in node_addresses:
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port))
r = c.fetchall()
if r == []:
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0))
def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id']
short_channel_id = ShortChannelID(short_channel_id)
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']):
raise Exception(f'failed verifying channel update for {short_channel_id}')
def add_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
old_addr = None
new_nodes = {}
for msg_payload in msg_payloads:
try:
node_info, node_addresses = NodeInfo.from_msg(msg_payload)
except UnknownEvenFeatureBits:
continue
node_id = node_info.node_id
# Ignore node if it has no associated channel (DoS protection)
if node_id not in self._channels_for_node:
#self.logger.info('ignoring orphan node_announcement')
continue
node = self._nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
node = new_nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
# save
self._nodes[node_id] = node_info
self.save_node(node_info)
for addr in node_addresses:
self._addresses[node_id].add((addr.host, addr.port, 0))
self.save_node_addresses(node_id, node_addresses)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
self.update_counts()
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[Policy]:
if not start_node_id or not short_channel_id: return None
channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None:
return self.get_policy_for_node(short_channel_id, start_node_id)
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg:
return None
return Policy.from_msg(msg) # won't actually be written to DB
def get_old_policies(self, delta):
now = int(time.time())
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
def prune_old_policies(self, delta):
l = self.get_old_policies(delta)
if l:
for k in l:
self._policies.pop(k)
self.delete_policy(*k)
self.update_counts()
self.logger.info(f'Deleting {len(l)} old policies')
def get_orphaned_channels(self):
ids = set(x[1] for x in self._policies.keys())
return list(x for x in self._channels.keys() if x not in ids)
def prune_orphaned_channels(self):
l = self.get_orphaned_channels()
if l:
for short_channel_id in l:
self.remove_channel(short_channel_id)
self.update_counts()
self.logger.info(f'Deleting {len(l)} orphaned channels')
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore
short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
def remove_channel(self, short_channel_id: ShortChannelID):
channel_info = self._channels.pop(short_channel_id, None)
if channel_info:
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
# delete from database
self.delete_channel(short_channel_id)
def get_node_addresses(self, node_id):
return self._addresses.get(node_id)
@sql
@profiler
def load_data(self):
c = self.conn.cursor()
c.execute("""SELECT * FROM address""")
for x in c:
node_id, host, port, timestamp = x
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0)))
c.execute("""SELECT * FROM channel_info""")
for x in c:
x = (ShortChannelID.normalize(x[0]), *x[1:])
ci = ChannelInfo(*x)
self._channels[ci.short_channel_id] = ci
c.execute("""SELECT * FROM node_info""")
for x in c:
ni = NodeInfo(*x)
self._nodes[ni.node_id] = ni
c.execute("""SELECT * FROM policy""")
for x in c:
p = Policy(*x)
self._policies[(p.start_node, p.short_channel_id)] = p
for channel_info in self._channels.values():
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
self.update_counts()
self.count_incomplete_channels()
self.data_loaded.set()
def count_incomplete_channels(self):
out = set()
for short_channel_id, ci in self._channels.items():
p1 = self.get_policy_for_node(short_channel_id, ci.node1_id)
p2 = self.get_policy_for_node(short_channel_id, ci.node2_id)
if p1 is None or p2 is not None:
out.add(short_channel_id)
self.logger.info(f'semi-orphaned: {len(out)}')
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
return self._policies.get((node_id, short_channel_id))
def get_channel_info(self, channel_id: bytes) -> ChannelInfo:
return self._channels.get(channel_id)
def get_channels_for_node(self, node_id) -> Set[bytes]:
"""Returns the set of channels that have node_id as one of the endpoints."""
return self._channels_for_node.get(node_id) or set()