mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-31 17:31:36 +00:00
LNGossip: sync channel db using query_channel_range
This commit is contained in:
parent
95376226e8
commit
1011245c5e
3 changed files with 168 additions and 59 deletions
|
@ -57,9 +57,7 @@ class Peer(Logger):
|
|||
|
||||
def __init__(self, lnworker: 'LNWorker', pubkey:bytes, transport: LNTransportBase):
|
||||
self.initialized = asyncio.Event()
|
||||
self.node_anns = []
|
||||
self.chan_anns = []
|
||||
self.chan_upds = []
|
||||
self.querying_lock = asyncio.Lock()
|
||||
self.transport = transport
|
||||
self.pubkey = pubkey
|
||||
self.lnworker = lnworker
|
||||
|
@ -70,6 +68,7 @@ class Peer(Logger):
|
|||
self.lnwatcher = lnworker.network.lnwatcher
|
||||
self.channel_db = lnworker.network.channel_db
|
||||
self.ping_time = 0
|
||||
self.reply_channel_range = asyncio.Queue()
|
||||
self.shutdown_received = defaultdict(asyncio.Future)
|
||||
self.channel_accepted = defaultdict(asyncio.Queue)
|
||||
self.channel_reestablished = defaultdict(asyncio.Future)
|
||||
|
@ -89,7 +88,7 @@ class Peer(Logger):
|
|||
|
||||
def send_message(self, message_name: str, **kwargs):
|
||||
assert type(message_name) is str
|
||||
self.logger.info(f"Sending {message_name.upper()}")
|
||||
self.logger.debug(f"Sending {message_name.upper()}")
|
||||
self.transport.send_bytes(encode_msg(message_name, **kwargs))
|
||||
|
||||
async def initialize(self):
|
||||
|
@ -177,13 +176,13 @@ class Peer(Logger):
|
|||
self.initialized.set()
|
||||
|
||||
def on_node_announcement(self, payload):
|
||||
self.node_anns.append(payload)
|
||||
self.channel_db.node_anns.append(payload)
|
||||
|
||||
def on_channel_update(self, payload):
|
||||
self.chan_upds.append(payload)
|
||||
self.channel_db.chan_upds.append(payload)
|
||||
|
||||
def on_channel_announcement(self, payload):
|
||||
self.chan_anns.append(payload)
|
||||
self.channel_db.chan_anns.append(payload)
|
||||
|
||||
def on_announcement_signatures(self, payload):
|
||||
channel_id = payload['channel_id']
|
||||
|
@ -207,15 +206,11 @@ class Peer(Logger):
|
|||
@handle_disconnect
|
||||
async def main_loop(self):
|
||||
async with aiorpcx.TaskGroup() as group:
|
||||
await group.spawn(self._gossip_loop())
|
||||
await group.spawn(self._message_loop())
|
||||
# kill group if the peer times out
|
||||
await group.spawn(asyncio.wait_for(self.initialized.wait(), 10))
|
||||
|
||||
@log_exceptions
|
||||
async def _gossip_loop(self):
|
||||
await self.initialized.wait()
|
||||
timestamp = self.channel_db.get_last_timestamp()
|
||||
def request_gossip(self, timestamp=0):
|
||||
if timestamp == 0:
|
||||
self.logger.info('requesting whole channel graph')
|
||||
else:
|
||||
|
@ -225,28 +220,47 @@ class Peer(Logger):
|
|||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
first_timestamp=timestamp,
|
||||
timestamp_range=b'\xff'*4)
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
if self.node_anns:
|
||||
self.channel_db.on_node_announcement(self.node_anns)
|
||||
self.node_anns = []
|
||||
if self.chan_anns:
|
||||
self.channel_db.on_channel_announcement(self.chan_anns)
|
||||
self.chan_anns = []
|
||||
if self.chan_upds:
|
||||
self.channel_db.on_channel_update(self.chan_upds)
|
||||
self.chan_upds = []
|
||||
# todo: enable when db is fixed
|
||||
#need_to_get = sorted(self.channel_db.missing_short_chan_ids())
|
||||
#if need_to_get and not self.receiving_channels:
|
||||
# self.logger.info(f'missing {len(need_to_get)} channels')
|
||||
# zlibencoded = zlib.compress(bfh(''.join(need_to_get[0:100])))
|
||||
# self.send_message(
|
||||
# 'query_short_channel_ids',
|
||||
# chain_hash=constants.net.rev_genesis_bytes(),
|
||||
# len=1+len(zlibencoded),
|
||||
# encoded_short_ids=b'\x01' + zlibencoded)
|
||||
# self.receiving_channels = True
|
||||
|
||||
def query_channel_range(self, index, num):
|
||||
self.logger.info(f'query channel range')
|
||||
self.send_message(
|
||||
'query_channel_range',
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
first_blocknum=index,
|
||||
number_of_blocks=num)
|
||||
|
||||
def encode_short_ids(self, ids):
|
||||
return chr(1) + zlib.compress(bfh(''.join(ids)))
|
||||
|
||||
def decode_short_ids(self, encoded):
|
||||
if encoded[0] == 0:
|
||||
decoded = encoded[1:]
|
||||
elif encoded[0] == 1:
|
||||
decoded = zlib.decompress(encoded[1:])
|
||||
else:
|
||||
raise BaseException('zlib')
|
||||
ids = [decoded[i:i+8] for i in range(0, len(decoded), 8)]
|
||||
return ids
|
||||
|
||||
def on_reply_channel_range(self, payload):
|
||||
first = int.from_bytes(payload['first_blocknum'], 'big')
|
||||
num = int.from_bytes(payload['number_of_blocks'], 'big')
|
||||
complete = bool(payload['complete'])
|
||||
encoded = payload['encoded_short_ids']
|
||||
ids = self.decode_short_ids(encoded)
|
||||
self.reply_channel_range.put_nowait((first, num, complete, ids))
|
||||
|
||||
async def query_short_channel_ids(self, ids, compressed=True):
|
||||
await self.querying_lock.acquire()
|
||||
#self.logger.info('querying {} short_channel_ids'.format(len(ids)))
|
||||
s = b''.join(ids)
|
||||
encoded = zlib.compress(s) if compressed else s
|
||||
prefix = b'\x01' if compressed else b'\x00'
|
||||
self.send_message(
|
||||
'query_short_channel_ids',
|
||||
chain_hash=constants.net.rev_genesis_bytes(),
|
||||
len=1+len(encoded),
|
||||
encoded_short_ids=prefix+encoded)
|
||||
|
||||
async def _message_loop(self):
|
||||
try:
|
||||
|
@ -260,7 +274,7 @@ class Peer(Logger):
|
|||
self.ping_if_required()
|
||||
|
||||
def on_reply_short_channel_ids_end(self, payload):
|
||||
self.receiving_channels = False
|
||||
self.querying_lock.release()
|
||||
|
||||
def close_and_cleanup(self):
|
||||
try:
|
||||
|
|
|
@ -223,6 +223,20 @@ class ChannelDB(SqlDB):
|
|||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
||||
self.ca_verifier = LNChannelVerifier(network, self)
|
||||
self.update_counts()
|
||||
self.node_anns = []
|
||||
self.chan_anns = []
|
||||
self.chan_upds = []
|
||||
|
||||
def process_gossip(self):
|
||||
if self.node_anns:
|
||||
self.on_node_announcement(self.node_anns)
|
||||
self.node_anns = []
|
||||
if self.chan_anns:
|
||||
self.on_channel_announcement(self.chan_anns)
|
||||
self.chan_anns = []
|
||||
if self.chan_upds:
|
||||
self.on_channel_update(self.chan_upds)
|
||||
self.chan_upds = []
|
||||
|
||||
@sql
|
||||
def update_counts(self):
|
||||
|
@ -232,7 +246,32 @@ class ChannelDB(SqlDB):
|
|||
self.num_channels = self.DBSession.query(ChannelInfo).count()
|
||||
self.num_policies = self.DBSession.query(Policy).count()
|
||||
self.num_nodes = self.DBSession.query(NodeInfo).count()
|
||||
self.logger.info(f'update counts {self.num_channels} {self.num_policies}')
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def purge_unknown_channels(self, channel_ids):
|
||||
ids = [x.hex() for x in channel_ids]
|
||||
missing = self.DBSession \
|
||||
.query(ChannelInfo) \
|
||||
.filter(not_(ChannelInfo.short_channel_id.in_(ids))) \
|
||||
.all()
|
||||
if missing:
|
||||
self.logger.info("deleting {} channels".format(len(missing)))
|
||||
delete_query = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(ids)))
|
||||
self.DBSession.execute(delete_query)
|
||||
self.DBSession.commit()
|
||||
|
||||
@sql
|
||||
@profiler
|
||||
def compare_channels(self, channel_ids):
|
||||
ids = [x.hex() for x in channel_ids]
|
||||
# I need to get the unknown, and also the channels that need refresh
|
||||
known = self.DBSession \
|
||||
.query(ChannelInfo) \
|
||||
.filter(ChannelInfo.short_channel_id.in_(ids)) \
|
||||
.all()
|
||||
known = [bfh(r.short_channel_id) for r in known]
|
||||
return known
|
||||
|
||||
@sql
|
||||
def add_recent_peer(self, peer: LNPeerAddr):
|
||||
|
@ -276,12 +315,14 @@ class ChannelDB(SqlDB):
|
|||
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r]
|
||||
|
||||
@sql
|
||||
def missing_short_chan_ids(self) -> Set[int]:
|
||||
def missing_channel_announcements(self) -> Set[int]:
|
||||
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
|
||||
chan_ids_from_policy = set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
|
||||
if chan_ids_from_policy:
|
||||
return chan_ids_from_policy
|
||||
return set()
|
||||
return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
|
||||
|
||||
@sql
|
||||
def missing_channel_updates(self) -> Set[int]:
|
||||
expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
|
||||
return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
|
||||
|
||||
@sql
|
||||
def add_verified_channel_info(self, short_id, capacity):
|
||||
|
@ -316,8 +357,8 @@ class ChannelDB(SqlDB):
|
|||
for channel_info in new_channels.values():
|
||||
self.DBSession.add(channel_info)
|
||||
self.DBSession.commit()
|
||||
#self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
|
||||
self._update_counts()
|
||||
self.logger.info('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
|
||||
self.network.trigger_callback('ln_status')
|
||||
|
||||
@sql
|
||||
|
@ -370,7 +411,7 @@ class ChannelDB(SqlDB):
|
|||
self.DBSession.commit()
|
||||
if new_policies:
|
||||
self.logger.info(f'on_channel_update: {len(new_policies)}/{len(msg_payloads)}')
|
||||
self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
|
||||
#self.logger.info(f'last timestamp: {datetime.fromtimestamp(self._get_last_timestamp()).ctime()}')
|
||||
self._update_counts()
|
||||
|
||||
@sql
|
||||
|
|
|
@ -133,9 +133,7 @@ class LNWorker(Logger):
|
|||
self.channel_db = self.network.channel_db
|
||||
self._last_tried_peer = {} # LNPeerAddr -> unix timestamp
|
||||
self._add_peers_from_config()
|
||||
# wait until we see confirmations
|
||||
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.main_loop()), self.network.asyncio_loop)
|
||||
self.first_timestamp_requested = None
|
||||
|
||||
def _add_peers_from_config(self):
|
||||
peer_list = self.config.get('lightning_peers', [])
|
||||
|
@ -215,9 +213,24 @@ class LNWorker(Logger):
|
|||
self.logger.info('got {} ln peers from dns seed'.format(len(peers)))
|
||||
return peers
|
||||
|
||||
@staticmethod
|
||||
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
|
||||
assert len(addr_list) >= 1
|
||||
# choose first one that is an IP
|
||||
for addr_in_db in addr_list:
|
||||
host = addr_in_db.host
|
||||
port = addr_in_db.port
|
||||
if is_ip_address(host):
|
||||
return host, port
|
||||
# otherwise choose one at random
|
||||
# TODO maybe filter out onion if not on tor?
|
||||
choice = random.choice(addr_list)
|
||||
return choice.host, choice.port
|
||||
|
||||
|
||||
class LNGossip(LNWorker):
|
||||
# height of first channel announcements
|
||||
first_block = 497000
|
||||
|
||||
def __init__(self, network):
|
||||
seed = os.urandom(32)
|
||||
|
@ -226,6 +239,61 @@ class LNGossip(LNWorker):
|
|||
super().__init__(xprv)
|
||||
self.localfeatures |= LnLocalFeatures.GOSSIP_QUERIES_REQ
|
||||
|
||||
def start_network(self, network: 'Network'):
|
||||
super().start_network(network)
|
||||
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(self.gossip_task()), self.network.asyncio_loop)
|
||||
|
||||
async def gossip_task(self):
|
||||
req_index = self.first_block
|
||||
req_num = self.network.get_local_height() - req_index
|
||||
while len(self.peers) == 0:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
# todo: parallelize over peers
|
||||
peer = list(self.peers.values())[0]
|
||||
await peer.initialized.wait()
|
||||
# send channels_range query. peer will reply with several intervals
|
||||
peer.query_channel_range(req_index, req_num)
|
||||
intervals = []
|
||||
ids = set()
|
||||
# wait until requested range is covered
|
||||
while True:
|
||||
index, num, complete, _ids = await peer.reply_channel_range.get()
|
||||
ids.update(_ids)
|
||||
intervals.append((index, index+num))
|
||||
intervals.sort()
|
||||
while len(intervals) > 1:
|
||||
a,b = intervals[0]
|
||||
c,d = intervals[1]
|
||||
if b == c:
|
||||
intervals = [(a,d)] + intervals[2:]
|
||||
else:
|
||||
break
|
||||
if len(intervals) == 1:
|
||||
a, b = intervals[0]
|
||||
if a <= req_index and b >= req_index + req_num:
|
||||
break
|
||||
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
|
||||
# TODO: filter results by date of last channel update, purge DB
|
||||
#if complete:
|
||||
# self.channel_db.purge_unknown_channels(ids)
|
||||
known = self.channel_db.compare_channels(ids)
|
||||
unknown = list(ids - set(known))
|
||||
total = len(unknown)
|
||||
N = 500
|
||||
while unknown:
|
||||
self.channel_db.process_gossip()
|
||||
await peer.query_short_channel_ids(unknown[0:N])
|
||||
unknown = unknown[N:]
|
||||
self.logger.info(f'Querying channels: {total - len(unknown)}/{total}. Count: {self.channel_db.num_channels}')
|
||||
|
||||
# request gossip fromm current time
|
||||
now = int(time.time())
|
||||
peer.request_gossip(now)
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
self.channel_db.process_gossip()
|
||||
|
||||
|
||||
class LNWallet(LNWorker):
|
||||
|
||||
|
@ -548,20 +616,6 @@ class LNWallet(LNWorker):
|
|||
def on_channels_updated(self):
|
||||
self.network.trigger_callback('channels')
|
||||
|
||||
@staticmethod
|
||||
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
|
||||
assert len(addr_list) >= 1
|
||||
# choose first one that is an IP
|
||||
for addr_in_db in addr_list:
|
||||
host = addr_in_db.host
|
||||
port = addr_in_db.port
|
||||
if is_ip_address(host):
|
||||
return host, port
|
||||
# otherwise choose one at random
|
||||
# TODO maybe filter out onion if not on tor?
|
||||
choice = random.choice(addr_list)
|
||||
return choice.host, choice.port
|
||||
|
||||
def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, password=None, timeout=20):
|
||||
node_id, rest = extract_nodeid(connect_contents)
|
||||
peer = self.peers.get(node_id)
|
||||
|
|
Loading…
Add table
Reference in a new issue