mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-30 08:51:32 +00:00
improve filter_channel_updates
blacklist channels that do not really get updated
This commit is contained in:
parent
f4b3d7627d
commit
eb4e6bb0de
2 changed files with 106 additions and 74 deletions
|
@ -241,12 +241,14 @@ class Peer(Logger):
|
|||
self.verify_node_announcements(node_anns)
|
||||
self.channel_db.on_node_announcement(node_anns)
|
||||
# channel updates
|
||||
good, bad = self.channel_db.filter_channel_updates(chan_upds)
|
||||
if bad:
|
||||
self.logger.info(f'adding {len(bad)} unknown channel ids')
|
||||
self.network.lngossip.add_new_ids(bad)
|
||||
self.verify_channel_updates(good)
|
||||
self.channel_db.on_channel_update(good)
|
||||
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds, max_age=self.network.lngossip.max_age)
|
||||
if orphaned:
|
||||
self.logger.info(f'adding {len(orphaned)} unknown channel ids')
|
||||
self.network.lngossip.add_new_ids(orphaned)
|
||||
if good:
|
||||
self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds)}')
|
||||
self.verify_channel_updates(good)
|
||||
self.channel_db.update_policies(good, to_delete)
|
||||
# refresh gui
|
||||
if chan_anns or node_anns or chan_upds:
|
||||
self.network.lngossip.refresh_gui()
|
||||
|
@ -273,7 +275,7 @@ class Peer(Logger):
|
|||
short_channel_id = payload['short_channel_id']
|
||||
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
|
||||
raise Exception('wrong chain hash')
|
||||
if not verify_sig_for_channel_update(payload, payload['node_id']):
|
||||
if not verify_sig_for_channel_update(payload, payload['start_node']):
|
||||
raise BaseException('verify error')
|
||||
|
||||
@log_exceptions
|
||||
|
@ -990,21 +992,29 @@ class Peer(Logger):
|
|||
OnionFailureCode.EXPIRY_TOO_SOON: 2,
|
||||
OnionFailureCode.CHANNEL_DISABLED: 4,
|
||||
}
|
||||
offset = failure_codes.get(code)
|
||||
if offset:
|
||||
if code in failure_codes:
|
||||
offset = failure_codes[code]
|
||||
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
|
||||
message_type, payload = decode_msg(channel_update)
|
||||
payload['raw'] = channel_update
|
||||
try:
|
||||
self.logger.info(f"trying to apply channel update on our db {payload}")
|
||||
self.channel_db.add_channel_update(payload)
|
||||
self.logger.info("successfully applied channel update on our db")
|
||||
except NotFoundChanAnnouncementForUpdate:
|
||||
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
|
||||
if good:
|
||||
self.verify_channel_updates(good)
|
||||
self.channel_db.update_policies(good, to_delete)
|
||||
self.logger.info("applied channel update on our db")
|
||||
elif orphaned:
|
||||
# maybe it is a private channel (and data in invoice was outdated)
|
||||
self.logger.info("maybe channel update is for private channel?")
|
||||
start_node_id = route[sender_idx].node_id
|
||||
self.channel_db.add_channel_update_for_private_channel(payload, start_node_id)
|
||||
elif expired:
|
||||
blacklist = True
|
||||
elif deprecated:
|
||||
self.logger.info(f'channel update is not more recent. blacklisting channel')
|
||||
blacklist = True
|
||||
else:
|
||||
blacklist = True
|
||||
if blacklist:
|
||||
# blacklist channel after reporter node
|
||||
# TODO this should depend on the error (even more granularity)
|
||||
# also, we need finer blacklisting (directed edges; nodes)
|
||||
|
|
|
@ -114,22 +114,16 @@ class Policy(Base):
|
|||
timestamp = Column(Integer, nullable=False)
|
||||
|
||||
@staticmethod
|
||||
def from_msg(payload, start_node, short_channel_id):
|
||||
cltv_expiry_delta = payload['cltv_expiry_delta']
|
||||
htlc_minimum_msat = payload['htlc_minimum_msat']
|
||||
fee_base_msat = payload['fee_base_msat']
|
||||
fee_proportional_millionths = payload['fee_proportional_millionths']
|
||||
channel_flags = payload['channel_flags']
|
||||
timestamp = payload['timestamp']
|
||||
htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional
|
||||
|
||||
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
|
||||
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
|
||||
htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
|
||||
fee_base_msat = int.from_bytes(fee_base_msat, "big")
|
||||
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
|
||||
channel_flags = int.from_bytes(channel_flags, "big")
|
||||
timestamp = int.from_bytes(timestamp, "big")
|
||||
def from_msg(payload):
|
||||
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
|
||||
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
|
||||
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
|
||||
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
|
||||
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
|
||||
channel_flags = int.from_bytes(payload['channel_flags'], "big")
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
start_node = payload['start_node'].hex()
|
||||
short_channel_id = payload['short_channel_id'].hex()
|
||||
|
||||
return Policy(start_node=start_node,
|
||||
short_channel_id=short_channel_id,
|
||||
|
@ -341,71 +335,98 @@ class ChannelDB(SqlDB):
|
|||
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
|
||||
return r.max_timestamp or 0
|
||||
|
||||
def print_change(self, old_policy, new_policy):
|
||||
# print what changed between policies
|
||||
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
|
||||
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
|
||||
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
|
||||
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
|
||||
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
|
||||
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
|
||||
if old_policy.fee_base_msat != new_policy.fee_base_msat:
|
||||
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
|
||||
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
|
||||
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
|
||||
if old_policy.channel_flags != new_policy.channel_flags:
|
||||
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
|
||||
|
||||
@sql
|
||||
def get_info_for_updates(self, msg_payloads):
|
||||
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
|
||||
def get_info_for_updates(self, payloads):
|
||||
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
|
||||
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
|
||||
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
|
||||
return channel_infos
|
||||
|
||||
@sql
|
||||
def get_policies_for_updates(self, payloads):
|
||||
out = {}
|
||||
for payload in payloads:
|
||||
short_channel_id = payload['short_channel_id'].hex()
|
||||
start_node = payload['start_node'].hex()
|
||||
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
|
||||
if policy:
|
||||
out[short_channel_id+start_node] = policy
|
||||
return out
|
||||
|
||||
@profiler
|
||||
def filter_channel_updates(self, payloads):
|
||||
# add 'node_id' to payload
|
||||
channel_infos = self.get_info_for_updates(payloads)
|
||||
def filter_channel_updates(self, payloads, max_age=None):
|
||||
orphaned = [] # no channel announcement for channel update
|
||||
expired = [] # update older than two weeks
|
||||
deprecated = [] # update older than database entry
|
||||
good = [] # good updates
|
||||
to_delete = [] # database entries to delete
|
||||
# filter orphaned and expired first
|
||||
known = []
|
||||
unknown = []
|
||||
now = int(time.time())
|
||||
channel_infos = self.get_info_for_updates(payloads)
|
||||
for payload in payloads:
|
||||
short_channel_id = payload['short_channel_id']
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
if max_age and now - timestamp > max_age:
|
||||
expired.append(short_channel_id)
|
||||
continue
|
||||
channel_info = channel_infos.get(short_channel_id)
|
||||
if not channel_info:
|
||||
unknown.append(short_channel_id)
|
||||
orphaned.append(short_channel_id)
|
||||
continue
|
||||
flags = int.from_bytes(payload['channel_flags'], 'big')
|
||||
direction = flags & FLAG_DIRECTION
|
||||
node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id)
|
||||
payload['node_id'] = node_id
|
||||
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
|
||||
payload['start_node'] = bfh(start_node)
|
||||
known.append(payload)
|
||||
return known, unknown
|
||||
# compare updates to existing database entries
|
||||
old_policies = self.get_policies_for_updates(known)
|
||||
for payload in known:
|
||||
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||
start_node = payload['start_node'].hex()
|
||||
short_channel_id = payload['short_channel_id'].hex()
|
||||
old_policy = old_policies.get(short_channel_id+start_node)
|
||||
if old_policy:
|
||||
if timestamp <= old_policy.timestamp:
|
||||
deprecated.append(short_channel_id)
|
||||
else:
|
||||
good.append(payload)
|
||||
to_delete.append(old_policy)
|
||||
else:
|
||||
good.append(payload)
|
||||
return orphaned, expired, deprecated, good, to_delete
|
||||
|
||||
def add_channel_update(self, payload):
|
||||
# called in tests/test_lnrouter
|
||||
good, bad = self.filter_channel_updates([payload])
|
||||
assert len(bad) == 0
|
||||
self.on_channel_update(good)
|
||||
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
|
||||
assert len(good) == 1
|
||||
self.update_policies(good, to_delete)
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def on_channel_update(self, msg_payloads):
|
||||
now = int(time.time())
|
||||
if type(msg_payloads) is dict:
|
||||
msg_payloads = [msg_payloads]
|
||||
new_policies = {}
|
||||
for msg_payload in msg_payloads:
|
||||
short_channel_id = msg_payload['short_channel_id'].hex()
|
||||
node_id = msg_payload['node_id'].hex()
|
||||
new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id)
|
||||
# must not be older than two weeks
|
||||
if new_policy.timestamp < now - 14*24*3600:
|
||||
continue
|
||||
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
|
||||
if old_policy:
|
||||
if old_policy.timestamp >= new_policy.timestamp:
|
||||
continue
|
||||
self.DBSession.delete(old_policy)
|
||||
p = new_policies.get((short_channel_id, node_id))
|
||||
if p and p.timestamp >= new_policy.timestamp:
|
||||
continue
|
||||
new_policies[(short_channel_id, node_id)] = new_policy
|
||||
# commit pending removals
|
||||
def update_policies(self, to_add, to_delete):
|
||||
for policy in to_delete:
|
||||
self.DBSession.delete(policy)
|
||||
self.DBSession.commit()
|
||||
# add and commit new policies
|
||||
for new_policy in new_policies.values():
|
||||
self.DBSession.add(new_policy)
|
||||
for payload in to_add:
|
||||
policy = Policy.from_msg(payload)
|
||||
self.DBSession.add(policy)
|
||||
self.DBSession.commit()
|
||||
if new_policies:
|
||||
self.logger.debug(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
|
||||
#self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
|
||||
self._update_counts()
|
||||
self._update_counts()
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
|
@ -454,7 +475,7 @@ class ChannelDB(SqlDB):
|
|||
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
||||
if not msg:
|
||||
return None
|
||||
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
|
||||
return Policy.from_msg(msg) # won't actually be written to DB
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
|
@ -496,6 +517,7 @@ class ChannelDB(SqlDB):
|
|||
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
||||
return # ignore
|
||||
short_channel_id = msg_payload['short_channel_id']
|
||||
msg_payload['start_node'] = start_node_id
|
||||
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
||||
|
||||
@sql
|
||||
|
|
Loading…
Add table
Reference in a new issue