improve filter_channel_updates

blacklist channels that do not really get updated
This commit is contained in:
ThomasV 2019-05-16 19:00:44 +02:00
parent f4b3d7627d
commit eb4e6bb0de
2 changed files with 106 additions and 74 deletions

View file

@ -241,12 +241,14 @@ class Peer(Logger):
self.verify_node_announcements(node_anns) self.verify_node_announcements(node_anns)
self.channel_db.on_node_announcement(node_anns) self.channel_db.on_node_announcement(node_anns)
# channel updates # channel updates
good, bad = self.channel_db.filter_channel_updates(chan_upds) orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds, max_age=self.network.lngossip.max_age)
if bad: if orphaned:
self.logger.info(f'adding {len(bad)} unknown channel ids') self.logger.info(f'adding {len(orphaned)} unknown channel ids')
self.network.lngossip.add_new_ids(bad) self.network.lngossip.add_new_ids(orphaned)
self.verify_channel_updates(good) if good:
self.channel_db.on_channel_update(good) self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds)}')
self.verify_channel_updates(good)
self.channel_db.update_policies(good, to_delete)
# refresh gui # refresh gui
if chan_anns or node_anns or chan_upds: if chan_anns or node_anns or chan_upds:
self.network.lngossip.refresh_gui() self.network.lngossip.refresh_gui()
@ -273,7 +275,7 @@ class Peer(Logger):
short_channel_id = payload['short_channel_id'] short_channel_id = payload['short_channel_id']
if constants.net.rev_genesis_bytes() != payload['chain_hash']: if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash') raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['node_id']): if not verify_sig_for_channel_update(payload, payload['start_node']):
raise BaseException('verify error') raise BaseException('verify error')
@log_exceptions @log_exceptions
@ -990,21 +992,29 @@ class Peer(Logger):
OnionFailureCode.EXPIRY_TOO_SOON: 2, OnionFailureCode.EXPIRY_TOO_SOON: 2,
OnionFailureCode.CHANNEL_DISABLED: 4, OnionFailureCode.CHANNEL_DISABLED: 4,
} }
offset = failure_codes.get(code) if code in failure_codes:
if offset: offset = failure_codes[code]
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:] channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
message_type, payload = decode_msg(channel_update) message_type, payload = decode_msg(channel_update)
payload['raw'] = channel_update payload['raw'] = channel_update
try: orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload])
self.logger.info(f"trying to apply channel update on our db {payload}") if good:
self.channel_db.add_channel_update(payload) self.verify_channel_updates(good)
self.logger.info("successfully applied channel update on our db") self.channel_db.update_policies(good, to_delete)
except NotFoundChanAnnouncementForUpdate: self.logger.info("applied channel update on our db")
elif orphaned:
# maybe it is a private channel (and data in invoice was outdated) # maybe it is a private channel (and data in invoice was outdated)
self.logger.info("maybe channel update is for private channel?") self.logger.info("maybe channel update is for private channel?")
start_node_id = route[sender_idx].node_id start_node_id = route[sender_idx].node_id
self.channel_db.add_channel_update_for_private_channel(payload, start_node_id) self.channel_db.add_channel_update_for_private_channel(payload, start_node_id)
elif expired:
blacklist = True
elif deprecated:
self.logger.info(f'channel update is not more recent. blacklisting channel')
blacklist = True
else: else:
blacklist = True
if blacklist:
# blacklist channel after reporter node # blacklist channel after reporter node
# TODO this should depend on the error (even more granularity) # TODO this should depend on the error (even more granularity)
# also, we need finer blacklisting (directed edges; nodes) # also, we need finer blacklisting (directed edges; nodes)

View file

@ -114,22 +114,16 @@ class Policy(Base):
timestamp = Column(Integer, nullable=False) timestamp = Column(Integer, nullable=False)
@staticmethod @staticmethod
def from_msg(payload, start_node, short_channel_id): def from_msg(payload):
cltv_expiry_delta = payload['cltv_expiry_delta'] cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big")
htlc_minimum_msat = payload['htlc_minimum_msat'] htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big")
fee_base_msat = payload['fee_base_msat'] htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None
fee_proportional_millionths = payload['fee_proportional_millionths'] fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big")
channel_flags = payload['channel_flags'] fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big")
timestamp = payload['timestamp'] channel_flags = int.from_bytes(payload['channel_flags'], "big")
htlc_maximum_msat = payload.get('htlc_maximum_msat') # optional timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node'].hex()
cltv_expiry_delta = int.from_bytes(cltv_expiry_delta, "big") short_channel_id = payload['short_channel_id'].hex()
htlc_minimum_msat = int.from_bytes(htlc_minimum_msat, "big")
htlc_maximum_msat = int.from_bytes(htlc_maximum_msat, "big") if htlc_maximum_msat else None
fee_base_msat = int.from_bytes(fee_base_msat, "big")
fee_proportional_millionths = int.from_bytes(fee_proportional_millionths, "big")
channel_flags = int.from_bytes(channel_flags, "big")
timestamp = int.from_bytes(timestamp, "big")
return Policy(start_node=start_node, return Policy(start_node=start_node,
short_channel_id=short_channel_id, short_channel_id=short_channel_id,
@ -341,71 +335,98 @@ class ChannelDB(SqlDB):
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one()
return r.max_timestamp or 0 return r.max_timestamp or 0
def print_change(self, old_policy, new_policy):
# print what changed between policies
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta:
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}')
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat:
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}')
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat:
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}')
if old_policy.fee_base_msat != new_policy.fee_base_msat:
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}')
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths:
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}')
if old_policy.channel_flags != new_policy.channel_flags:
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
@sql @sql
def get_info_for_updates(self, msg_payloads): def get_info_for_updates(self, payloads):
short_channel_ids = [msg_payload['short_channel_id'].hex() for msg_payload in msg_payloads] short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
return channel_infos return channel_infos
@sql
def get_policies_for_updates(self, payloads):
out = {}
for payload in payloads:
short_channel_id = payload['short_channel_id'].hex()
start_node = payload['start_node'].hex()
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
if policy:
out[short_channel_id+start_node] = policy
return out
@profiler @profiler
def filter_channel_updates(self, payloads): def filter_channel_updates(self, payloads, max_age=None):
# add 'node_id' to payload orphaned = [] # no channel announcement for channel update
channel_infos = self.get_info_for_updates(payloads) expired = [] # update older than two weeks
deprecated = [] # update older than database entry
good = [] # good updates
to_delete = [] # database entries to delete
# filter orphaned and expired first
known = [] known = []
unknown = [] now = int(time.time())
channel_infos = self.get_info_for_updates(payloads)
for payload in payloads: for payload in payloads:
short_channel_id = payload['short_channel_id'] short_channel_id = payload['short_channel_id']
timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age:
expired.append(short_channel_id)
continue
channel_info = channel_infos.get(short_channel_id) channel_info = channel_infos.get(short_channel_id)
if not channel_info: if not channel_info:
unknown.append(short_channel_id) orphaned.append(short_channel_id)
continue continue
flags = int.from_bytes(payload['channel_flags'], 'big') flags = int.from_bytes(payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION direction = flags & FLAG_DIRECTION
node_id = bfh(channel_info.node1_id if direction == 0 else channel_info.node2_id) start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
payload['node_id'] = node_id payload['start_node'] = bfh(start_node)
known.append(payload) known.append(payload)
return known, unknown # compare updates to existing database entries
old_policies = self.get_policies_for_updates(known)
for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node'].hex()
short_channel_id = payload['short_channel_id'].hex()
old_policy = old_policies.get(short_channel_id+start_node)
if old_policy:
if timestamp <= old_policy.timestamp:
deprecated.append(short_channel_id)
else:
good.append(payload)
to_delete.append(old_policy)
else:
good.append(payload)
return orphaned, expired, deprecated, good, to_delete
def add_channel_update(self, payload): def add_channel_update(self, payload):
# called in tests/test_lnrouter orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload])
good, bad = self.filter_channel_updates([payload]) assert len(good) == 1
assert len(bad) == 0 self.update_policies(good, to_delete)
self.on_channel_update(good)
@sql @sql
@profiler @profiler
def on_channel_update(self, msg_payloads): def update_policies(self, to_add, to_delete):
now = int(time.time()) for policy in to_delete:
if type(msg_payloads) is dict: self.DBSession.delete(policy)
msg_payloads = [msg_payloads]
new_policies = {}
for msg_payload in msg_payloads:
short_channel_id = msg_payload['short_channel_id'].hex()
node_id = msg_payload['node_id'].hex()
new_policy = Policy.from_msg(msg_payload, node_id, short_channel_id)
# must not be older than two weeks
if new_policy.timestamp < now - 14*24*3600:
continue
old_policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=node_id).one_or_none()
if old_policy:
if old_policy.timestamp >= new_policy.timestamp:
continue
self.DBSession.delete(old_policy)
p = new_policies.get((short_channel_id, node_id))
if p and p.timestamp >= new_policy.timestamp:
continue
new_policies[(short_channel_id, node_id)] = new_policy
# commit pending removals
self.DBSession.commit() self.DBSession.commit()
# add and commit new policies for payload in to_add:
for new_policy in new_policies.values(): policy = Policy.from_msg(payload)
self.DBSession.add(new_policy) self.DBSession.add(policy)
self.DBSession.commit() self.DBSession.commit()
if new_policies: self._update_counts()
self.logger.debug(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
#self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
self._update_counts()
@sql @sql
@profiler @profiler
@ -454,7 +475,7 @@ class ChannelDB(SqlDB):
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id))
if not msg: if not msg:
return None return None
return Policy.from_msg(msg, None, short_channel_id) # won't actually be written to DB return Policy.from_msg(msg) # won't actually be written to DB
@sql @sql
@profiler @profiler
@ -496,6 +517,7 @@ class ChannelDB(SqlDB):
if not verify_sig_for_channel_update(msg_payload, start_node_id): if not verify_sig_for_channel_update(msg_payload, start_node_id):
return # ignore return # ignore
short_channel_id = msg_payload['short_channel_id'] short_channel_id = msg_payload['short_channel_id']
msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
@sql @sql