mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-09-03 12:30:07 +00:00
optimize channel_db:
- use python objects mirrored by sql database - write sql to file asynchronously - the sql decorator is awaited in sweepstore, not in channel_db
This commit is contained in:
parent
180f6d34be
commit
f2d58d0e3f
11 changed files with 435 additions and 454 deletions
|
@ -51,6 +51,7 @@ from .crypto import sha256d
|
|||
from . import ecc
|
||||
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
|
||||
NotFoundChanAnnouncementForUpdate)
|
||||
from .lnverifier import verify_sig_for_channel_update
|
||||
from .lnmsg import encode_msg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -70,85 +71,83 @@ Base = declarative_base()
|
|||
FLAG_DISABLE = 1 << 1
|
||||
FLAG_DIRECTION = 1 << 0
|
||||
|
||||
class ChannelInfo(Base):
|
||||
__tablename__ = 'channel_info'
|
||||
short_channel_id = Column(String(64), primary_key=True)
|
||||
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||
capacity_sat = Column(Integer)
|
||||
msg_payload_hex = Column(String(1024), nullable=False)
|
||||
trusted = Column(Boolean, nullable=False)
|
||||
class ChannelInfo(NamedTuple):
|
||||
short_channel_id: bytes
|
||||
node1_id: bytes
|
||||
node2_id: bytes
|
||||
capacity_sat: int
|
||||
msg_payload: bytes
|
||||
trusted: bool
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload):
|
||||
features = int.from_bytes(payload['features'], 'big')
|
||||
validate_features(features)
|
||||
channel_id = payload['short_channel_id'].hex()
|
||||
node_id_1 = payload['node_id_1'].hex()
|
||||
node_id_2 = payload['node_id_2'].hex()
|
||||
channel_id = payload['short_channel_id']
|
||||
node_id_1 = payload['node_id_1']
|
||||
node_id_2 = payload['node_id_2']
|
||||
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
|
||||
msg_payload_hex = encode_msg('channel_announcement', **payload).hex()
|
||||
msg_payload = encode_msg('channel_announcement', **payload)
|
||||
capacity_sat = None
|
||||
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1,
|
||||
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex,
|
||||
trusted = False)
|
||||
|
||||
@property
|
||||
def msg_payload(self):
|
||||
return bytes.fromhex(self.msg_payload_hex)
|
||||
return ChannelInfo(
|
||||
short_channel_id = channel_id,
|
||||
node1_id = node_id_1,
|
||||
node2_id = node_id_2,
|
||||
capacity_sat = capacity_sat,
|
||||
msg_payload = msg_payload,
|
||||
trusted = False)
|
||||
|
||||
|
||||
class Policy(Base):
|
||||
__tablename__ = 'policy'
|
||||
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
||||
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True)
|
||||
cltv_expiry_delta = Column(Integer, nullable=False)
|
||||
htlc_minimum_msat = Column(Integer, nullable=False)
|
||||
htlc_maximum_msat = Column(Integer)
|
||||
fee_base_msat = Column(Integer, nullable=False)
|
||||
fee_proportional_millionths = Column(Integer, nullable=False)
|
||||
channel_flags = Column(Integer, nullable=False)
|
||||
timestamp = Column(Integer, nullable=False)
|
||||
|
||||
class Policy(NamedTuple):
|
||||
key: bytes
|
||||
cltv_expiry_delta: int
|
||||
htlc_minimum_msat: int
|
||||
htlc_maximum_msat: int
|
||||
fee_base_msat: int
|
||||
fee_proportional_millionths: int
|
||||
channel_flags: int
|
||||
timestamp: int
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload):
|
||||
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
|
||||
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
|
||||
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
|
||||
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
|
||||
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
|
||||
channel_flags = int.from_bytes(payload['channel_flags'], "big")
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
start_node = payload['start_node'].hex()
|
||||
short_channel_id = payload['short_channel_id'].hex()
|
||||
|
||||
return Policy(start_node=start_node,
|
||||
short_channel_id=short_channel_id,
|
||||
cltv_expiry_delta=cltv_expiry_delta,
|
||||
htlc_minimum_msat=htlc_minimum_msat,
|
||||
fee_base_msat=fee_base_msat,
|
||||
fee_proportional_millionths=fee_proportional_millionths,
|
||||
channel_flags=channel_flags,
|
||||
timestamp=timestamp,
|
||||
htlc_maximum_msat=htlc_maximum_msat)
|
||||
return Policy(
|
||||
key = payload['short_channel_id'] + payload['start_node'],
|
||||
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
|
||||
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"),
|
||||
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None,
|
||||
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"),
|
||||
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"),
|
||||
channel_flags = int.from_bytes(payload['channel_flags'], "big"),
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
)
|
||||
|
||||
def is_disabled(self):
|
||||
return self.channel_flags & FLAG_DISABLE
|
||||
|
||||
class NodeInfo(Base):
|
||||
__tablename__ = 'node_info'
|
||||
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
features = Column(Integer, nullable=False)
|
||||
timestamp = Column(Integer, nullable=False)
|
||||
alias = Column(String(64), nullable=False)
|
||||
@property
|
||||
def short_channel_id(self):
|
||||
return self.key[0:8]
|
||||
|
||||
@property
|
||||
def start_node(self):
|
||||
return self.key[8:]
|
||||
|
||||
|
||||
|
||||
class NodeInfo(NamedTuple):
|
||||
node_id: bytes
|
||||
features: int
|
||||
timestamp: int
|
||||
alias: str
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload):
|
||||
node_id = payload['node_id'].hex()
|
||||
node_id = payload['node_id']
|
||||
features = int.from_bytes(payload['features'], "big")
|
||||
validate_features(features)
|
||||
addresses = NodeInfo.parse_addresses_field(payload['addresses'])
|
||||
alias = payload['alias'].rstrip(b'\x00').hex()
|
||||
alias = payload['alias'].rstrip(b'\x00')
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
|
||||
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses]
|
||||
|
@ -193,110 +192,136 @@ class NodeInfo(Base):
|
|||
break
|
||||
return addresses
|
||||
|
||||
class Address(Base):
|
||||
|
||||
class Address(NamedTuple):
|
||||
node_id: bytes
|
||||
host: str
|
||||
port: int
|
||||
last_connected_date: int
|
||||
|
||||
|
||||
class ChannelInfoBase(Base):
|
||||
__tablename__ = 'channel_info'
|
||||
short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
|
||||
capacity_sat = Column(Integer)
|
||||
msg_payload = Column(String(1024), nullable=False)
|
||||
trusted = Column(Boolean, nullable=False)
|
||||
|
||||
def to_nametuple(self):
|
||||
return ChannelInfo(
|
||||
short_channel_id=self.short_channel_id,
|
||||
node1_id=self.node1_id,
|
||||
node2_id=self.node2_id,
|
||||
capacity_sat=self.capacity_sat,
|
||||
msg_payload=self.msg_payload,
|
||||
trusted=self.trusted
|
||||
)
|
||||
|
||||
class PolicyBase(Base):
|
||||
__tablename__ = 'policy'
|
||||
key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
cltv_expiry_delta = Column(Integer, nullable=False)
|
||||
htlc_minimum_msat = Column(Integer, nullable=False)
|
||||
htlc_maximum_msat = Column(Integer)
|
||||
fee_base_msat = Column(Integer, nullable=False)
|
||||
fee_proportional_millionths = Column(Integer, nullable=False)
|
||||
channel_flags = Column(Integer, nullable=False)
|
||||
timestamp = Column(Integer, nullable=False)
|
||||
|
||||
def to_nametuple(self):
|
||||
return Policy(
|
||||
key=self.key,
|
||||
cltv_expiry_delta=self.cltv_expiry_delta,
|
||||
htlc_minimum_msat=self.htlc_minimum_msat,
|
||||
htlc_maximum_msat=self.htlc_maximum_msat,
|
||||
fee_base_msat= self.fee_base_msat,
|
||||
fee_proportional_millionths = self.fee_proportional_millionths,
|
||||
channel_flags=self.channel_flags,
|
||||
timestamp=self.timestamp
|
||||
)
|
||||
|
||||
class NodeInfoBase(Base):
|
||||
__tablename__ = 'node_info'
|
||||
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
features = Column(Integer, nullable=False)
|
||||
timestamp = Column(Integer, nullable=False)
|
||||
alias = Column(String(64), nullable=False)
|
||||
|
||||
class AddressBase(Base):
|
||||
__tablename__ = 'address'
|
||||
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True)
|
||||
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
|
||||
host = Column(String(256), primary_key=True)
|
||||
port = Column(Integer, primary_key=True)
|
||||
last_connected_date = Column(Integer(), nullable=True)
|
||||
|
||||
|
||||
|
||||
class ChannelDB(SqlDB):
|
||||
|
||||
NUM_MAX_RECENT_PEERS = 20
|
||||
|
||||
def __init__(self, network: 'Network'):
|
||||
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
||||
super().__init__(network, path, Base)
|
||||
super().__init__(network, path, Base, commit_interval=100)
|
||||
self.num_nodes = 0
|
||||
self.num_channels = 0
|
||||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
||||
self.ca_verifier = LNChannelVerifier(network, self)
|
||||
self.update_counts()
|
||||
# initialized in load_data
|
||||
self._channels = {}
|
||||
self._policies = {}
|
||||
self._nodes = {}
|
||||
self._addresses = defaultdict(set)
|
||||
self._channels_for_node = defaultdict(set)
|
||||
|
||||
@sql
|
||||
def update_counts(self):
|
||||
self._update_counts()
|
||||
self.num_channels = len(self._channels)
|
||||
self.num_policies = len(self._policies)
|
||||
self.num_nodes = len(self._nodes)
|
||||
|
||||
def _update_counts(self):
|
||||
self.num_channels = self.DBSession.query(ChannelInfo).count()
|
||||
self.num_policies = self.DBSession.query(Policy).count()
|
||||
self.num_nodes = self.DBSession.query(NodeInfo).count()
|
||||
def get_channel_ids(self):
|
||||
return set(self._channels.keys())
|
||||
|
||||
@sql
|
||||
def known_ids(self):
|
||||
known = self.DBSession.query(ChannelInfo.short_channel_id).all()
|
||||
return set(bfh(r.short_channel_id) for r in known)
|
||||
|
||||
@sql
|
||||
def add_recent_peer(self, peer: LNPeerAddr):
|
||||
now = int(time.time())
|
||||
node_id = peer.pubkey.hex()
|
||||
addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
|
||||
node_id = peer.pubkey
|
||||
self._addresses[node_id].add((peer.host, peer.port, now))
|
||||
self.save_address(node_id, peer, now)
|
||||
|
||||
@sql
|
||||
def save_address(self, node_id, peer, now):
|
||||
addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
|
||||
if addr:
|
||||
addr.last_connected_date = now
|
||||
else:
|
||||
addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
|
||||
addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
|
||||
self.DBSession.add(addr)
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
|
||||
unshuffled = self.DBSession \
|
||||
.query(NodeInfo) \
|
||||
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
|
||||
.limit(200) \
|
||||
.all()
|
||||
return random.sample(unshuffled, len(unshuffled))
|
||||
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
|
||||
unshuffled = set(self._nodes.keys()) - node_ids
|
||||
return random.sample(unshuffled, min(200, len(unshuffled)))
|
||||
|
||||
@sql
|
||||
def nodes_get(self, node_id):
|
||||
return self.DBSession \
|
||||
.query(NodeInfo) \
|
||||
.filter_by(node_id = node_id.hex()) \
|
||||
.one_or_none()
|
||||
|
||||
@sql
|
||||
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
|
||||
r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all()
|
||||
r = self._addresses.get(node_id)
|
||||
if not r:
|
||||
return None
|
||||
addr = r[0]
|
||||
return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id))
|
||||
addr = sorted(list(r), key=lambda x: x[2])[0]
|
||||
host, port, timestamp = addr
|
||||
return LNPeerAddr(host, port, node_id)
|
||||
|
||||
@sql
|
||||
def get_recent_peers(self):
|
||||
r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all()
|
||||
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
|
||||
r = [self.get_last_good_address(x) for x in self._addresses.keys()]
|
||||
r = r[-self.NUM_MAX_RECENT_PEERS:]
|
||||
return r
|
||||
|
||||
@sql
|
||||
def missing_channel_announcements(self) -> Set[int]:
|
||||
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
|
||||
return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
|
||||
|
||||
@sql
|
||||
def missing_channel_updates(self) -> Set[int]:
|
||||
expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
|
||||
return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
|
||||
|
||||
@sql
|
||||
def add_verified_channel_info(self, short_id, capacity):
|
||||
# called from lnchannelverifier
|
||||
channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none()
|
||||
channel_info.trusted = True
|
||||
channel_info.capacity = capacity
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def on_channel_announcement(self, msg_payloads, trusted=True):
|
||||
def add_channel_announcement(self, msg_payloads, trusted=True):
|
||||
if type(msg_payloads) is dict:
|
||||
msg_payloads = [msg_payloads]
|
||||
new_channels = {}
|
||||
added = 0
|
||||
for msg in msg_payloads:
|
||||
short_channel_id = bh2u(msg['short_channel_id'])
|
||||
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count():
|
||||
short_channel_id = msg['short_channel_id']
|
||||
if short_channel_id in self._channels:
|
||||
continue
|
||||
if constants.net.rev_genesis_bytes() != msg['chain_hash']:
|
||||
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
|
||||
|
@ -306,24 +331,24 @@ class ChannelDB(SqlDB):
|
|||
except UnknownEvenFeatureBits:
|
||||
self.logger.info("unknown feature bits")
|
||||
continue
|
||||
channel_info.trusted = trusted
|
||||
new_channels[short_channel_id] = channel_info
|
||||
#channel_info.trusted = trusted
|
||||
added += 1
|
||||
self._channels[short_channel_id] = channel_info
|
||||
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
||||
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
||||
self.save_channel(channel_info)
|
||||
if not trusted:
|
||||
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
|
||||
for channel_info in new_channels.values():
|
||||
self.DBSession.add(channel_info)
|
||||
self.DBSession.commit()
|
||||
self._update_counts()
|
||||
self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
|
||||
|
||||
@sql
|
||||
def get_last_timestamp(self):
|
||||
return self._get_last_timestamp()
|
||||
self.update_counts()
|
||||
self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
|
||||
|
||||
def _get_last_timestamp(self):
|
||||
from sqlalchemy.sql import func
|
||||
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
|
||||
return r.max_timestamp or 0
|
||||
|
||||
#def add_verified_channel_info(self, short_id, capacity):
|
||||
# # called from lnchannelverifier
|
||||
# channel_info = self.DBSession.query(ChannelInfoBase).filter_by(short_channel_id = short_id).one_or_none()
|
||||
# channel_info.trusted = True
|
||||
# channel_info.capacity = capacity
|
||||
|
||||
def print_change(self, old_policy, new_policy):
|
||||
# print what changed between policies
|
||||
|
@ -340,89 +365,74 @@ class ChannelDB(SqlDB):
|
|||
if old_policy.channel_flags != new_policy.channel_flags:
|
||||
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
|
||||
|
||||
@sql
|
||||
def get_info_for_updates(self, payloads):
|
||||
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
|
||||
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
|
||||
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
|
||||
return channel_infos
|
||||
|
||||
@sql
|
||||
def get_policies_for_updates(self, payloads):
|
||||
out = {}
|
||||
for payload in payloads:
|
||||
short_channel_id = payload['short_channel_id'].hex()
|
||||
start_node = payload['start_node'].hex()
|
||||
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
|
||||
if policy:
|
||||
out[short_channel_id+start_node] = policy
|
||||
return out
|
||||
|
||||
@profiler
|
||||
def filter_channel_updates(self, payloads, max_age=None):
|
||||
def add_channel_updates(self, payloads, max_age=None, verify=True):
|
||||
orphaned = [] # no channel announcement for channel update
|
||||
expired = [] # update older than two weeks
|
||||
deprecated = [] # update older than database entry
|
||||
good = {} # good updates
|
||||
good = [] # good updates
|
||||
to_delete = [] # database entries to delete
|
||||
# filter orphaned and expired first
|
||||
known = []
|
||||
now = int(time.time())
|
||||
channel_infos = self.get_info_for_updates(payloads)
|
||||
for payload in payloads:
|
||||
short_channel_id = payload['short_channel_id']
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
if max_age and now - timestamp > max_age:
|
||||
expired.append(short_channel_id)
|
||||
continue
|
||||
channel_info = channel_infos.get(short_channel_id)
|
||||
channel_info = self._channels.get(short_channel_id)
|
||||
if not channel_info:
|
||||
orphaned.append(short_channel_id)
|
||||
continue
|
||||
flags = int.from_bytes(payload['channel_flags'], 'big')
|
||||
direction = flags & FLAG_DIRECTION
|
||||
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
|
||||
payload['start_node'] = bfh(start_node)
|
||||
payload['start_node'] = start_node
|
||||
known.append(payload)
|
||||
# compare updates to existing database entries
|
||||
old_policies = self.get_policies_for_updates(known)
|
||||
for payload in known:
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
start_node = payload['start_node']
|
||||
short_channel_id = payload['short_channel_id']
|
||||
key = (short_channel_id+start_node).hex()
|
||||
old_policy = old_policies.get(key)
|
||||
if old_policy:
|
||||
if timestamp <= old_policy.timestamp:
|
||||
deprecated.append(short_channel_id)
|
||||
else:
|
||||
good[key] = payload
|
||||
to_delete.append(old_policy)
|
||||
else:
|
||||
good[key] = payload
|
||||
good = list(good.values())
|
||||
key = (start_node, short_channel_id)
|
||||
old_policy = self._policies.get(key)
|
||||
if old_policy and timestamp <= old_policy.timestamp:
|
||||
deprecated.append(short_channel_id)
|
||||
continue
|
||||
good.append(payload)
|
||||
if verify:
|
||||
self.verify_channel_update(payload)
|
||||
policy = Policy.from_msg(payload)
|
||||
self._policies[key] = policy
|
||||
self.save_policy(policy)
|
||||
#
|
||||
self.update_counts()
|
||||
return orphaned, expired, deprecated, good, to_delete
|
||||
|
||||
def add_channel_update(self, payload):
|
||||
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
|
||||
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
|
||||
assert len(good) == 1
|
||||
self.update_policies(good, to_delete)
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def update_policies(self, to_add, to_delete):
|
||||
for policy in to_delete:
|
||||
self.DBSession.delete(policy)
|
||||
self.DBSession.commit()
|
||||
for payload in to_add:
|
||||
policy = Policy.from_msg(payload)
|
||||
self.DBSession.add(policy)
|
||||
self.DBSession.commit()
|
||||
self._update_counts()
|
||||
def save_policy(self, policy):
|
||||
self.DBSession.execute(PolicyBase.__table__.insert().values(policy))
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def on_node_announcement(self, msg_payloads):
|
||||
def delete_policy(self, short_channel_id, node_id):
|
||||
self.DBSession.execute(PolicyBase.__table__.delete().values(policy))
|
||||
|
||||
@sql
|
||||
def save_channel(self, channel_info):
|
||||
self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info))
|
||||
|
||||
def verify_channel_update(self, payload):
|
||||
short_channel_id = payload['short_channel_id']
|
||||
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
|
||||
raise Exception('wrong chain hash')
|
||||
if not verify_sig_for_channel_update(payload, payload['start_node']):
|
||||
raise BaseException('verify error')
|
||||
|
||||
def add_node_announcement(self, msg_payloads):
|
||||
if type(msg_payloads) is dict:
|
||||
msg_payloads = [msg_payloads]
|
||||
old_addr = None
|
||||
|
@ -435,29 +445,35 @@ class ChannelDB(SqlDB):
|
|||
continue
|
||||
node_id = node_info.node_id
|
||||
# Ignore node if it has no associated channel (DoS protection)
|
||||
# FIXME this is slow
|
||||
expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id)
|
||||
if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0:
|
||||
if node_id not in self._channels_for_node:
|
||||
#self.logger.info('ignoring orphan node_announcement')
|
||||
continue
|
||||
node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none()
|
||||
node = self._nodes.get(node_id)
|
||||
if node and node.timestamp >= node_info.timestamp:
|
||||
continue
|
||||
node = new_nodes.get(node_id)
|
||||
if node and node.timestamp >= node_info.timestamp:
|
||||
continue
|
||||
new_nodes[node_id] = node_info
|
||||
# save
|
||||
self._nodes[node_id] = node_info
|
||||
self.save_node(node_info)
|
||||
for addr in node_addresses:
|
||||
new_addresses[(addr.node_id,addr.host,addr.port)] = addr
|
||||
self._addresses[node_id].add((addr.host, addr.port, 0))
|
||||
self.save_node_addresses(node_id, node_addresses)
|
||||
|
||||
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||
for node_info in new_nodes.values():
|
||||
self.DBSession.add(node_info)
|
||||
for new_addr in new_addresses.values():
|
||||
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
|
||||
self.update_counts()
|
||||
|
||||
@sql
|
||||
def save_node_addresses(self, node_if, node_addresses):
|
||||
for new_addr in node_addresses:
|
||||
old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
|
||||
if not old_addr:
|
||||
self.DBSession.add(new_addr)
|
||||
self.DBSession.commit()
|
||||
self._update_counts()
|
||||
self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
|
||||
|
||||
@sql
|
||||
def save_node(self, node_info):
|
||||
self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
|
||||
|
||||
def get_routing_policy_for_channel(self, start_node_id: bytes,
|
||||
short_channel_id: bytes) -> Optional[bytes]:
|
||||
|
@ -470,41 +486,28 @@ class ChannelDB(SqlDB):
|
|||
return None
|
||||
return Policy.from_msg(msg) # won't actually be written to DB
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def get_old_policies(self, delta):
|
||||
timestamp = int(time.time()) - delta
|
||||
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
|
||||
return old_policies.distinct().count()
|
||||
now = int(time.time())
|
||||
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def prune_old_policies(self, delta):
|
||||
# note: delete queries are order sensitive
|
||||
timestamp = int(time.time()) - delta
|
||||
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp)
|
||||
delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies))
|
||||
delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp)
|
||||
self.DBSession.execute(delete_old_channels)
|
||||
self.DBSession.execute(delete_old_policies)
|
||||
self.DBSession.commit()
|
||||
self._update_counts()
|
||||
l = self.get_old_policies(delta)
|
||||
for k in l:
|
||||
self._policies.pop(k)
|
||||
if l:
|
||||
self.logger.info(f'Deleting {len(l)} old policies')
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def get_orphaned_channels(self):
|
||||
subquery = self.DBSession.query(Policy.short_channel_id)
|
||||
orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery)))
|
||||
return orphaned.count()
|
||||
ids = set(x[1] for x in self._policies.keys())
|
||||
return list(x for x in self._channels.keys() if x not in ids)
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def prune_orphaned_channels(self):
|
||||
subquery = self.DBSession.query(Policy.short_channel_id)
|
||||
delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery)))
|
||||
self.DBSession.execute(delete_orphaned)
|
||||
self.DBSession.commit()
|
||||
self._update_counts()
|
||||
l = self.get_orphaned_channels()
|
||||
for k in l:
|
||||
self._channels.pop(k)
|
||||
self.update_counts()
|
||||
if l:
|
||||
self.logger.info(f'Deleting {len(l)} orphaned channels')
|
||||
|
||||
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
|
||||
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
||||
|
@ -513,67 +516,27 @@ class ChannelDB(SqlDB):
|
|||
msg_payload['start_node'] = start_node_id
|
||||
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
||||
|
||||
@sql
|
||||
def remove_channel(self, short_channel_id):
|
||||
r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none()
|
||||
if not r:
|
||||
return
|
||||
self.DBSession.delete(r)
|
||||
self.DBSession.commit()
|
||||
self._channels.pop(short_channel_id, None)
|
||||
|
||||
def print_graph(self, full_ids=False):
|
||||
# used for debugging.
|
||||
# FIXME there is a race here - iterables could change size from another thread
|
||||
def other_node_id(node_id, channel_id):
|
||||
channel_info = self.get_channel_info(channel_id)
|
||||
if node_id == channel_info.node1_id:
|
||||
other = channel_info.node2_id
|
||||
else:
|
||||
other = channel_info.node1_id
|
||||
return other if full_ids else other[-4:]
|
||||
|
||||
print_msg('nodes')
|
||||
for node in self.DBSession.query(NodeInfo).all():
|
||||
print_msg(node)
|
||||
|
||||
print_msg('channels')
|
||||
for channel_info in self.DBSession.query(ChannelInfo).all():
|
||||
short_channel_id = channel_info.short_channel_id
|
||||
node1 = channel_info.node1_id
|
||||
node2 = channel_info.node2_id
|
||||
direction1 = self.get_policy_for_node(channel_info, node1) is not None
|
||||
direction2 = self.get_policy_for_node(channel_info, node2) is not None
|
||||
if direction1 and direction2:
|
||||
direction = 'both'
|
||||
elif direction1:
|
||||
direction = 'forward'
|
||||
elif direction2:
|
||||
direction = 'backward'
|
||||
else:
|
||||
direction = 'none'
|
||||
print_msg('{}: {}, {}, {}'
|
||||
.format(bh2u(short_channel_id),
|
||||
bh2u(node1) if full_ids else bh2u(node1[-4:]),
|
||||
bh2u(node2) if full_ids else bh2u(node2[-4:]),
|
||||
direction))
|
||||
|
||||
|
||||
@sql
|
||||
def get_node_addresses(self, node_info):
|
||||
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
|
||||
def get_node_addresses(self, node_id):
|
||||
return self._addresses.get(node_id)
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def load_data(self):
|
||||
r = self.DBSession.query(ChannelInfo).all()
|
||||
self._channels = dict([(bfh(x.short_channel_id), x) for x in r])
|
||||
r = self.DBSession.query(Policy).filter_by().all()
|
||||
self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r])
|
||||
self._channels_for_node = defaultdict(set)
|
||||
for x in self.DBSession.query(AddressBase).all():
|
||||
self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0)))
|
||||
for x in self.DBSession.query(ChannelInfoBase).all():
|
||||
self._channels[x.short_channel_id] = x.to_nametuple()
|
||||
for x in self.DBSession.query(PolicyBase).filter_by().all():
|
||||
p = x.to_nametuple()
|
||||
self._policies[(p.start_node, p.short_channel_id)] = p
|
||||
for channel_info in self._channels.values():
|
||||
self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id))
|
||||
self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id))
|
||||
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
||||
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
||||
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
|
||||
self.update_counts()
|
||||
|
||||
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
|
||||
return self._policies.get((node_id, short_channel_id))
|
||||
|
@ -584,6 +547,3 @@ class ChannelDB(SqlDB):
|
|||
def get_channels_for_node(self, node_id) -> Set[bytes]:
|
||||
"""Returns the set of channels that have node_id as one of the endpoints."""
|
||||
return self._channels_for_node.get(node_id) or set()
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -56,10 +56,11 @@ class WatcherList(MyTreeView):
|
|||
return
|
||||
self.model().clear()
|
||||
self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
|
||||
sweepstore = self.parent.lnwatcher.sweepstore
|
||||
for outpoint in sweepstore.list_sweep_tx():
|
||||
n = sweepstore.get_num_tx(outpoint)
|
||||
status = self.parent.lnwatcher.get_channel_status(outpoint)
|
||||
lnwatcher = self.parent.lnwatcher
|
||||
l = lnwatcher.list_sweep_tx()
|
||||
for outpoint in l:
|
||||
n = lnwatcher.get_num_tx(outpoint)
|
||||
status = lnwatcher.get_channel_status(outpoint)
|
||||
items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
|
||||
self.model().insertRow(self.model().rowCount(), items)
|
||||
|
||||
|
|
|
@ -258,14 +258,21 @@ class LnAddr(object):
|
|||
def get_min_final_cltv_expiry(self) -> int:
|
||||
return self._min_final_cltv_expiry
|
||||
|
||||
def get_description(self):
|
||||
def get_tag(self, tag):
|
||||
description = ''
|
||||
for k,v in self.tags:
|
||||
if k == 'd':
|
||||
if k == tag:
|
||||
description = v
|
||||
break
|
||||
return description
|
||||
|
||||
def get_description(self):
|
||||
return self.get_tag('d')
|
||||
|
||||
def get_expiry(self):
|
||||
return int(self.get_tag('x') or '3600')
|
||||
|
||||
|
||||
|
||||
def lndecode(a, verbose=False, expected_hrp=None):
|
||||
if expected_hrp is None:
|
||||
|
|
|
@ -163,8 +163,6 @@ class Channel(Logger):
|
|||
self._is_funding_txo_spent = None # "don't know"
|
||||
self._state = None
|
||||
self.set_state('DISCONNECTED')
|
||||
self.lnwatcher = None
|
||||
|
||||
self.local_commitment = None
|
||||
self.remote_commitment = None
|
||||
self.sweep_info = None
|
||||
|
@ -453,13 +451,10 @@ class Channel(Logger):
|
|||
return secret, point
|
||||
|
||||
def process_new_revocation_secret(self, per_commitment_secret: bytes):
|
||||
if not self.lnwatcher:
|
||||
return
|
||||
outpoint = self.funding_outpoint.to_str()
|
||||
ctx = self.remote_commitment_to_be_revoked # FIXME can't we just reconstruct it?
|
||||
sweeptxs = create_sweeptxs_for_their_revoked_ctx(self, ctx, per_commitment_secret, self.sweep_address)
|
||||
for tx in sweeptxs:
|
||||
self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx))
|
||||
return sweeptxs
|
||||
|
||||
def receive_revocation(self, revocation: RevokeAndAck):
|
||||
self.logger.info("receive_revocation")
|
||||
|
@ -477,9 +472,10 @@ class Channel(Logger):
|
|||
|
||||
# be robust to exceptions raised in lnwatcher
|
||||
try:
|
||||
self.process_new_revocation_secret(revocation.per_commitment_secret)
|
||||
sweeptxs = self.process_new_revocation_secret(revocation.per_commitment_secret)
|
||||
except Exception as e:
|
||||
self.logger.info("Could not process revocation secret: {}".format(repr(e)))
|
||||
sweeptxs = []
|
||||
|
||||
##### start applying fee/htlc changes
|
||||
|
||||
|
@ -505,6 +501,8 @@ class Channel(Logger):
|
|||
|
||||
self.set_remote_commitment()
|
||||
self.remote_commitment_to_be_revoked = prev_remote_commitment
|
||||
# return sweep transactions for watchtower
|
||||
return sweeptxs
|
||||
|
||||
def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None):
|
||||
"""
|
||||
|
|
|
@ -42,7 +42,6 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
|
|||
MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY)
|
||||
from .lntransport import LNTransport, LNTransportBase
|
||||
from .lnmsg import encode_msg, decode_msg
|
||||
from .lnverifier import verify_sig_for_channel_update
|
||||
from .interface import GracefulDisconnect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -242,22 +241,20 @@ class Peer(Logger):
|
|||
# channel announcements
|
||||
for chan_anns_chunk in chunks(chan_anns, 300):
|
||||
self.verify_channel_announcements(chan_anns_chunk)
|
||||
self.channel_db.on_channel_announcement(chan_anns_chunk)
|
||||
self.channel_db.add_channel_announcement(chan_anns_chunk)
|
||||
# node announcements
|
||||
for node_anns_chunk in chunks(node_anns, 100):
|
||||
self.verify_node_announcements(node_anns_chunk)
|
||||
self.channel_db.on_node_announcement(node_anns_chunk)
|
||||
self.channel_db.add_node_announcement(node_anns_chunk)
|
||||
# channel updates
|
||||
for chan_upds_chunk in chunks(chan_upds, 1000):
|
||||
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds_chunk,
|
||||
max_age=self.network.lngossip.max_age)
|
||||
orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates(
|
||||
chan_upds_chunk, max_age=self.network.lngossip.max_age)
|
||||
if orphaned:
|
||||
self.logger.info(f'adding {len(orphaned)} unknown channel ids')
|
||||
self.network.lngossip.add_new_ids(orphaned)
|
||||
await self.network.lngossip.add_new_ids(orphaned)
|
||||
if good:
|
||||
self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds_chunk)}')
|
||||
self.verify_channel_updates(good)
|
||||
self.channel_db.update_policies(good, to_delete)
|
||||
# refresh gui
|
||||
if chan_anns or node_anns or chan_upds:
|
||||
self.network.lngossip.refresh_gui()
|
||||
|
@ -279,14 +276,6 @@ class Peer(Logger):
|
|||
if not ecc.verify_signature(pubkey, signature, h):
|
||||
raise Exception('signature failed')
|
||||
|
||||
def verify_channel_updates(self, chan_upds):
|
||||
for payload in chan_upds:
|
||||
short_channel_id = payload['short_channel_id']
|
||||
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
|
||||
raise Exception('wrong chain hash')
|
||||
if not verify_sig_for_channel_update(payload, payload['start_node']):
|
||||
raise BaseException('verify error')
|
||||
|
||||
async def query_gossip(self):
|
||||
try:
|
||||
await asyncio.wait_for(self.initialized.wait(), 10)
|
||||
|
@ -298,7 +287,7 @@ class Peer(Logger):
|
|||
except asyncio.TimeoutError as e:
|
||||
raise GracefulDisconnect("query_channel_range timed out") from e
|
||||
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
|
||||
self.lnworker.add_new_ids(ids)
|
||||
await self.lnworker.add_new_ids(ids)
|
||||
while True:
|
||||
todo = self.lnworker.get_ids_to_query()
|
||||
if not todo:
|
||||
|
@ -658,7 +647,7 @@ class Peer(Logger):
|
|||
)
|
||||
chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig)
|
||||
self.lnworker.save_channel(chan)
|
||||
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
|
||||
await self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
|
||||
self.lnworker.on_channels_updated()
|
||||
while True:
|
||||
try:
|
||||
|
@ -862,8 +851,6 @@ class Peer(Logger):
|
|||
bitcoin_key_2=bitcoin_keys[1]
|
||||
)
|
||||
|
||||
print("SENT CHANNEL ANNOUNCEMENT")
|
||||
|
||||
def mark_open(self, chan: Channel):
|
||||
assert chan.short_channel_id is not None
|
||||
if chan.get_state() == "OPEN":
|
||||
|
@ -872,6 +859,10 @@ class Peer(Logger):
|
|||
assert chan.config[LOCAL].funding_locked_received
|
||||
chan.set_state("OPEN")
|
||||
self.network.trigger_callback('channel', chan)
|
||||
asyncio.ensure_future(self.add_own_channel(chan))
|
||||
self.logger.info("CHANNEL OPENING COMPLETED")
|
||||
|
||||
async def add_own_channel(self, chan):
|
||||
# add channel to database
|
||||
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
|
||||
sorted_node_ids = list(sorted(self.node_ids))
|
||||
|
@ -887,7 +878,7 @@ class Peer(Logger):
|
|||
# that the remote sends, even if the channel was not announced
|
||||
# (from BOLT-07: "MAY create a channel_update to communicate the channel
|
||||
# parameters to the final node, even though the channel has not yet been announced")
|
||||
self.channel_db.on_channel_announcement(
|
||||
self.channel_db.add_channel_announcement(
|
||||
{
|
||||
"short_channel_id": chan.short_channel_id,
|
||||
"node_id_1": node_ids[0],
|
||||
|
@ -922,8 +913,6 @@ class Peer(Logger):
|
|||
if pending_channel_update:
|
||||
self.channel_db.add_channel_update(pending_channel_update)
|
||||
|
||||
self.logger.info("CHANNEL OPENING COMPLETED")
|
||||
|
||||
def send_announcement_signatures(self, chan: Channel):
|
||||
|
||||
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
|
||||
|
@ -962,6 +951,21 @@ class Peer(Logger):
|
|||
def on_update_fail_htlc(self, payload):
|
||||
channel_id = payload["channel_id"]
|
||||
htlc_id = int.from_bytes(payload["id"], "big")
|
||||
chan = self.channels[channel_id]
|
||||
chan.receive_fail_htlc(htlc_id)
|
||||
local_ctn = chan.get_current_ctn(LOCAL)
|
||||
asyncio.ensure_future(self._handle_error_code_from_failed_htlc(payload, channel_id, htlc_id))
|
||||
asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn))
|
||||
|
||||
@log_exceptions
|
||||
async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn):
|
||||
chan = self.channels[channel_id]
|
||||
await self.await_local(chan, local_ctn)
|
||||
self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False)
|
||||
|
||||
@log_exceptions
|
||||
async def _handle_error_code_from_failed_htlc(self, payload, channel_id, htlc_id):
|
||||
chan = self.channels[channel_id]
|
||||
key = (channel_id, htlc_id)
|
||||
try:
|
||||
route = self.attempted_route[key]
|
||||
|
@ -969,29 +973,12 @@ class Peer(Logger):
|
|||
# the remote might try to fail an htlc after we restarted...
|
||||
# attempted_route is not persisted, so we will get here then
|
||||
self.logger.info("UPDATE_FAIL_HTLC. cannot decode! attempted route is MISSING. {}".format(key))
|
||||
else:
|
||||
try:
|
||||
self._handle_error_code_from_failed_htlc(payload["reason"], route, channel_id, htlc_id)
|
||||
except Exception:
|
||||
# exceptions are suppressed as failing to handle an error code
|
||||
# should not block us from removing the htlc
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
# process update_fail_htlc on channel
|
||||
chan = self.channels[channel_id]
|
||||
chan.receive_fail_htlc(htlc_id)
|
||||
local_ctn = chan.get_current_ctn(LOCAL)
|
||||
asyncio.ensure_future(self._on_update_fail_htlc(chan, htlc_id, local_ctn))
|
||||
|
||||
@log_exceptions
|
||||
async def _on_update_fail_htlc(self, chan, htlc_id, local_ctn):
|
||||
await self.await_local(chan, local_ctn)
|
||||
self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False)
|
||||
|
||||
def _handle_error_code_from_failed_htlc(self, error_reason, route: List['RouteEdge'], channel_id, htlc_id):
|
||||
chan = self.channels[channel_id]
|
||||
failure_msg, sender_idx = decode_onion_error(error_reason,
|
||||
[x.node_id for x in route],
|
||||
chan.onion_keys[htlc_id])
|
||||
return
|
||||
error_reason = payload["reason"]
|
||||
failure_msg, sender_idx = decode_onion_error(
|
||||
error_reason,
|
||||
[x.node_id for x in route],
|
||||
chan.onion_keys[htlc_id])
|
||||
code, data = failure_msg.code, failure_msg.data
|
||||
self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}")
|
||||
self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}")
|
||||
|
@ -1009,11 +996,9 @@ class Peer(Logger):
|
|||
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
|
||||
message_type, payload = decode_msg(channel_update)
|
||||
payload['raw'] = channel_update
|
||||
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
|
||||
orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates([payload])
|
||||
blacklist = False
|
||||
if good:
|
||||
self.verify_channel_updates(good)
|
||||
self.channel_db.update_policies(good, to_delete)
|
||||
self.logger.info("applied channel update on our db")
|
||||
elif orphaned:
|
||||
# maybe it is a private channel (and data in invoice was outdated)
|
||||
|
@ -1276,11 +1261,17 @@ class Peer(Logger):
|
|||
self.logger.info("on_revoke_and_ack")
|
||||
channel_id = payload["channel_id"]
|
||||
chan = self.channels[channel_id]
|
||||
chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]))
|
||||
sweeptxs = chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]))
|
||||
self._remote_changed_events[chan.channel_id].set()
|
||||
self._remote_changed_events[chan.channel_id].clear()
|
||||
self.lnworker.save_channel(chan)
|
||||
self.maybe_send_commitment(chan)
|
||||
asyncio.ensure_future(self._on_revoke_and_ack(chan, sweeptxs))
|
||||
|
||||
async def _on_revoke_and_ack(self, chan, sweeptxs):
|
||||
outpoint = chan.funding_outpoint.to_str()
|
||||
for tx in sweeptxs:
|
||||
await self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx))
|
||||
|
||||
def on_update_fee(self, payload):
|
||||
channel_id = payload["channel_id"]
|
||||
|
|
|
@ -37,7 +37,7 @@ import binascii
|
|||
import base64
|
||||
|
||||
from . import constants
|
||||
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks
|
||||
from .util import bh2u, profiler, get_headers_dir, is_ip_address, list_enabled_bits, print_msg, chunks
|
||||
from .logging import Logger
|
||||
from .storage import JsonDB
|
||||
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
|
||||
|
@ -169,7 +169,6 @@ class LNPathFinder(Logger):
|
|||
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
|
||||
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
|
||||
"""
|
||||
self.channel_db.load_data()
|
||||
assert type(nodeA) is bytes
|
||||
assert type(nodeB) is bytes
|
||||
assert type(invoice_amount_msat) is int
|
||||
|
@ -195,11 +194,12 @@ class LNPathFinder(Logger):
|
|||
else: # payment incoming, on our channel. (funny business, cycle weirdness)
|
||||
assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
|
||||
pass # TODO?
|
||||
edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id,
|
||||
start_node=edge_startnode,
|
||||
end_node=edge_endnode,
|
||||
payment_amt_msat=amount_msat,
|
||||
ignore_costs=(edge_startnode == nodeA))
|
||||
edge_cost, fee_for_edge_msat = self._edge_cost(
|
||||
edge_channel_id,
|
||||
start_node=edge_startnode,
|
||||
end_node=edge_endnode,
|
||||
payment_amt_msat=amount_msat,
|
||||
ignore_costs=(edge_startnode == nodeA))
|
||||
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
|
||||
if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
|
||||
distance_from_start[edge_startnode] = alt_dist_to_neighbour
|
||||
|
@ -219,9 +219,10 @@ class LNPathFinder(Logger):
|
|||
continue
|
||||
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
|
||||
assert type(edge_channel_id) is bytes
|
||||
if edge_channel_id in self.blacklist: continue
|
||||
if edge_channel_id in self.blacklist:
|
||||
continue
|
||||
channel_info = self.channel_db.get_channel_info(edge_channel_id)
|
||||
edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id)
|
||||
edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
|
||||
inspect_edge()
|
||||
else:
|
||||
return None # no path found
|
||||
|
|
|
@ -70,11 +70,11 @@ class SweepStore(SqlDB):
|
|||
@sql
|
||||
def get_tx_by_index(self, funding_outpoint, index):
|
||||
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
|
||||
return r.prevout, bh2u(r.tx)
|
||||
return str(r.prevout), bh2u(r.tx)
|
||||
|
||||
@sql
|
||||
def list_sweep_tx(self):
|
||||
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all())
|
||||
return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all())
|
||||
|
||||
@sql
|
||||
def add_sweep_tx(self, funding_outpoint, prevout, tx):
|
||||
|
@ -84,7 +84,7 @@ class SweepStore(SqlDB):
|
|||
|
||||
@sql
|
||||
def get_num_tx(self, funding_outpoint):
|
||||
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count()
|
||||
return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count())
|
||||
|
||||
@sql
|
||||
def remove_sweep_tx(self, funding_outpoint):
|
||||
|
@ -111,11 +111,11 @@ class SweepStore(SqlDB):
|
|||
@sql
|
||||
def get_address(self, outpoint):
|
||||
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
|
||||
return r.address if r else None
|
||||
return str(r.address) if r else None
|
||||
|
||||
@sql
|
||||
def list_channel_info(self):
|
||||
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()]
|
||||
return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()]
|
||||
|
||||
|
||||
class LNWatcher(AddressSynchronizer):
|
||||
|
@ -150,14 +150,21 @@ class LNWatcher(AddressSynchronizer):
|
|||
self.watchtower_queue = asyncio.Queue()
|
||||
|
||||
def get_num_tx(self, outpoint):
|
||||
return self.sweepstore.get_num_tx(outpoint)
|
||||
async def f():
|
||||
return await self.sweepstore.get_num_tx(outpoint)
|
||||
return self.network.run_from_another_thread(f())
|
||||
|
||||
def list_sweep_tx(self):
|
||||
async def f():
|
||||
return await self.sweepstore.list_sweep_tx()
|
||||
return self.network.run_from_another_thread(f())
|
||||
|
||||
@ignore_exceptions
|
||||
@log_exceptions
|
||||
async def watchtower_task(self):
|
||||
self.logger.info('watchtower task started')
|
||||
# initial check
|
||||
for address, outpoint in self.sweepstore.list_channel_info():
|
||||
for address, outpoint in await self.sweepstore.list_channel_info():
|
||||
await self.watchtower_queue.put(outpoint)
|
||||
while True:
|
||||
outpoint = await self.watchtower_queue.get()
|
||||
|
@ -165,30 +172,30 @@ class LNWatcher(AddressSynchronizer):
|
|||
continue
|
||||
# synchronize with remote
|
||||
try:
|
||||
local_n = self.sweepstore.get_num_tx(outpoint)
|
||||
local_n = await self.sweepstore.get_num_tx(outpoint)
|
||||
n = self.watchtower.get_num_tx(outpoint)
|
||||
if n == 0:
|
||||
address = self.sweepstore.get_address(outpoint)
|
||||
address = await self.sweepstore.get_address(outpoint)
|
||||
self.watchtower.add_channel(outpoint, address)
|
||||
self.logger.info("sending %d transactions to watchtower"%(local_n - n))
|
||||
for index in range(n, local_n):
|
||||
prevout, tx = self.sweepstore.get_tx_by_index(outpoint, index)
|
||||
prevout, tx = await self.sweepstore.get_tx_by_index(outpoint, index)
|
||||
self.watchtower.add_sweep_tx(outpoint, prevout, tx)
|
||||
except ConnectionRefusedError:
|
||||
self.logger.info('could not reach watchtower, will retry in 5s')
|
||||
await asyncio.sleep(5)
|
||||
await self.watchtower_queue.put(outpoint)
|
||||
|
||||
def add_channel(self, outpoint, address):
|
||||
async def add_channel(self, outpoint, address):
|
||||
self.add_address(address)
|
||||
with self.lock:
|
||||
if not self.sweepstore.has_channel(outpoint):
|
||||
self.sweepstore.add_channel(outpoint, address)
|
||||
if not await self.sweepstore.has_channel(outpoint):
|
||||
await self.sweepstore.add_channel(outpoint, address)
|
||||
|
||||
def unwatch_channel(self, address, funding_outpoint):
|
||||
async def unwatch_channel(self, address, funding_outpoint):
|
||||
self.logger.info(f'unwatching {funding_outpoint}')
|
||||
self.sweepstore.remove_sweep_tx(funding_outpoint)
|
||||
self.sweepstore.remove_channel(funding_outpoint)
|
||||
await self.sweepstore.remove_sweep_tx(funding_outpoint)
|
||||
await self.sweepstore.remove_channel(funding_outpoint)
|
||||
if funding_outpoint in self.tx_progress:
|
||||
self.tx_progress[funding_outpoint].all_done.set()
|
||||
|
||||
|
@ -202,7 +209,7 @@ class LNWatcher(AddressSynchronizer):
|
|||
return
|
||||
if not self.synchronizer.is_up_to_date():
|
||||
return
|
||||
for address, outpoint in self.sweepstore.list_channel_info():
|
||||
for address, outpoint in await self.sweepstore.list_channel_info():
|
||||
await self.check_onchain_situation(address, outpoint)
|
||||
|
||||
async def check_onchain_situation(self, address, funding_outpoint):
|
||||
|
@ -223,7 +230,7 @@ class LNWatcher(AddressSynchronizer):
|
|||
closing_height, closing_tx) # FIXME sooo many args..
|
||||
await self.do_breach_remedy(funding_outpoint, spenders)
|
||||
if not keep_watching:
|
||||
self.unwatch_channel(address, funding_outpoint)
|
||||
await self.unwatch_channel(address, funding_outpoint)
|
||||
else:
|
||||
#self.logger.info(f'we will keep_watching {funding_outpoint}')
|
||||
pass
|
||||
|
@ -260,7 +267,7 @@ class LNWatcher(AddressSynchronizer):
|
|||
for prevout, spender in spenders.items():
|
||||
if spender is not None:
|
||||
continue
|
||||
sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prevout)
|
||||
sweep_txns = await self.sweepstore.get_sweep_tx(funding_outpoint, prevout)
|
||||
for tx in sweep_txns:
|
||||
if not await self.broadcast_or_log(funding_outpoint, tx):
|
||||
self.logger.info(f'{tx.name} could not publish tx: {str(tx)}, prevout: {prevout}')
|
||||
|
@ -279,8 +286,8 @@ class LNWatcher(AddressSynchronizer):
|
|||
await self.tx_progress[funding_outpoint].tx_queue.put(tx)
|
||||
return txid
|
||||
|
||||
def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str):
|
||||
self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx)
|
||||
async def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str):
|
||||
await self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx)
|
||||
if self.watchtower:
|
||||
self.watchtower_queue.put_nowait(funding_outpoint)
|
||||
|
||||
|
|
|
@ -108,12 +108,14 @@ class LNWorker(Logger):
|
|||
|
||||
@log_exceptions
|
||||
async def main_loop(self):
|
||||
# fixme: only lngossip should do that
|
||||
await self.channel_db.load_data()
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
now = time.time()
|
||||
if len(self.peers) >= NUM_PEERS_TARGET:
|
||||
continue
|
||||
peers = self._get_next_peers_to_try()
|
||||
peers = await self._get_next_peers_to_try()
|
||||
for peer in peers:
|
||||
last_tried = self._last_tried_peer.get(peer, 0)
|
||||
if last_tried + PEER_RETRY_INTERVAL < now:
|
||||
|
@ -130,7 +132,8 @@ class LNWorker(Logger):
|
|||
peer = Peer(self, node_id, transport)
|
||||
await self.network.main_taskgroup.spawn(peer.main_loop())
|
||||
self.peers[node_id] = peer
|
||||
self.network.lngossip.refresh_gui()
|
||||
#if self.network.lngossip:
|
||||
# self.network.lngossip.refresh_gui()
|
||||
return peer
|
||||
|
||||
def start_network(self, network: 'Network'):
|
||||
|
@ -148,7 +151,7 @@ class LNWorker(Logger):
|
|||
self._add_peer(host, int(port), bfh(pubkey)),
|
||||
self.network.asyncio_loop)
|
||||
|
||||
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
|
||||
async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
|
||||
now = time.time()
|
||||
recent_peers = self.channel_db.get_recent_peers()
|
||||
# maintenance for last tried times
|
||||
|
@ -158,19 +161,22 @@ class LNWorker(Logger):
|
|||
del self._last_tried_peer[peer]
|
||||
# first try from recent peers
|
||||
for peer in recent_peers:
|
||||
if peer.pubkey in self.peers: continue
|
||||
if peer in self._last_tried_peer: continue
|
||||
if peer.pubkey in self.peers:
|
||||
continue
|
||||
if peer in self._last_tried_peer:
|
||||
continue
|
||||
return [peer]
|
||||
# try random peer from graph
|
||||
unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
|
||||
if unconnected_nodes:
|
||||
for node in unconnected_nodes:
|
||||
addrs = self.channel_db.get_node_addresses(node)
|
||||
for node_id in unconnected_nodes:
|
||||
addrs = self.channel_db.get_node_addresses(node_id)
|
||||
if not addrs:
|
||||
continue
|
||||
host, port = self.choose_preferred_address(addrs)
|
||||
peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id))
|
||||
if peer in self._last_tried_peer: continue
|
||||
host, port, timestamp = self.choose_preferred_address(addrs)
|
||||
peer = LNPeerAddr(host, port, node_id)
|
||||
if peer in self._last_tried_peer:
|
||||
continue
|
||||
#self.logger.info('taking random ln peer from our channel db')
|
||||
return [peer]
|
||||
|
||||
|
@ -223,15 +229,13 @@ class LNWorker(Logger):
|
|||
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
|
||||
assert len(addr_list) >= 1
|
||||
# choose first one that is an IP
|
||||
for addr_in_db in addr_list:
|
||||
host = addr_in_db.host
|
||||
port = addr_in_db.port
|
||||
for host, port, timestamp in addr_list:
|
||||
if is_ip_address(host):
|
||||
return host, port
|
||||
return host, port, timestamp
|
||||
# otherwise choose one at random
|
||||
# TODO maybe filter out onion if not on tor?
|
||||
choice = random.choice(addr_list)
|
||||
return choice.host, choice.port
|
||||
return choice
|
||||
|
||||
|
||||
class LNGossip(LNWorker):
|
||||
|
@ -260,26 +264,19 @@ class LNGossip(LNWorker):
|
|||
self.network.trigger_callback('ln_status', num_peers, num_nodes, known, unknown)
|
||||
|
||||
async def maintain_db(self):
|
||||
n = self.channel_db.get_orphaned_channels()
|
||||
if n:
|
||||
self.logger.info(f'Deleting {n} orphaned channels')
|
||||
self.channel_db.prune_orphaned_channels()
|
||||
self.refresh_gui()
|
||||
self.channel_db.prune_orphaned_channels()
|
||||
while True:
|
||||
n = self.channel_db.get_old_policies(self.max_age)
|
||||
if n:
|
||||
self.logger.info(f'Deleting {n} old channels')
|
||||
self.channel_db.prune_old_policies(self.max_age)
|
||||
self.refresh_gui()
|
||||
self.channel_db.prune_old_policies(self.max_age)
|
||||
self.refresh_gui()
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def add_new_ids(self, ids):
|
||||
known = self.channel_db.known_ids()
|
||||
async def add_new_ids(self, ids):
|
||||
known = self.channel_db.get_channel_ids()
|
||||
new = set(ids) - set(known)
|
||||
self.unknown_ids.update(new)
|
||||
|
||||
def get_ids_to_query(self):
|
||||
N = 500
|
||||
N = 100
|
||||
l = list(self.unknown_ids)
|
||||
self.unknown_ids = set(l[N:])
|
||||
return l[0:N]
|
||||
|
@ -324,9 +321,10 @@ class LNWallet(LNWorker):
|
|||
self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee']) # thread safe
|
||||
self.network.register_callback(self.on_channel_open, ['channel_open'])
|
||||
self.network.register_callback(self.on_channel_closed, ['channel_closed'])
|
||||
|
||||
for chan_id, chan in self.channels.items():
|
||||
self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
|
||||
chan.lnwatcher = network.lnwatcher
|
||||
self.network.lnwatcher.add_address(chan.get_funding_address())
|
||||
|
||||
super().start_network(network)
|
||||
for coro in [
|
||||
self.maybe_listen(),
|
||||
|
@ -494,7 +492,7 @@ class LNWallet(LNWorker):
|
|||
chan = self.channel_by_txo(funding_outpoint)
|
||||
if not chan:
|
||||
return
|
||||
self.logger.debug(f'on_channel_open {funding_outpoint}')
|
||||
#self.logger.debug(f'on_channel_open {funding_outpoint}')
|
||||
self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, None, None, None
|
||||
self.storage.put('lightning_channel_timestamps', self.channel_timestamps)
|
||||
chan.set_funding_txo_spentness(False)
|
||||
|
@ -606,7 +604,8 @@ class LNWallet(LNWorker):
|
|||
self.logger.info('REBROADCASTING CLOSING TX')
|
||||
await self.force_close_channel(chan.channel_id)
|
||||
|
||||
async def _open_channel_coroutine(self, peer, local_amount_sat, push_sat, password):
|
||||
async def _open_channel_coroutine(self, connect_str, local_amount_sat, push_sat, password):
|
||||
peer = await self.add_peer(connect_str)
|
||||
# peer might just have been connected to
|
||||
await asyncio.wait_for(peer.initialized.wait(), 5)
|
||||
chan = await peer.channel_establishment_flow(
|
||||
|
@ -615,24 +614,22 @@ class LNWallet(LNWorker):
|
|||
push_msat=push_sat * 1000,
|
||||
temp_channel_id=os.urandom(32))
|
||||
self.save_channel(chan)
|
||||
self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
|
||||
self.network.lnwatcher.add_address(chan.get_funding_address())
|
||||
await self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
|
||||
self.on_channels_updated()
|
||||
return chan
|
||||
|
||||
def on_channels_updated(self):
|
||||
self.network.trigger_callback('channels')
|
||||
|
||||
def add_peer(self, connect_str, timeout=20):
|
||||
async def add_peer(self, connect_str, timeout=20):
|
||||
node_id, rest = extract_nodeid(connect_str)
|
||||
peer = self.peers.get(node_id)
|
||||
if not peer:
|
||||
if rest is not None:
|
||||
host, port = split_host_port(rest)
|
||||
else:
|
||||
node_info = self.network.channel_db.nodes_get(node_id)
|
||||
if not node_info:
|
||||
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
|
||||
addrs = self.channel_db.get_node_addresses(node_info)
|
||||
addrs = self.channel_db.get_node_addresses(node_id)
|
||||
if len(addrs) == 0:
|
||||
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
|
||||
host, port = self.choose_preferred_address(addrs)
|
||||
|
@ -640,18 +637,12 @@ class LNWallet(LNWorker):
|
|||
socket.getaddrinfo(host, int(port))
|
||||
except socket.gaierror:
|
||||
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
|
||||
peer_future = asyncio.run_coroutine_threadsafe(
|
||||
self._add_peer(host, port, node_id),
|
||||
self.network.asyncio_loop)
|
||||
try:
|
||||
peer = peer_future.result(timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise Exception(_("add_peer timed out"))
|
||||
# add peer
|
||||
peer = await self._add_peer(host, port, node_id)
|
||||
return peer
|
||||
|
||||
def open_channel(self, connect_str, local_amt_sat, push_amt_sat, password=None, timeout=20):
|
||||
peer = self.add_peer(connect_str, timeout)
|
||||
coro = self._open_channel_coroutine(peer, local_amt_sat, push_amt_sat, password)
|
||||
coro = self._open_channel_coroutine(connect_str, local_amt_sat, push_amt_sat, password)
|
||||
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
||||
try:
|
||||
chan = fut.result(timeout=timeout)
|
||||
|
@ -664,6 +655,9 @@ class LNWallet(LNWorker):
|
|||
Can be called from other threads
|
||||
Raises timeout exception if htlc is not fulfilled
|
||||
"""
|
||||
addr = self._check_invoice(invoice, amount_sat)
|
||||
self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False)
|
||||
self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description())
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._pay(invoice, attempts, amount_sat),
|
||||
self.network.asyncio_loop)
|
||||
|
@ -680,8 +674,6 @@ class LNWallet(LNWorker):
|
|||
|
||||
async def _pay(self, invoice, attempts=1, amount_sat=None):
|
||||
addr = self._check_invoice(invoice, amount_sat)
|
||||
self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False)
|
||||
self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description())
|
||||
for i in range(attempts):
|
||||
route = await self._create_route_from_invoice(decoded_invoice=addr)
|
||||
if not self.get_channel_by_short_id(route[0].short_channel_id):
|
||||
|
@ -691,7 +683,7 @@ class LNWallet(LNWorker):
|
|||
return True
|
||||
return False
|
||||
|
||||
async def _pay_to_route(self, route, addr, pay_req):
|
||||
async def _pay_to_route(self, route, addr, invoice):
|
||||
short_channel_id = route[0].short_channel_id
|
||||
chan = self.get_channel_by_short_id(short_channel_id)
|
||||
if not chan:
|
||||
|
@ -713,6 +705,9 @@ class LNWallet(LNWorker):
|
|||
raise InvoiceError("{}\n{}".format(
|
||||
_("Invoice wants us to risk locking funds for unreasonably long."),
|
||||
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
|
||||
#now = int(time.time())
|
||||
#if addr.date + addr.get_expiry() > now:
|
||||
# raise InvoiceError(_('Invoice expired'))
|
||||
return addr
|
||||
|
||||
async def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
|
||||
|
@ -730,11 +725,14 @@ class LNWallet(LNWorker):
|
|||
with self.lock:
|
||||
channels = list(self.channels.values())
|
||||
for private_route in r_tags:
|
||||
if len(private_route) == 0: continue
|
||||
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: continue
|
||||
if len(private_route) == 0:
|
||||
continue
|
||||
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
|
||||
continue
|
||||
border_node_pubkey = private_route[0][0]
|
||||
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
|
||||
if not path: continue
|
||||
if not path:
|
||||
continue
|
||||
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
|
||||
# we need to shift the node pubkey by one towards the destination:
|
||||
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
|
||||
|
@ -770,10 +768,18 @@ class LNWallet(LNWorker):
|
|||
return route
|
||||
|
||||
def add_invoice(self, amount_sat, message):
|
||||
coro = self._add_invoice_coro(amount_sat, message)
|
||||
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
||||
try:
|
||||
return fut.result(timeout=5)
|
||||
except concurrent.futures.TimeoutError:
|
||||
raise Exception(_("add_invoice timed out"))
|
||||
|
||||
async def _add_invoice_coro(self, amount_sat, message):
|
||||
payment_preimage = os.urandom(32)
|
||||
payment_hash = sha256(payment_preimage)
|
||||
amount_btc = amount_sat/Decimal(COIN) if amount_sat else None
|
||||
routing_hints = self._calc_routing_hints_for_invoice(amount_sat)
|
||||
routing_hints = await self._calc_routing_hints_for_invoice(amount_sat)
|
||||
if not routing_hints:
|
||||
self.logger.info("Warning. No routing hints added to invoice. "
|
||||
"Other clients will likely not be able to send to us.")
|
||||
|
@ -847,19 +853,20 @@ class LNWallet(LNWorker):
|
|||
})
|
||||
return out
|
||||
|
||||
def _calc_routing_hints_for_invoice(self, amount_sat):
|
||||
async def _calc_routing_hints_for_invoice(self, amount_sat):
|
||||
"""calculate routing hints (BOLT-11 'r' field)"""
|
||||
self.channel_db.load_data()
|
||||
routing_hints = []
|
||||
with self.lock:
|
||||
channels = list(self.channels.values())
|
||||
# note: currently we add *all* our channels; but this might be a privacy leak?
|
||||
for chan in channels:
|
||||
# check channel is open
|
||||
if chan.get_state() != "OPEN": continue
|
||||
if chan.get_state() != "OPEN":
|
||||
continue
|
||||
# check channel has sufficient balance
|
||||
# FIXME because of on-chain fees of ctx, this check is insufficient
|
||||
if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat: continue
|
||||
if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat:
|
||||
continue
|
||||
chan_id = chan.short_channel_id
|
||||
assert type(chan_id) is bytes, chan_id
|
||||
channel_info = self.channel_db.get_channel_info(chan_id)
|
||||
|
@ -949,14 +956,10 @@ class LNWallet(LNWorker):
|
|||
await self._add_peer(peer.host, peer.port, peer.pubkey)
|
||||
return
|
||||
# try random address for node_id
|
||||
node_info = self.channel_db.nodes_get(chan.node_id)
|
||||
if not node_info:
|
||||
return
|
||||
addresses = self.channel_db.get_node_addresses(node_info)
|
||||
addresses = self.channel_db.get_node_addresses(chan.node_id)
|
||||
if not addresses:
|
||||
return
|
||||
adr_obj = random.choice(addresses)
|
||||
host, port = adr_obj.host, adr_obj.port
|
||||
host, port, t = random.choice(list(addresses))
|
||||
peer = LNPeerAddr(host, port, chan.node_id)
|
||||
last_tried = self._last_tried_peer.get(peer, 0)
|
||||
if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import concurrent
|
||||
import queue
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
@ -18,28 +19,32 @@ def sql(func):
|
|||
"""wrapper for sql methods"""
|
||||
def wrapper(self, *args, **kwargs):
|
||||
assert threading.currentThread() != self.sql_thread
|
||||
f = concurrent.futures.Future()
|
||||
f = asyncio.Future()
|
||||
self.db_requests.put((f, func, args, kwargs))
|
||||
return f.result(timeout=10)
|
||||
return f
|
||||
return wrapper
|
||||
|
||||
class SqlDB(Logger):
|
||||
|
||||
def __init__(self, network, path, base):
|
||||
def __init__(self, network, path, base, commit_interval=None):
|
||||
Logger.__init__(self)
|
||||
self.base = base
|
||||
self.network = network
|
||||
self.path = path
|
||||
self.commit_interval = commit_interval
|
||||
self.db_requests = queue.Queue()
|
||||
self.sql_thread = threading.Thread(target=self.run_sql)
|
||||
self.sql_thread.start()
|
||||
|
||||
def run_sql(self):
|
||||
#return
|
||||
self.logger.info("SQL thread started")
|
||||
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
|
||||
DBSession = sessionmaker(bind=engine, autoflush=False)
|
||||
self.DBSession = DBSession()
|
||||
if not os.path.exists(self.path):
|
||||
self.base.metadata.create_all(engine)
|
||||
self.DBSession = DBSession()
|
||||
i = 0
|
||||
while self.network.asyncio_loop.is_running():
|
||||
try:
|
||||
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
|
||||
|
@ -50,7 +55,14 @@ class SqlDB(Logger):
|
|||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
continue
|
||||
future.set_result(result)
|
||||
if not future.cancelled():
|
||||
future.set_result(result)
|
||||
# note: in sweepstore session.commit() is called inside
|
||||
# the sql-decorated methods, so commiting to disk is awaited
|
||||
if self.commit_interval:
|
||||
i = (i + 1) % self.commit_interval
|
||||
if i == 0:
|
||||
self.DBSession.commit()
|
||||
# write
|
||||
self.DBSession.commit()
|
||||
self.logger.info("SQL thread terminated")
|
||||
|
|
|
@ -16,7 +16,8 @@ from electrum.lnpeer import Peer
|
|||
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
||||
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
||||
from electrum.lnutil import PaymentFailure, LnLocalFeatures
|
||||
from electrum.lnrouter import ChannelDB, LNPathFinder
|
||||
from electrum.lnrouter import LNPathFinder
|
||||
from electrum.channel_db import ChannelDB
|
||||
from electrum.lnworker import LNWallet
|
||||
from electrum.lnmsg import encode_msg, decode_msg
|
||||
from electrum.logging import console_stderr_handler
|
||||
|
|
|
@ -59,33 +59,33 @@ class Test_LNRouter(TestCaseForTestnet):
|
|||
cdb = fake_network.channel_db
|
||||
path_finder = lnrouter.LNPathFinder(cdb)
|
||||
self.assertEqual(cdb.num_channels, 0)
|
||||
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
|
||||
cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
|
||||
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc',
|
||||
'short_channel_id': bfh('0000000000000001'),
|
||||
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
|
||||
'len': b'\x00\x00', 'features': b''}, trusted=True)
|
||||
self.assertEqual(cdb.num_channels, 1)
|
||||
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
||||
cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
||||
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
||||
'short_channel_id': bfh('0000000000000002'),
|
||||
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
|
||||
'len': b'\x00\x00', 'features': b''}, trusted=True)
|
||||
cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
|
||||
cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
|
||||
'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
|
||||
'short_channel_id': bfh('0000000000000003'),
|
||||
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
|
||||
'len': b'\x00\x00', 'features': b''}, trusted=True)
|
||||
cdb.on_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
|
||||
cdb.add_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
|
||||
'bitcoin_key_1': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd',
|
||||
'short_channel_id': bfh('0000000000000004'),
|
||||
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
|
||||
'len': b'\x00\x00', 'features': b''}, trusted=True)
|
||||
cdb.on_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
||||
cdb.add_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
||||
'bitcoin_key_1': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
|
||||
'short_channel_id': bfh('0000000000000005'),
|
||||
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
|
||||
'len': b'\x00\x00', 'features': b''}, trusted=True)
|
||||
cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
|
||||
cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
|
||||
'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd',
|
||||
'short_channel_id': bfh('0000000000000006'),
|
||||
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
|
||||
|
|
Loading…
Add table
Reference in a new issue