mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-31 17:31:36 +00:00
lnrouter: load data before finding path
This commit is contained in:
parent
dac686b11d
commit
34f22e6681
2 changed files with 38 additions and 36 deletions
|
@ -229,7 +229,9 @@ class ChannelDB(SqlDB):
|
||||||
|
|
||||||
def _update_counts(self):
|
def _update_counts(self):
|
||||||
self.num_channels = self.DBSession.query(ChannelInfo).count()
|
self.num_channels = self.DBSession.query(ChannelInfo).count()
|
||||||
|
self.num_policies = self.DBSession.query(Policy).count()
|
||||||
self.num_nodes = self.DBSession.query(NodeInfo).count()
|
self.num_nodes = self.DBSession.query(NodeInfo).count()
|
||||||
|
self.print_error('update counts', self.num_channels, self.num_policies)
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
def add_recent_peer(self, peer: LNPeerAddr):
|
def add_recent_peer(self, peer: LNPeerAddr):
|
||||||
|
@ -272,19 +274,6 @@ class ChannelDB(SqlDB):
|
||||||
r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all()
|
r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all()
|
||||||
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
|
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_channel_info(self, channel_id: bytes):
|
|
||||||
return self._chan_query_for_id(channel_id).one_or_none()
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_channels_for_node(self, node_id):
|
|
||||||
"""Returns the set of channels that have node_id as one of the endpoints."""
|
|
||||||
condition = or_(
|
|
||||||
ChannelInfo.node1_id == node_id.hex(),
|
|
||||||
ChannelInfo.node2_id == node_id.hex())
|
|
||||||
rows = self.DBSession.query(ChannelInfo).filter(condition).all()
|
|
||||||
return [bytes.fromhex(x.short_channel_id) for x in rows]
|
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
def missing_short_chan_ids(self) -> Set[int]:
|
def missing_short_chan_ids(self) -> Set[int]:
|
||||||
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
|
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
|
||||||
|
@ -296,7 +285,7 @@ class ChannelDB(SqlDB):
|
||||||
@sql
|
@sql
|
||||||
def add_verified_channel_info(self, short_id, capacity):
|
def add_verified_channel_info(self, short_id, capacity):
|
||||||
# called from lnchannelverifier
|
# called from lnchannelverifier
|
||||||
channel_info = self._chan_query_for_id(short_id).one_or_none()
|
channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none()
|
||||||
channel_info.trusted = True
|
channel_info.trusted = True
|
||||||
channel_info.capacity = capacity
|
channel_info.capacity = capacity
|
||||||
self.DBSession.commit()
|
self.DBSession.commit()
|
||||||
|
@ -372,7 +361,6 @@ class ChannelDB(SqlDB):
|
||||||
if p and p.timestamp >= new_policy.timestamp:
|
if p and p.timestamp >= new_policy.timestamp:
|
||||||
continue
|
continue
|
||||||
new_policies[(short_channel_id, node_id)] = new_policy
|
new_policies[(short_channel_id, node_id)] = new_policy
|
||||||
#self.print_error('on_channel_update: %d/%d'%(len(new_policies), len(msg_payloads)))
|
|
||||||
# commit pending removals
|
# commit pending removals
|
||||||
self.DBSession.commit()
|
self.DBSession.commit()
|
||||||
# add and commit new policies
|
# add and commit new policies
|
||||||
|
@ -380,7 +368,9 @@ class ChannelDB(SqlDB):
|
||||||
self.DBSession.add(new_policy)
|
self.DBSession.add(new_policy)
|
||||||
self.DBSession.commit()
|
self.DBSession.commit()
|
||||||
if new_policies:
|
if new_policies:
|
||||||
|
self.print_error('on_channel_update: %d/%d'%(len(new_policies), len(msg_payloads)))
|
||||||
self.print_error('last timestamp:', datetime.fromtimestamp(self._get_last_timestamp()).ctime())
|
self.print_error('last timestamp:', datetime.fromtimestamp(self._get_last_timestamp()).ctime())
|
||||||
|
self._update_counts()
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
#@profiler
|
#@profiler
|
||||||
|
@ -432,7 +422,7 @@ class ChannelDB(SqlDB):
|
||||||
if not start_node_id or not short_channel_id: return None
|
if not start_node_id or not short_channel_id: return None
|
||||||
channel_info = self.get_channel_info(short_channel_id)
|
channel_info = self.get_channel_info(short_channel_id)
|
||||||
if channel_info is not None:
|
if channel_info is not None:
|
||||||
return self.get_policy_for_node(channel_info, start_node_id)
|
return self.get_policy_for_node(short_channel_id, start_node_id)
|
||||||
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
||||||
if not msg:
|
if not msg:
|
||||||
return None
|
return None
|
||||||
|
@ -446,12 +436,12 @@ class ChannelDB(SqlDB):
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
def remove_channel(self, short_channel_id):
|
def remove_channel(self, short_channel_id):
|
||||||
self._chan_query_for_id(short_channel_id).delete('evaluate')
|
r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none()
|
||||||
|
if not r:
|
||||||
|
return
|
||||||
|
self.DBSession.delete(r)
|
||||||
self.DBSession.commit()
|
self.DBSession.commit()
|
||||||
|
|
||||||
def _chan_query_for_id(self, short_channel_id) -> Query:
|
|
||||||
return self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex())
|
|
||||||
|
|
||||||
def print_graph(self, full_ids=False):
|
def print_graph(self, full_ids=False):
|
||||||
# used for debugging.
|
# used for debugging.
|
||||||
# FIXME there is a race here - iterables could change size from another thread
|
# FIXME there is a race here - iterables could change size from another thread
|
||||||
|
@ -488,23 +478,33 @@ class ChannelDB(SqlDB):
|
||||||
direction))
|
direction))
|
||||||
|
|
||||||
|
|
||||||
@sql
|
|
||||||
def get_policy_for_node(self, channel_info, node) -> Optional['Policy']:
|
|
||||||
"""
|
|
||||||
raises when initiator/non-initiator both unequal node
|
|
||||||
"""
|
|
||||||
if node.hex() not in (channel_info.node1_id, channel_info.node2_id):
|
|
||||||
raise Exception("the given node is not a party in this channel")
|
|
||||||
n1 = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node = channel_info.node1_id).one_or_none()
|
|
||||||
if n1:
|
|
||||||
return n1
|
|
||||||
n2 = self.DBSession.query(Policy).filter_by(short_channel_id = channel_info.short_channel_id, start_node = channel_info.node2_id).one_or_none()
|
|
||||||
return n2
|
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
def get_node_addresses(self, node_info):
|
def get_node_addresses(self, node_info):
|
||||||
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
|
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
|
||||||
|
|
||||||
|
@sql
|
||||||
|
@profiler
|
||||||
|
def load_data(self):
|
||||||
|
r = self.DBSession.query(ChannelInfo).all()
|
||||||
|
self._channels = dict([(bfh(x.short_channel_id), x) for x in r])
|
||||||
|
r = self.DBSession.query(Policy).filter_by().all()
|
||||||
|
self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r])
|
||||||
|
self._channels_for_node = defaultdict(set)
|
||||||
|
for channel_info in self._channels.values():
|
||||||
|
self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id))
|
||||||
|
self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id))
|
||||||
|
self.print_error('load data', len(self._channels), len(self._policies), len(self._channels_for_node))
|
||||||
|
|
||||||
|
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
|
||||||
|
return self._policies.get((node_id, short_channel_id))
|
||||||
|
|
||||||
|
def get_channel_info(self, channel_id: bytes):
|
||||||
|
return self._channels.get(channel_id)
|
||||||
|
|
||||||
|
def get_channels_for_node(self, node_id):
|
||||||
|
"""Returns the set of channels that have node_id as one of the endpoints."""
|
||||||
|
return self._channels_for_node.get(node_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
|
class RouteEdge(NamedTuple("RouteEdge", [('node_id', bytes),
|
||||||
|
@ -586,9 +586,9 @@ class LNPathFinder(PrintError):
|
||||||
channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo
|
channel_info = self.channel_db.get_channel_info(short_channel_id) # type: ChannelInfo
|
||||||
if channel_info is None:
|
if channel_info is None:
|
||||||
return float('inf'), 0
|
return float('inf'), 0
|
||||||
|
channel_policy = self.channel_db.get_policy_for_node(short_channel_id, start_node)
|
||||||
channel_policy = self.channel_db.get_policy_for_node(channel_info, start_node)
|
if channel_policy is None:
|
||||||
if channel_policy is None: return float('inf'), 0
|
return float('inf'), 0
|
||||||
if channel_policy.is_disabled(): return float('inf'), 0
|
if channel_policy.is_disabled(): return float('inf'), 0
|
||||||
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
|
route_edge = RouteEdge.from_channel_policy(channel_policy, short_channel_id, end_node)
|
||||||
if payment_amt_msat < channel_policy.htlc_minimum_msat:
|
if payment_amt_msat < channel_policy.htlc_minimum_msat:
|
||||||
|
@ -618,6 +618,7 @@ class LNPathFinder(PrintError):
|
||||||
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
|
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
|
||||||
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
|
i.e. an element reads as, "to get to node_id, travel through short_channel_id"
|
||||||
"""
|
"""
|
||||||
|
self.channel_db.load_data()
|
||||||
assert type(nodeA) is bytes
|
assert type(nodeA) is bytes
|
||||||
assert type(nodeB) is bytes
|
assert type(nodeB) is bytes
|
||||||
assert type(invoice_amount_msat) is int
|
assert type(invoice_amount_msat) is int
|
||||||
|
|
|
@ -611,6 +611,7 @@ class LNWorker(PrintError):
|
||||||
|
|
||||||
def _calc_routing_hints_for_invoice(self, amount_sat):
|
def _calc_routing_hints_for_invoice(self, amount_sat):
|
||||||
"""calculate routing hints (BOLT-11 'r' field)"""
|
"""calculate routing hints (BOLT-11 'r' field)"""
|
||||||
|
self.channel_db.load_data()
|
||||||
routing_hints = []
|
routing_hints = []
|
||||||
with self.lock:
|
with self.lock:
|
||||||
channels = list(self.channels.values())
|
channels = list(self.channels.values())
|
||||||
|
|
Loading…
Add table
Reference in a new issue