optimize channel_db:

- use python objects mirrored by sql database
 - write sql to file asynchronously
 - the sql decorator is awaited in sweepstore, not in channel_db
This commit is contained in:
ThomasV 2019-06-18 13:49:31 +02:00
parent 180f6d34be
commit f2d58d0e3f
11 changed files with 435 additions and 454 deletions

View file

@ -51,6 +51,7 @@ from .crypto import sha256d
from . import ecc
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
NotFoundChanAnnouncementForUpdate)
from .lnverifier import verify_sig_for_channel_update
from .lnmsg import encode_msg
if TYPE_CHECKING:
@ -70,85 +71,83 @@ Base = declarative_base()
FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0
class ChannelInfo(Base):
__tablename__ = 'channel_info'
short_channel_id = Column(String(64), primary_key=True)
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
capacity_sat = Column(Integer)
msg_payload_hex = Column(String(1024), nullable=False)
trusted = Column(Boolean, nullable=False)
class ChannelInfo(NamedTuple):
short_channel_id: bytes
node1_id: bytes
node2_id: bytes
capacity_sat: int
msg_payload: bytes
trusted: bool
@staticmethod
def from_msg(payload):
features = int.from_bytes(payload['features'], 'big')
validate_features(features)
channel_id = payload['short_channel_id'].hex()
node_id_1 = payload['node_id_1'].hex()
node_id_2 = payload['node_id_2'].hex()
channel_id = payload['short_channel_id']
node_id_1 = payload['node_id_1']
node_id_2 = payload['node_id_2']
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
msg_payload = encode_msg('channel_announcement', **payload)
capacity_sat = None
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
trusted = False)
@property
def msg_payload(self):
return bytes.fromhex(self.msg_payload_hex)
return ChannelInfo(
short_channel_id = channel_id,
node1_id = node_id_1,
node2_id = node_id_2,
capacity_sat = capacity_sat,
msg_payload = msg_payload,
trusted = False)
class Policy(Base):
__tablename__ = 'policy'
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True)
cltv_expiry_delta = Column(Integer, nullable=False)
htlc_minimum_msat = Column(Integer, nullable=False)
htlc_maximum_msat = Column(Integer)
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
class Policy(NamedTuple):
key: bytes
cltv_expiry_delta: int
htlc_minimum_msat: int
htlc_maximum_msat: int
fee_base_msat: int
fee_proportional_millionths: int
channel_flags: int
timestamp: int
@staticmethod
def from_msg(payload):
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
channel_flags = int.from_bytes(payload['channel_flags'], "big")
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node'].hex()
short_channel_id = payload['short_channel_id'].hex()
return Policy(start_node=start_node,
short_channel_id=short_channel_id,
cltv_expiry_delta=cltv_expiry_delta,
htlc_minimum_msat=htlc_minimum_msat,
fee_base_msat=fee_base_msat,
fee_proportional_millionths=fee_proportional_millionths,
channel_flags=channel_flags,
timestamp=timestamp,
htlc_maximum_msat=htlc_maximum_msat)
return Policy(
key = payload['short_channel_id'] + payload['start_node'],
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"),
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None,
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"),
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"),
channel_flags = int.from_bytes(payload['channel_flags'], "big"),
timestamp = int.from_bytes(payload['timestamp'], "big")
)
def is_disabled(self):
return self.channel_flags & FLAG_DISABLE
class NodeInfo(Base):
__tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
alias = Column(String(64), nullable=False)
@property
def short_channel_id(self):
return self.key[0:8]
@property
def start_node(self):
return self.key[8:]
class NodeInfo(NamedTuple):
node_id: bytes
features: int
timestamp: int
alias: str
@staticmethod
def from_msg(payload):
node_id = payload['node_id'].hex()
node_id = payload['node_id']
features = int.from_bytes(payload['features'], "big")
validate_features(features)
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
alias = payload['alias'].rstrip(b'\x00').hex()
alias = payload['alias'].rstrip(b'\x00')
timestamp = int.from_bytes(payload['timestamp'], "big")
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses]
@ -193,110 +192,136 @@ class NodeInfo(Base):
break
return addresses
class Address(Base):
class Address(NamedTuple):
node_id: bytes
host: str
port: int
last_connected_date: int
class ChannelInfoBase(Base):
__tablename__ = 'channel_info'
short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
capacity_sat = Column(Integer)
msg_payload = Column(String(1024), nullable=False)
trusted = Column(Boolean, nullable=False)
def to_nametuple(self):
return ChannelInfo(
short_channel_id=self.short_channel_id,
node1_id=self.node1_id,
node2_id=self.node2_id,
capacity_sat=self.capacity_sat,
msg_payload=self.msg_payload,
trusted=self.trusted
)
class PolicyBase(Base):
__tablename__ = 'policy'
key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
cltv_expiry_delta = Column(Integer, nullable=False)
htlc_minimum_msat = Column(Integer, nullable=False)
htlc_maximum_msat = Column(Integer)
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
def to_nametuple(self):
return Policy(
key=self.key,
cltv_expiry_delta=self.cltv_expiry_delta,
htlc_minimum_msat=self.htlc_minimum_msat,
htlc_maximum_msat=self.htlc_maximum_msat,
fee_base_msat= self.fee_base_msat,
fee_proportional_millionths = self.fee_proportional_millionths,
channel_flags=self.channel_flags,
timestamp=self.timestamp
)
class NodeInfoBase(Base):
__tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
alias = Column(String(64), nullable=False)
class AddressBase(Base):
__tablename__ = 'address'
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
host = Column(String(256), primary_key=True)
port = Column(Integer, primary_key=True)
last_connected_date = Column(Integer(), nullable=True)
class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base)
super().__init__(network, path, Base, commit_interval=100)
self.num_nodes = 0
self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self)
self.update_counts()
# initialized in load_data
self._channels = {}
self._policies = {}
self._nodes = {}
self._addresses = defaultdict(set)
self._channels_for_node = defaultdict(set)
@sql
def update_counts(self):
self._update_counts()
self.num_channels = len(self._channels)
self.num_policies = len(self._policies)
self.num_nodes = len(self._nodes)
def _update_counts(self):
self.num_channels = self.DBSession.query(ChannelInfo).count()
self.num_policies = self.DBSession.query(Policy).count()
self.num_nodes = self.DBSession.query(NodeInfo).count()
def get_channel_ids(self):
return set(self._channels.keys())
@sql
def known_ids(self):
known = self.DBSession.query(ChannelInfo.short_channel_id).all()
return set(bfh(r.short_channel_id) for r in known)
@sql
def add_recent_peer(self, peer: LNPeerAddr):
now = int(time.time())
node_id = peer.pubkey.hex()
addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
node_id = peer.pubkey
self._addresses[node_id].add((peer.host, peer.port, now))
self.save_address(node_id, peer, now)
@sql
def save_address(self, node_id, peer, now):
addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
if addr:
addr.last_connected_date = now
else:
addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
self.DBSession.add(addr)
self.DBSession.commit()
@sql
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
unshuffled = self.DBSession \
.query(NodeInfo) \
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
.limit(200) \
.all()
return random.sample(unshuffled, len(unshuffled))
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
unshuffled = set(self._nodes.keys()) - node_ids
return random.sample(unshuffled, min(200, len(unshuffled)))
@sql
def nodes_get(self, node_id):
return self.DBSession \
.query(NodeInfo) \
.filter_by(node_id = node_id.hex()) \
.one_or_none()
@sql
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all()
r = self._addresses.get(node_id)
if not r:
return None
addr = r[0]
return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id))
addr = sorted(list(r), key=lambda x: x[2])[0]
host, port, timestamp = addr
return LNPeerAddr(host, port, node_id)
@sql
def get_recent_peers(self):
r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all()
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
r = [self.get_last_good_address(x) for x in self._addresses.keys()]
r = r[-self.NUM_MAX_RECENT_PEERS:]
return r
@sql
def missing_channel_announcements(self) -> Set[int]:
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
@sql
def missing_channel_updates(self) -> Set[int]:
expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
@sql
def add_verified_channel_info(self, short_id, capacity):
# called from lnchannelverifier
channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none()
channel_info.trusted = True
channel_info.capacity = capacity
self.DBSession.commit()
@sql
@profiler
def on_channel_announcement(self, msg_payloads, trusted=True):
def add_channel_announcement(self, msg_payloads, trusted=True):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
new_channels = {}
added = 0
for msg in msg_payloads:
short_channel_id = bh2u(msg['short_channel_id'])
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
short_channel_id = msg['short_channel_id']
if short_channel_id in self._channels:
continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
@ -306,24 +331,24 @@ class ChannelDB(SqlDB):
except UnknownEvenFeatureBits:
self.logger.info("unknown feature bits")
continue
channel_info.trusted = trusted
new_channels[short_channel_id] = channel_info
#channel_info.trusted = trusted
added += 1
self._channels[short_channel_id] = channel_info
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.save_channel(channel_info)
if not trusted:
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
for channel_info in new_channels.values():
self.DBSession.add(channel_info)
self.DBSession.commit()
self._update_counts()
self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
@sql
def get_last_timestamp(self):
return self._get_last_timestamp()
self.update_counts()
self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
def _get_last_timestamp(self):
from sqlalchemy.sql import func
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
return r.max_timestamp or 0
#def add_verified_channel_info(self, short_id, capacity):
# # called from lnchannelverifier
# channel_info = self.DBSession.query(ChannelInfoBase).filter_by(short_channel_id = short_id).one_or_none()
# channel_info.trusted = True
# channel_info.capacity = capacity
def print_change(self, old_policy, new_policy):
# print what changed between policies
@ -340,89 +365,74 @@ class ChannelDB(SqlDB):
if old_policy.channel_flags != new_policy.channel_flags:
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
@sql
def get_info_for_updates(self, payloads):
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
return channel_infos
@sql
def get_policies_for_updates(self, payloads):
out = {}
for payload in payloads:
short_channel_id = payload['short_channel_id'].hex()
start_node = payload['start_node'].hex()
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
if policy:
out[short_channel_id+start_node] = policy
return out
@profiler
def filter_channel_updates(self, payloads, max_age=None):
def add_channel_updates(self, payloads, max_age=None, verify=True):
orphaned = [] # no channel announcement for channel update
expired = [] # update older than two weeks
deprecated = [] # update older than database entry
good = {} # good updates
good = [] # good updates
to_delete = [] # database entries to delete
# filter orphaned and expired first
known = []
now = int(time.time())
channel_infos = self.get_info_for_updates(payloads)
for payload in payloads:
short_channel_id = payload['short_channel_id']
timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age:
expired.append(short_channel_id)
continue
channel_info = channel_infos.get(short_channel_id)
channel_info = self._channels.get(short_channel_id)
if not channel_info:
orphaned.append(short_channel_id)
continue
flags = int.from_bytes(payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
payload['start_node'] = bfh(start_node)
payload['start_node'] = start_node
known.append(payload)
# compare updates to existing database entries
old_policies = self.get_policies_for_updates(known)
for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node']
short_channel_id = payload['short_channel_id']
key = (short_channel_id+start_node).hex()
old_policy = old_policies.get(key)
if old_policy:
if timestamp <= old_policy.timestamp:
deprecated.append(short_channel_id)
else:
good[key] = payload
to_delete.append(old_policy)
else:
good[key] = payload
good = list(good.values())
key = (start_node, short_channel_id)
old_policy = self._policies.get(key)
if old_policy and timestamp <= old_policy.timestamp:
deprecated.append(short_channel_id)
continue
good.append(payload)
if verify:
self.verify_channel_update(payload)
policy = Policy.from_msg(payload)
self._policies[key] = policy
self.save_policy(policy)
#
self.update_counts()
return orphaned, expired, deprecated, good, to_delete
def add_channel_update(self, payload):
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
assert len(good) == 1
self.update_policies(good, to_delete)
@sql
@profiler
def update_policies(self, to_add, to_delete):
for policy in to_delete:
self.DBSession.delete(policy)
self.DBSession.commit()
for payload in to_add:
policy = Policy.from_msg(payload)
self.DBSession.add(policy)
self.DBSession.commit()
self._update_counts()
def save_policy(self, policy):
self.DBSession.execute(PolicyBase.__table__.insert().values(policy))
@sql
@profiler
def on_node_announcement(self, msg_payloads):
def delete_policy(self, short_channel_id, node_id):
self.DBSession.execute(PolicyBase.__table__.delete().values(policy))
@sql
def save_channel(self, channel_info):
self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info))
def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id']
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']):
raise BaseException('verify error')
def add_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict:
msg_payloads = [msg_payloads]
old_addr = None
@ -435,29 +445,35 @@ class ChannelDB(SqlDB):
continue
node_id = node_info.node_id
# Ignore node if it has no associated channel (DoS protection)
# FIXME this is slow
expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id)
if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0:
if node_id not in self._channels_for_node:
#self.logger.info('ignoring orphan node_announcement')
continue
node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
node = self._nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
node = new_nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp:
continue
new_nodes[node_id] = node_info
# save
self._nodes[node_id] = node_info
self.save_node(node_info)
for addr in node_addresses:
new_addresses[(addr.node_id,addr.host,addr.port)] = addr
self._addresses[node_id].add((addr.host, addr.port, 0))
self.save_node_addresses(node_id, node_addresses)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
for node_info in new_nodes.values():
self.DBSession.add(node_info)
for new_addr in new_addresses.values():
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
self.update_counts()
@sql
def save_node_addresses(self, node_if, node_addresses):
for new_addr in node_addresses:
old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
if not old_addr:
self.DBSession.add(new_addr)
self.DBSession.commit()
self._update_counts()
self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
@sql
def save_node(self, node_info):
self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[bytes]:
@ -470,41 +486,28 @@ class ChannelDB(SqlDB):
return None
return Policy.from_msg(msg) # won't actually be written to DB
@sql
@profiler
def get_old_policies(self, delta):
timestamp = int(time.time()) - delta
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
return old_policies.distinct().count()
now = int(time.time())
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
@sql
@profiler
def prune_old_policies(self, delta):
# note: delete queries are order sensitive
timestamp = int(time.time()) - delta
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies))
delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp)
self.DBSession.execute(delete_old_channels)
self.DBSession.execute(delete_old_policies)
self.DBSession.commit()
self._update_counts()
l = self.get_old_policies(delta)
for k in l:
self._policies.pop(k)
if l:
self.logger.info(f'Deleting {len(l)} old policies')
@sql
@profiler
def get_orphaned_channels(self):
subquery = self.DBSession.query(Policy.short_channel_id)
orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery)))
return orphaned.count()
ids = set(x[1] for x in self._policies.keys())
return list(x for x in self._channels.keys() if x not in ids)
@sql
@profiler
def prune_orphaned_channels(self):
subquery = self.DBSession.query(Policy.short_channel_id)
delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery)))
self.DBSession.execute(delete_orphaned)
self.DBSession.commit()
self._update_counts()
l = self.get_orphaned_channels()
for k in l:
self._channels.pop(k)
self.update_counts()
if l:
self.logger.info(f'Deleting {len(l)} orphaned channels')
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id):
@ -513,67 +516,27 @@ class ChannelDB(SqlDB):
msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
@sql
def remove_channel(self, short_channel_id):
r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none()
if not r:
return
self.DBSession.delete(r)
self.DBSession.commit()
self._channels.pop(short_channel_id, None)
def print_graph(self, full_ids=False):
# used for debugging.
# FIXME there is a race here - iterables could change size from another thread
def other_node_id(node_id, channel_id):
channel_info = self.get_channel_info(channel_id)
if node_id == channel_info.node1_id:
other = channel_info.node2_id
else:
other = channel_info.node1_id
return other if full_ids else other[-4:]
print_msg('nodes')
for node in self.DBSession.query(NodeInfo).all():
print_msg(node)
print_msg('channels')
for channel_info in self.DBSession.query(ChannelInfo).all():
short_channel_id = channel_info.short_channel_id
node1 = channel_info.node1_id
node2 = channel_info.node2_id
direction1 = self.get_policy_for_node(channel_info, node1) is not None
direction2 = self.get_policy_for_node(channel_info, node2) is not None
if direction1 and direction2:
direction = 'both'
elif direction1:
direction = 'forward'
elif direction2:
direction = 'backward'
else:
direction = 'none'
print_msg('{}: {}, {}, {}'
.format(bh2u(short_channel_id),
bh2u(node1) if full_ids else bh2u(node1[-4:]),
bh2u(node2) if full_ids else bh2u(node2[-4:]),
direction))
@sql
def get_node_addresses(self, node_info):
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
def get_node_addresses(self, node_id):
return self._addresses.get(node_id)
@sql
@profiler
def load_data(self):
r = self.DBSession.query(ChannelInfo).all()
self._channels = dict([(bfh(x.short_channel_id), x) for x in r])
r = self.DBSession.query(Policy).filter_by().all()
self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r])
self._channels_for_node = defaultdict(set)
for x in self.DBSession.query(AddressBase).all():
self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0)))
for x in self.DBSession.query(ChannelInfoBase).all():
self._channels[x.short_channel_id] = x.to_nametuple()
for x in self.DBSession.query(PolicyBase).filter_by().all():
p = x.to_nametuple()
self._policies[(p.start_node, p.short_channel_id)] = p
for channel_info in self._channels.values():
self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id))
self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id))
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
self.update_counts()
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
return self._policies.get((node_id, short_channel_id))
@ -584,6 +547,3 @@ class ChannelDB(SqlDB):
def get_channels_for_node(self, node_id) -> Set[bytes]:
"""Returns the set of channels that have node_id as one of the endpoints."""
return self._channels_for_node.get(node_id) or set()

View file

@ -56,10 +56,11 @@ class WatcherList(MyTreeView):
return
self.model().clear()
self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
sweepstore = self.parent.lnwatcher.sweepstore
for outpoint in sweepstore.list_sweep_tx():
n = sweepstore.get_num_tx(outpoint)
status = self.parent.lnwatcher.get_channel_status(outpoint)
lnwatcher = self.parent.lnwatcher
l = lnwatcher.list_sweep_tx()
for outpoint in l:
n = lnwatcher.get_num_tx(outpoint)
status = lnwatcher.get_channel_status(outpoint)
items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
self.model().insertRow(self.model().rowCount(), items)

View file

@ -258,14 +258,21 @@ class LnAddr(object):
def get_min_final_cltv_expiry(self) -> int:
return self._min_final_cltv_expiry
def get_description(self):
def get_tag(self, tag):
description = ''
for k,v in self.tags:
if k == 'd':
if k == tag:
description = v
break
return description
def get_description(self):
return self.get_tag('d')
def get_expiry(self):
return int(self.get_tag('x') or '3600')
def lndecode(a, verbose=False, expected_hrp=None):
if expected_hrp is None:

View file

@ -163,8 +163,6 @@ class Channel(Logger):
self._is_funding_txo_spent = None # "don't know"
self._state = None
self.set_state('DISCONNECTED')
self.lnwatcher = None
self.local_commitment = None
self.remote_commitment = None
self.sweep_info = None
@ -453,13 +451,10 @@ class Channel(Logger):
return secret, point
def process_new_revocation_secret(self, per_commitment_secret: bytes):
if not self.lnwatcher:
return
outpoint = self.funding_outpoint.to_str()
ctx = self.remote_commitment_to_be_revoked # FIXME can't we just reconstruct it?
sweeptxs = create_sweeptxs_for_their_revoked_ctx(self, ctx, per_commitment_secret, self.sweep_address)
for tx in sweeptxs:
self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx))
return sweeptxs
def receive_revocation(self, revocation: RevokeAndAck):
self.logger.info("receive_revocation")
@ -477,9 +472,10 @@ class Channel(Logger):
# be robust to exceptions raised in lnwatcher
try:
self.process_new_revocation_secret(revocation.per_commitment_secret)
sweeptxs = self.process_new_revocation_secret(revocation.per_commitment_secret)
except Exception as e:
self.logger.info("Could not process revocation secret: {}".format(repr(e)))
sweeptxs = []
##### start applying fee/htlc changes
@ -505,6 +501,8 @@ class Channel(Logger):
self.set_remote_commitment()
self.remote_commitment_to_be_revoked = prev_remote_commitment
# return sweep transactions for watchtower
return sweeptxs
def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None):
"""

View file

@ -42,7 +42,6 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY)
from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg
from .lnverifier import verify_sig_for_channel_update
from .interface import GracefulDisconnect
if TYPE_CHECKING:
@ -242,22 +241,20 @@ class Peer(Logger):
# channel announcements
for chan_anns_chunk in chunks(chan_anns, 300):
self.verify_channel_announcements(chan_anns_chunk)
self.channel_db.on_channel_announcement(chan_anns_chunk)
self.channel_db.add_channel_announcement(chan_anns_chunk)
# node announcements
for node_anns_chunk in chunks(node_anns, 100):
self.verify_node_announcements(node_anns_chunk)
self.channel_db.on_node_announcement(node_anns_chunk)
self.channel_db.add_node_announcement(node_anns_chunk)
# channel updates
for chan_upds_chunk in chunks(chan_upds, 1000):
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds_chunk,
max_age=self.network.lngossip.max_age)
orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates(
chan_upds_chunk, max_age=self.network.lngossip.max_age)
if orphaned:
self.logger.info(f'adding {len(orphaned)} unknown channel ids')
self.network.lngossip.add_new_ids(orphaned)
await self.network.lngossip.add_new_ids(orphaned)
if good:
self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds_chunk)}')
self.verify_channel_updates(good)
self.channel_db.update_policies(good, to_delete)
# refresh gui
if chan_anns or node_anns or chan_upds:
self.network.lngossip.refresh_gui()
@ -279,14 +276,6 @@ class Peer(Logger):
if not ecc.verify_signature(pubkey, signature, h):
raise Exception('signature failed')
def verify_channel_updates(self, chan_upds):
for payload in chan_upds:
short_channel_id = payload['short_channel_id']
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']):
raise BaseException('verify error')
async def query_gossip(self):
try:
await asyncio.wait_for(self.initialized.wait(), 10)
@ -298,7 +287,7 @@ class Peer(Logger):
except asyncio.TimeoutError as e:
raise GracefulDisconnect("query_channel_range timed out") from e
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
self.lnworker.add_new_ids(ids)
await self.lnworker.add_new_ids(ids)
while True:
todo = self.lnworker.get_ids_to_query()
if not todo:
@ -658,7 +647,7 @@ class Peer(Logger):
)
chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig)
self.lnworker.save_channel(chan)
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
await self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
self.lnworker.on_channels_updated()
while True:
try:
@ -862,8 +851,6 @@ class Peer(Logger):
bitcoin_key_2=bitcoin_keys[1]
)
print("SENT CHANNEL ANNOUNCEMENT")
def mark_open(self, chan: Channel):
assert chan.short_channel_id is not None
if chan.get_state() == "OPEN":
@ -872,6 +859,10 @@ class Peer(Logger):
assert chan.config[LOCAL].funding_locked_received
chan.set_state("OPEN")
self.network.trigger_callback('channel', chan)
asyncio.ensure_future(self.add_own_channel(chan))
self.logger.info("CHANNEL OPENING COMPLETED")
async def add_own_channel(self, chan):
# add channel to database
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
sorted_node_ids = list(sorted(self.node_ids))
@ -887,7 +878,7 @@ class Peer(Logger):
# that the remote sends, even if the channel was not announced
# (from BOLT-07: "MAY create a channel_update to communicate the channel
# parameters to the final node, even though the channel has not yet been announced")
self.channel_db.on_channel_announcement(
self.channel_db.add_channel_announcement(
{
"short_channel_id": chan.short_channel_id,
"node_id_1": node_ids[0],
@ -922,8 +913,6 @@ class Peer(Logger):
if pending_channel_update:
self.channel_db.add_channel_update(pending_channel_update)
self.logger.info("CHANNEL OPENING COMPLETED")
def send_announcement_signatures(self, chan: Channel):
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
@ -962,6 +951,21 @@ class Peer(Logger):
def on_update_fail_htlc(self, payload):
channel_id = payload["channel_id"]
htlc_id = int.from_bytes(payload["id"], "big")
chan = self.channels[channel_id]
chan.receive_fail_htlc(htlc_id)
local_ctn = chan.get_current_ctn(LOCAL)
asyncio.ensure_future(self._handle_error_code_from_failed_htlc(payload, channel_id, htlc_id))
asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn))
@log_exceptions
async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn):
chan = self.channels[channel_id]
await self.await_local(chan, local_ctn)
self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False)
@log_exceptions
async def _handle_error_code_from_failed_htlc(self, payload, channel_id, htlc_id):
chan = self.channels[channel_id]
key = (channel_id, htlc_id)
try:
route = self.attempted_route[key]
@ -969,29 +973,12 @@ class Peer(Logger):
# the remote might try to fail an htlc after we restarted...
# attempted_route is not persisted, so we will get here then
self.logger.info("UPDATE_FAIL_HTLC. cannot decode! attempted route is MISSING. {}".format(key))
else:
try:
self._handle_error_code_from_failed_htlc(payload["reason"], route, channel_id, htlc_id)
except Exception:
# exceptions are suppressed as failing to handle an error code
# should not block us from removing the htlc
traceback.print_exc(file=sys.stderr)
# process update_fail_htlc on channel
chan = self.channels[channel_id]
chan.receive_fail_htlc(htlc_id)
local_ctn = chan.get_current_ctn(LOCAL)
asyncio.ensure_future(self._on_update_fail_htlc(chan, htlc_id, local_ctn))
@log_exceptions
async def _on_update_fail_htlc(self, chan, htlc_id, local_ctn):
await self.await_local(chan, local_ctn)
self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False)
def _handle_error_code_from_failed_htlc(self, error_reason, route: List['RouteEdge'], channel_id, htlc_id):
chan = self.channels[channel_id]
failure_msg, sender_idx = decode_onion_error(error_reason,
[x.node_id for x in route],
chan.onion_keys[htlc_id])
return
error_reason = payload["reason"]
failure_msg, sender_idx = decode_onion_error(
error_reason,
[x.node_id for x in route],
chan.onion_keys[htlc_id])
code, data = failure_msg.code, failure_msg.data
self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}")
self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}")
@ -1009,11 +996,9 @@ class Peer(Logger):
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
message_type, payload = decode_msg(channel_update)
payload['raw'] = channel_update
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates([payload])
blacklist = False
if good:
self.verify_channel_updates(good)
self.channel_db.update_policies(good, to_delete)
self.logger.info("applied channel update on our db")
elif orphaned:
# maybe it is a private channel (and data in invoice was outdated)
@ -1276,11 +1261,17 @@ class Peer(Logger):
self.logger.info("on_revoke_and_ack")
channel_id = payload["channel_id"]
chan = self.channels[channel_id]
chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]))
sweeptxs = chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]))
self._remote_changed_events[chan.channel_id].set()
self._remote_changed_events[chan.channel_id].clear()
self.lnworker.save_channel(chan)
self.maybe_send_commitment(chan)
asyncio.ensure_future(self._on_revoke_and_ack(chan, sweeptxs))
async def _on_revoke_and_ack(self, chan, sweeptxs):
outpoint = chan.funding_outpoint.to_str()
for tx in sweeptxs:
await self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx))
def on_update_fee(self, payload):
channel_id = payload["channel_id"]

View file

@ -37,7 +37,7 @@ import binascii
import base64
from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks
from .util import bh2u, profiler, get_headers_dir, is_ip_address, list_enabled_bits, print_msg, chunks
from .logging import Logger
from .storage import JsonDB
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
@ -169,7 +169,6 @@ class LNPathFinder(Logger):
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
"""
self.channel_db.load_data()
assert type(nodeA) is bytes
assert type(nodeB) is bytes
assert type(invoice_amount_msat) is int
@ -195,11 +194,12 @@ class LNPathFinder(Logger):
else: # payment incoming, on our channel. (funny business, cycle weirdness)
assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
pass # TODO?
edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id,
start_node=edge_startnode,
end_node=edge_endnode,
payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA))
edge_cost, fee_for_edge_msat = self._edge_cost(
edge_channel_id,
start_node=edge_startnode,
end_node=edge_endnode,
payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA))
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour
@ -219,9 +219,10 @@ class LNPathFinder(Logger):
continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
assert type(edge_channel_id) is bytes
if edge_channel_id in self.blacklist: continue
if edge_channel_id in self.blacklist:
continue
channel_info = self.channel_db.get_channel_info(edge_channel_id)
edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id)
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
inspect_edge()
else:
return None # no path found

View file

@ -70,11 +70,11 @@ class SweepStore(SqlDB):
@sql
def get_tx_by_index(self, funding_outpoint, index):
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
return r.prevout, bh2u(r.tx)
return str(r.prevout), bh2u(r.tx)
@sql
def list_sweep_tx(self):
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all())
@sql
def add_sweep_tx(self, funding_outpoint, prevout, tx):
@ -84,7 +84,7 @@ class SweepStore(SqlDB):
@sql
def get_num_tx(self, funding_outpoint):
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count())
@sql
def remove_sweep_tx(self, funding_outpoint):
@ -111,11 +111,11 @@ class SweepStore(SqlDB):
@sql
def get_address(self, outpoint):
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
return r.address if r else None
return str(r.address) if r else None
@sql
def list_channel_info(self):
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()]
class LNWatcher(AddressSynchronizer):
@ -150,14 +150,21 @@ class LNWatcher(AddressSynchronizer):
self.watchtower_queue = asyncio.Queue()
def get_num_tx(self, outpoint):
return self.sweepstore.get_num_tx(outpoint)
async def f():
return await self.sweepstore.get_num_tx(outpoint)
return self.network.run_from_another_thread(f())
def list_sweep_tx(self):
async def f():
return await self.sweepstore.list_sweep_tx()
return self.network.run_from_another_thread(f())
@ignore_exceptions
@log_exceptions
async def watchtower_task(self):
self.logger.info('watchtower task started')
# initial check
for address, outpoint in self.sweepstore.list_channel_info():
for address, outpoint in await self.sweepstore.list_channel_info():
await self.watchtower_queue.put(outpoint)
while True:
outpoint = await self.watchtower_queue.get()
@ -165,30 +172,30 @@ class LNWatcher(AddressSynchronizer):
continue
# synchronize with remote
try:
local_n = self.sweepstore.get_num_tx(outpoint)
local_n = await self.sweepstore.get_num_tx(outpoint)
n = self.watchtower.get_num_tx(outpoint)
if n == 0:
address = self.sweepstore.get_address(outpoint)
address = await self.sweepstore.get_address(outpoint)
self.watchtower.add_channel(outpoint, address)
self.logger.info("sending %d transactions to watchtower"%(local_n - n))
for index in range(n, local_n):
prevout, tx = self.sweepstore.get_tx_by_index(outpoint, index)
prevout, tx = await self.sweepstore.get_tx_by_index(outpoint, index)
self.watchtower.add_sweep_tx(outpoint, prevout, tx)
except ConnectionRefusedError:
self.logger.info('could not reach watchtower, will retry in 5s')
await asyncio.sleep(5)
await self.watchtower_queue.put(outpoint)
def add_channel(self, outpoint, address):
async def add_channel(self, outpoint, address):
self.add_address(address)
with self.lock:
if not self.sweepstore.has_channel(outpoint):
self.sweepstore.add_channel(outpoint, address)
if not await self.sweepstore.has_channel(outpoint):
await self.sweepstore.add_channel(outpoint, address)
def unwatch_channel(self, address, funding_outpoint):
async def unwatch_channel(self, address, funding_outpoint):
self.logger.info(f'unwatching {funding_outpoint}')
self.sweepstore.remove_sweep_tx(funding_outpoint)
self.sweepstore.remove_channel(funding_outpoint)
await self.sweepstore.remove_sweep_tx(funding_outpoint)
await self.sweepstore.remove_channel(funding_outpoint)
if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set()
@ -202,7 +209,7 @@ class LNWatcher(AddressSynchronizer):
return
if not self.synchronizer.is_up_to_date():
return
for address, outpoint in self.sweepstore.list_channel_info():
for address, outpoint in await self.sweepstore.list_channel_info():
await self.check_onchain_situation(address, outpoint)
async def check_onchain_situation(self, address, funding_outpoint):
@ -223,7 +230,7 @@ class LNWatcher(AddressSynchronizer):
closing_height, closing_tx) # FIXME sooo many args..
await self.do_breach_remedy(funding_outpoint, spenders)
if not keep_watching:
self.unwatch_channel(address, funding_outpoint)
await self.unwatch_channel(address, funding_outpoint)
else:
#self.logger.info(f'we will keep_watching {funding_outpoint}')
pass
@ -260,7 +267,7 @@ class LNWatcher(AddressSynchronizer):
for prevout, spender in spenders.items():
if spender is not None:
continue
sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prevout)
sweep_txns = await self.sweepstore.get_sweep_tx(funding_outpoint, prevout)
for tx in sweep_txns:
if not await self.broadcast_or_log(funding_outpoint, tx):
self.logger.info(f'{tx.name} could not publish tx: {str(tx)}, prevout: {prevout}')
@ -279,8 +286,8 @@ class LNWatcher(AddressSynchronizer):
await self.tx_progress[funding_outpoint].tx_queue.put(tx)
return txid
def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str):
self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx)
async def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str):
await self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx)
if self.watchtower:
self.watchtower_queue.put_nowait(funding_outpoint)

View file

@ -108,12 +108,14 @@ class LNWorker(Logger):
@log_exceptions
async def main_loop(self):
# fixme: only lngossip should do that
await self.channel_db.load_data()
while True:
await asyncio.sleep(1)
now = time.time()
if len(self.peers) >= NUM_PEERS_TARGET:
continue
peers = self._get_next_peers_to_try()
peers = await self._get_next_peers_to_try()
for peer in peers:
last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL < now:
@ -130,7 +132,8 @@ class LNWorker(Logger):
peer = Peer(self, node_id, transport)
await self.network.main_taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer
self.network.lngossip.refresh_gui()
#if self.network.lngossip:
# self.network.lngossip.refresh_gui()
return peer
def start_network(self, network: 'Network'):
@ -148,7 +151,7 @@ class LNWorker(Logger):
self._add_peer(host, int(port), bfh(pubkey)),
self.network.asyncio_loop)
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time()
recent_peers = self.channel_db.get_recent_peers()
# maintenance for last tried times
@ -158,19 +161,22 @@ class LNWorker(Logger):
del self._last_tried_peer[peer]
# first try from recent peers
for peer in recent_peers:
if peer.pubkey in self.peers: continue
if peer in self._last_tried_peer: continue
if peer.pubkey in self.peers:
continue
if peer in self._last_tried_peer:
continue
return [peer]
# try random peer from graph
unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
if unconnected_nodes:
for node in unconnected_nodes:
addrs = self.channel_db.get_node_addresses(node)
for node_id in unconnected_nodes:
addrs = self.channel_db.get_node_addresses(node_id)
if not addrs:
continue
host, port = self.choose_preferred_address(addrs)
peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id))
if peer in self._last_tried_peer: continue
host, port, timestamp = self.choose_preferred_address(addrs)
peer = LNPeerAddr(host, port, node_id)
if peer in self._last_tried_peer:
continue
#self.logger.info('taking random ln peer from our channel db')
return [peer]
@ -223,15 +229,13 @@ class LNWorker(Logger):
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
assert len(addr_list) >= 1
# choose first one that is an IP
for addr_in_db in addr_list:
host = addr_in_db.host
port = addr_in_db.port
for host, port, timestamp in addr_list:
if is_ip_address(host):
return host, port
return host, port, timestamp
# otherwise choose one at random
# TODO maybe filter out onion if not on tor?
choice = random.choice(addr_list)
return choice.host, choice.port
return choice
class LNGossip(LNWorker):
@ -260,26 +264,19 @@ class LNGossip(LNWorker):
self.network.trigger_callback('ln_status', num_peers, num_nodes, known, unknown)
async def maintain_db(self):
n = self.channel_db.get_orphaned_channels()
if n:
self.logger.info(f'Deleting {n} orphaned channels')
self.channel_db.prune_orphaned_channels()
self.refresh_gui()
self.channel_db.prune_orphaned_channels()
while True:
n = self.channel_db.get_old_policies(self.max_age)
if n:
self.logger.info(f'Deleting {n} old channels')
self.channel_db.prune_old_policies(self.max_age)
self.refresh_gui()
self.channel_db.prune_old_policies(self.max_age)
self.refresh_gui()
await asyncio.sleep(5)
def add_new_ids(self, ids):
known = self.channel_db.known_ids()
async def add_new_ids(self, ids):
known = self.channel_db.get_channel_ids()
new = set(ids) - set(known)
self.unknown_ids.update(new)
def get_ids_to_query(self):
N = 500
N = 100
l = list(self.unknown_ids)
self.unknown_ids = set(l[N:])
return l[0:N]
@ -324,9 +321,10 @@ class LNWallet(LNWorker):
self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee']) # thread safe
self.network.register_callback(self.on_channel_open, ['channel_open'])
self.network.register_callback(self.on_channel_closed, ['channel_closed'])
for chan_id, chan in self.channels.items():
self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
chan.lnwatcher = network.lnwatcher
self.network.lnwatcher.add_address(chan.get_funding_address())
super().start_network(network)
for coro in [
self.maybe_listen(),
@ -494,7 +492,7 @@ class LNWallet(LNWorker):
chan = self.channel_by_txo(funding_outpoint)
if not chan:
return
self.logger.debug(f'on_channel_open {funding_outpoint}')
#self.logger.debug(f'on_channel_open {funding_outpoint}')
self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, None, None, None
self.storage.put('lightning_channel_timestamps', self.channel_timestamps)
chan.set_funding_txo_spentness(False)
@ -606,7 +604,8 @@ class LNWallet(LNWorker):
self.logger.info('REBROADCASTING CLOSING TX')
await self.force_close_channel(chan.channel_id)
async def _open_channel_coroutine(self, peer, local_amount_sat, push_sat, password):
async def _open_channel_coroutine(self, connect_str, local_amount_sat, push_sat, password):
peer = await self.add_peer(connect_str)
# peer might just have been connected to
await asyncio.wait_for(peer.initialized.wait(), 5)
chan = await peer.channel_establishment_flow(
@ -615,24 +614,22 @@ class LNWallet(LNWorker):
push_msat=push_sat * 1000,
temp_channel_id=os.urandom(32))
self.save_channel(chan)
self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
self.network.lnwatcher.add_address(chan.get_funding_address())
await self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
self.on_channels_updated()
return chan
def on_channels_updated(self):
self.network.trigger_callback('channels')
def add_peer(self, connect_str, timeout=20):
async def add_peer(self, connect_str, timeout=20):
node_id, rest = extract_nodeid(connect_str)
peer = self.peers.get(node_id)
if not peer:
if rest is not None:
host, port = split_host_port(rest)
else:
node_info = self.network.channel_db.nodes_get(node_id)
if not node_info:
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
addrs = self.channel_db.get_node_addresses(node_info)
addrs = self.channel_db.get_node_addresses(node_id)
if len(addrs) == 0:
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
host, port = self.choose_preferred_address(addrs)
@ -640,18 +637,12 @@ class LNWallet(LNWorker):
socket.getaddrinfo(host, int(port))
except socket.gaierror:
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
peer_future = asyncio.run_coroutine_threadsafe(
self._add_peer(host, port, node_id),
self.network.asyncio_loop)
try:
peer = peer_future.result(timeout)
except concurrent.futures.TimeoutError:
raise Exception(_("add_peer timed out"))
# add peer
peer = await self._add_peer(host, port, node_id)
return peer
def open_channel(self, connect_str, local_amt_sat, push_amt_sat, password=None, timeout=20):
peer = self.add_peer(connect_str, timeout)
coro = self._open_channel_coroutine(peer, local_amt_sat, push_amt_sat, password)
coro = self._open_channel_coroutine(connect_str, local_amt_sat, push_amt_sat, password)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
try:
chan = fut.result(timeout=timeout)
@ -664,6 +655,9 @@ class LNWallet(LNWorker):
Can be called from other threads
Raises timeout exception if htlc is not fulfilled
"""
addr = self._check_invoice(invoice, amount_sat)
self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False)
self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description())
fut = asyncio.run_coroutine_threadsafe(
self._pay(invoice, attempts, amount_sat),
self.network.asyncio_loop)
@ -680,8 +674,6 @@ class LNWallet(LNWorker):
async def _pay(self, invoice, attempts=1, amount_sat=None):
addr = self._check_invoice(invoice, amount_sat)
self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False)
self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description())
for i in range(attempts):
route = await self._create_route_from_invoice(decoded_invoice=addr)
if not self.get_channel_by_short_id(route[0].short_channel_id):
@ -691,7 +683,7 @@ class LNWallet(LNWorker):
return True
return False
async def _pay_to_route(self, route, addr, pay_req):
async def _pay_to_route(self, route, addr, invoice):
short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
if not chan:
@ -713,6 +705,9 @@ class LNWallet(LNWorker):
raise InvoiceError("{}\n{}".format(
_("Invoice wants us to risk locking funds for unreasonably long."),
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
#now = int(time.time())
#if addr.date + addr.get_expiry() > now:
# raise InvoiceError(_('Invoice expired'))
return addr
async def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
@ -730,11 +725,14 @@ class LNWallet(LNWorker):
with self.lock:
channels = list(self.channels.values())
for private_route in r_tags:
if len(private_route) == 0: continue
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: continue
if len(private_route) == 0:
continue
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
continue
border_node_pubkey = private_route[0][0]
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
if not path: continue
if not path:
continue
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
# we need to shift the node pubkey by one towards the destination:
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
@ -770,10 +768,18 @@ class LNWallet(LNWorker):
return route
def add_invoice(self, amount_sat, message):
coro = self._add_invoice_coro(amount_sat, message)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
try:
return fut.result(timeout=5)
except concurrent.futures.TimeoutError:
raise Exception(_("add_invoice timed out"))
async def _add_invoice_coro(self, amount_sat, message):
payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage)
amount_btc = amount_sat/Decimal(COIN) if amount_sat else None
routing_hints = self._calc_routing_hints_for_invoice(amount_sat)
routing_hints = await self._calc_routing_hints_for_invoice(amount_sat)
if not routing_hints:
self.logger.info("Warning. No routing hints added to invoice. "
"Other clients will likely not be able to send to us.")
@ -847,19 +853,20 @@ class LNWallet(LNWorker):
})
return out
def _calc_routing_hints_for_invoice(self, amount_sat):
async def _calc_routing_hints_for_invoice(self, amount_sat):
"""calculate routing hints (BOLT-11 'r' field)"""
self.channel_db.load_data()
routing_hints = []
with self.lock:
channels = list(self.channels.values())
# note: currently we add *all* our channels; but this might be a privacy leak?
for chan in channels:
# check channel is open
if chan.get_state() != "OPEN": continue
if chan.get_state() != "OPEN":
continue
# check channel has sufficient balance
# FIXME because of on-chain fees of ctx, this check is insufficient
if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat: continue
if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat:
continue
chan_id = chan.short_channel_id
assert type(chan_id) is bytes, chan_id
channel_info = self.channel_db.get_channel_info(chan_id)
@ -949,14 +956,10 @@ class LNWallet(LNWorker):
await self._add_peer(peer.host, peer.port, peer.pubkey)
return
# try random address for node_id
node_info = self.channel_db.nodes_get(chan.node_id)
if not node_info:
return
addresses = self.channel_db.get_node_addresses(node_info)
addresses = self.channel_db.get_node_addresses(chan.node_id)
if not addresses:
return
adr_obj = random.choice(addresses)
host, port = adr_obj.host, adr_obj.port
host, port, t = random.choice(list(addresses))
peer = LNPeerAddr(host, port, chan.node_id)
last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:

View file

@ -2,6 +2,7 @@ import os
import concurrent
import queue
import threading
import asyncio
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
@ -18,28 +19,32 @@ def sql(func):
"""wrapper for sql methods"""
def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future()
f = asyncio.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return f
return wrapper
class SqlDB(Logger):
def __init__(self, network, path, base):
def __init__(self, network, path, base, commit_interval=None):
Logger.__init__(self)
self.base = base
self.network = network
self.path = path
self.commit_interval = commit_interval
self.db_requests = queue.Queue()
self.sql_thread = threading.Thread(target=self.run_sql)
self.sql_thread.start()
def run_sql(self):
#return
self.logger.info("SQL thread started")
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
self.base.metadata.create_all(engine)
self.DBSession = DBSession()
i = 0
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
@ -50,7 +55,14 @@ class SqlDB(Logger):
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
if not future.cancelled():
future.set_result(result)
# note: in sweepstore session.commit() is called inside
# the sql-decorated methods, so commiting to disk is awaited
if self.commit_interval:
i = (i + 1) % self.commit_interval
if i == 0:
self.DBSession.commit()
# write
self.DBSession.commit()
self.logger.info("SQL thread terminated")

View file

@ -16,7 +16,8 @@ from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure, LnLocalFeatures
from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnrouter import LNPathFinder
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet
from electrum.lnmsg import encode_msg, decode_msg
from electrum.logging import console_stderr_handler

View file

@ -59,33 +59,33 @@ class Test_LNRouter(TestCaseForTestnet):
cdb = fake_network.channel_db
path_finder = lnrouter.LNPathFinder(cdb)
self.assertEqual(cdb.num_channels, 0)
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc',
'short_channel_id': bfh('0000000000000001'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True)
self.assertEqual(cdb.num_channels, 1)
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'short_channel_id': bfh('0000000000000002'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True)
cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
'short_channel_id': bfh('0000000000000003'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True)
cdb.on_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
cdb.add_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
'bitcoin_key_1': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd',
'short_channel_id': bfh('0000000000000004'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True)
cdb.on_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
cdb.add_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'bitcoin_key_1': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'short_channel_id': bfh('0000000000000005'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True)
cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd',
'short_channel_id': bfh('0000000000000006'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),