mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-09-01 01:35:20 +00:00
improve filter_channel_updates
blacklist channels that do not really get updated
This commit is contained in:
parent
f4b3d7627d
commit
eb4e6bb0de
2 changed files with 106 additions and 74 deletions
|
@ -241,12 +241,14 @@ class Peer(Logger):
|
||||||
self.verify_node_announcements(node_anns)
|
self.verify_node_announcements(node_anns)
|
||||||
self.channel_db.on_node_announcement(node_anns)
|
self.channel_db.on_node_announcement(node_anns)
|
||||||
# channel updates
|
# channel updates
|
||||||
good, bad = self.channel_db.filter_channel_updates(chan_upds)
|
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds, max_age=self.network.lngossip.max_age)
|
||||||
if bad:
|
if orphaned:
|
||||||
self.logger.info(f'adding {len(bad)} unknown channel ids')
|
self.logger.info(f'adding {len(orphaned)} unknown channel ids')
|
||||||
self.network.lngossip.add_new_ids(bad)
|
self.network.lngossip.add_new_ids(orphaned)
|
||||||
self.verify_channel_updates(good)
|
if good:
|
||||||
self.channel_db.on_channel_update(good)
|
self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds)}')
|
||||||
|
self.verify_channel_updates(good)
|
||||||
|
self.channel_db.update_policies(good, to_delete)
|
||||||
# refresh gui
|
# refresh gui
|
||||||
if chan_anns or node_anns or chan_upds:
|
if chan_anns or node_anns or chan_upds:
|
||||||
self.network.lngossip.refresh_gui()
|
self.network.lngossip.refresh_gui()
|
||||||
|
@ -273,7 +275,7 @@ class Peer(Logger):
|
||||||
short_channel_id = payload['short_channel_id']
|
short_channel_id = payload['short_channel_id']
|
||||||
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
|
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
|
||||||
raise Exception('wrong chain hash')
|
raise Exception('wrong chain hash')
|
||||||
if not verify_sig_for_channel_update(payload, payload['node_id']):
|
if not verify_sig_for_channel_update(payload, payload['start_node']):
|
||||||
raise BaseException('verify error')
|
raise BaseException('verify error')
|
||||||
|
|
||||||
@log_exceptions
|
@log_exceptions
|
||||||
|
@ -990,21 +992,29 @@ class Peer(Logger):
|
||||||
OnionFailureCode.EXPIRY_TOO_SOON: 2,
|
OnionFailureCode.EXPIRY_TOO_SOON: 2,
|
||||||
OnionFailureCode.CHANNEL_DISABLED: 4,
|
OnionFailureCode.CHANNEL_DISABLED: 4,
|
||||||
}
|
}
|
||||||
offset = failure_codes.get(code)
|
if code in failure_codes:
|
||||||
if offset:
|
offset = failure_codes[code]
|
||||||
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
|
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
|
||||||
message_type, payload = decode_msg(channel_update)
|
message_type, payload = decode_msg(channel_update)
|
||||||
payload['raw'] = channel_update
|
payload['raw'] = channel_update
|
||||||
try:
|
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
|
||||||
self.logger.info(f"trying to apply channel update on our db {payload}")
|
if good:
|
||||||
self.channel_db.add_channel_update(payload)
|
self.verify_channel_updates(good)
|
||||||
self.logger.info("successfully applied channel update on our db")
|
self.channel_db.update_policies(good, to_delete)
|
||||||
except NotFoundChanAnnouncementForUpdate:
|
self.logger.info("applied channel update on our db")
|
||||||
|
elif orphaned:
|
||||||
# maybe it is a private channel (and data in invoice was outdated)
|
# maybe it is a private channel (and data in invoice was outdated)
|
||||||
self.logger.info("maybe channel update is for private channel?")
|
self.logger.info("maybe channel update is for private channel?")
|
||||||
start_node_id = route[sender_idx].node_id
|
start_node_id = route[sender_idx].node_id
|
||||||
self.channel_db.add_channel_update_for_private_channel(payload, start_node_id)
|
self.channel_db.add_channel_update_for_private_channel(payload, start_node_id)
|
||||||
|
elif expired:
|
||||||
|
blacklist = True
|
||||||
|
elif deprecated:
|
||||||
|
self.logger.info(f'channel update is not more recent. blacklisting channel')
|
||||||
|
blacklist = True
|
||||||
else:
|
else:
|
||||||
|
blacklist = True
|
||||||
|
if blacklist:
|
||||||
# blacklist channel after reporter node
|
# blacklist channel after reporter node
|
||||||
# TODO this should depend on the error (even more granularity)
|
# TODO this should depend on the error (even more granularity)
|
||||||
# also, we need finer blacklisting (directed edges; nodes)
|
# also, we need finer blacklisting (directed edges; nodes)
|
||||||
|
|
|
@ -114,22 +114,16 @@ class Policy(Base):
|
||||||
timestamp = Column(Integer, nullable=False)
|
timestamp = Column(Integer, nullable=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_msg(payload, start_node, short_channel_id):
|
def from_msg(payload):
|
||||||
cltv_expiry_delta = payload['cltv_expiry_delta']
|
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
|
||||||
htlc_minimum_msat = payload['htlc_minimum_msat']
|
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
|
||||||
fee_base_msat = payload['fee_base_msat']
|
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
|
||||||
fee_proportional_millionths = payload['fee_proportional_millionths']
|
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
|
||||||
channel_flags = payload['channel_flags']
|
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
|
||||||
timestamp = payload['timestamp']
|
channel_flags = int.from_bytes(payload['channel_flags'], "big")
|
||||||
htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional
|
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||||
|
start_node = payload['start_node'].hex()
|
||||||
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big")
|
short_channel_id = payload['short_channel_id'].hex()
|
||||||
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
|
|
||||||
htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
|
|
||||||
fee_base_msat = int.from_bytes(fee_base_msat, "big")
|
|
||||||
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
|
|
||||||
channel_flags = int.from_bytes(channel_flags, "big")
|
|
||||||
timestamp = int.from_bytes(timestamp, "big")
|
|
||||||
|
|
||||||
return Policy(start_node=start_node,
|
return Policy(start_node=start_node,
|
||||||
short_channel_id=short_channel_id,
|
short_channel_id=short_channel_id,
|
||||||
|
@ -341,71 +335,98 @@ class ChannelDB(SqlDB):
|
||||||
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
|
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
|
||||||
return r.max_timestamp or 0
|
return r.max_timestamp or 0
|
||||||
|
|
||||||
|
def print_change(self, old_policy, new_policy):
|
||||||
|
# print what changed between policies
|
||||||
|
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
|
||||||
|
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
|
||||||
|
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
|
||||||
|
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
|
||||||
|
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
|
||||||
|
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
|
||||||
|
if old_policy.fee_base_msat != new_policy.fee_base_msat:
|
||||||
|
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
|
||||||
|
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
|
||||||
|
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
|
||||||
|
if old_policy.channel_flags != new_policy.channel_flags:
|
||||||
|
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
def get_info_for_updates(self, msg_payloads):
|
def get_info_for_updates(self, payloads):
|
||||||
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads]
|
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
|
||||||
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
|
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
|
||||||
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
|
||||||
return channel_infos
|
return channel_infos
|
||||||
|
|
||||||
|
@sql
|
||||||
|
def get_policies_for_updates(self, payloads):
|
||||||
|
out = {}
|
||||||
|
for payload in payloads:
|
||||||
|
short_channel_id = payload['short_channel_id'].hex()
|
||||||
|
start_node = payload['start_node'].hex()
|
||||||
|
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
|
||||||
|
if policy:
|
||||||
|
out[short_channel_id+start_node] = policy
|
||||||
|
return out
|
||||||
|
|
||||||
@profiler
|
@profiler
|
||||||
def filter_channel_updates(self, payloads):
|
def filter_channel_updates(self, payloads, max_age=None):
|
||||||
# add 'node_id' to payload
|
orphaned = [] # no channel announcement for channel update
|
||||||
channel_infos = self.get_info_for_updates(payloads)
|
expired = [] # update older than two weeks
|
||||||
|
deprecated = [] # update older than database entry
|
||||||
|
good = [] # good updates
|
||||||
|
to_delete = [] # database entries to delete
|
||||||
|
# filter orphaned and expired first
|
||||||
known = []
|
known = []
|
||||||
unknown = []
|
now = int(time.time())
|
||||||
|
channel_infos = self.get_info_for_updates(payloads)
|
||||||
for payload in payloads:
|
for payload in payloads:
|
||||||
short_channel_id = payload['short_channel_id']
|
short_channel_id = payload['short_channel_id']
|
||||||
|
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||||
|
if max_age and now - timestamp > max_age:
|
||||||
|
expired.append(short_channel_id)
|
||||||
|
continue
|
||||||
channel_info = channel_infos.get(short_channel_id)
|
channel_info = channel_infos.get(short_channel_id)
|
||||||
if not channel_info:
|
if not channel_info:
|
||||||
unknown.append(short_channel_id)
|
orphaned.append(short_channel_id)
|
||||||
continue
|
continue
|
||||||
flags = int.from_bytes(payload['channel_flags'], 'big')
|
flags = int.from_bytes(payload['channel_flags'], 'big')
|
||||||
direction = flags & FLAG_DIRECTION
|
direction = flags & FLAG_DIRECTION
|
||||||
node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id)
|
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
|
||||||
payload['node_id'] = node_id
|
payload['start_node'] = bfh(start_node)
|
||||||
known.append(payload)
|
known.append(payload)
|
||||||
return known, unknown
|
# compare updates to existing database entries
|
||||||
|
old_policies = self.get_policies_for_updates(known)
|
||||||
|
for payload in known:
|
||||||
|
timestamp = int.from_bytes(payload['timestamp'], "big")
|
||||||
|
start_node = payload['start_node'].hex()
|
||||||
|
short_channel_id = payload['short_channel_id'].hex()
|
||||||
|
old_policy = old_policies.get(short_channel_id+start_node)
|
||||||
|
if old_policy:
|
||||||
|
if timestamp <= old_policy.timestamp:
|
||||||
|
deprecated.append(short_channel_id)
|
||||||
|
else:
|
||||||
|
good.append(payload)
|
||||||
|
to_delete.append(old_policy)
|
||||||
|
else:
|
||||||
|
good.append(payload)
|
||||||
|
return orphaned, expired, deprecated, good, to_delete
|
||||||
|
|
||||||
def add_channel_update(self, payload):
|
def add_channel_update(self, payload):
|
||||||
# called in tests/test_lnrouter
|
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
|
||||||
good, bad = self.filter_channel_updates([payload])
|
assert len(good) == 1
|
||||||
assert len(bad) == 0
|
self.update_policies(good, to_delete)
|
||||||
self.on_channel_update(good)
|
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
@profiler
|
@profiler
|
||||||
def on_channel_update(self, msg_payloads):
|
def update_policies(self, to_add, to_delete):
|
||||||
now = int(time.time())
|
for policy in to_delete:
|
||||||
if type(msg_payloads) is dict:
|
self.DBSession.delete(policy)
|
||||||
msg_payloads = [msg_payloads]
|
|
||||||
new_policies = {}
|
|
||||||
for msg_payload in msg_payloads:
|
|
||||||
short_channel_id = msg_payload['short_channel_id'].hex()
|
|
||||||
node_id = msg_payload['node_id'].hex()
|
|
||||||
new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id)
|
|
||||||
# must not be older than two weeks
|
|
||||||
if new_policy.timestamp < now - 14*24*3600:
|
|
||||||
continue
|
|
||||||
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
|
|
||||||
if old_policy:
|
|
||||||
if old_policy.timestamp >= new_policy.timestamp:
|
|
||||||
continue
|
|
||||||
self.DBSession.delete(old_policy)
|
|
||||||
p = new_policies.get((short_channel_id, node_id))
|
|
||||||
if p and p.timestamp >= new_policy.timestamp:
|
|
||||||
continue
|
|
||||||
new_policies[(short_channel_id, node_id)] = new_policy
|
|
||||||
# commit pending removals
|
|
||||||
self.DBSession.commit()
|
self.DBSession.commit()
|
||||||
# add and commit new policies
|
for payload in to_add:
|
||||||
for new_policy in new_policies.values():
|
policy = Policy.from_msg(payload)
|
||||||
self.DBSession.add(new_policy)
|
self.DBSession.add(policy)
|
||||||
self.DBSession.commit()
|
self.DBSession.commit()
|
||||||
if new_policies:
|
self._update_counts()
|
||||||
self.logger.debug(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
|
|
||||||
#self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
|
|
||||||
self._update_counts()
|
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
@profiler
|
@profiler
|
||||||
|
@ -454,7 +475,7 @@ class ChannelDB(SqlDB):
|
||||||
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
|
||||||
if not msg:
|
if not msg:
|
||||||
return None
|
return None
|
||||||
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB
|
return Policy.from_msg(msg) # won't actually be written to DB
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
@profiler
|
@profiler
|
||||||
|
@ -496,6 +517,7 @@ class ChannelDB(SqlDB):
|
||||||
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
if not verify_sig_for_channel_update(msg_payload, start_node_id):
|
||||||
return # ignore
|
return # ignore
|
||||||
short_channel_id = msg_payload['short_channel_id']
|
short_channel_id = msg_payload['short_channel_id']
|
||||||
|
msg_payload['start_node'] = start_node_id
|
||||||
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
|
|
Loading…
Add table
Reference in a new issue