mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-31 01:11:35 +00:00
ChannelDB: add self.lock and make it thread-safe
This commit is contained in:
parent
1ca6f6f306
commit
fd56fb9189
1 changed files with 46 additions and 25 deletions
|
@ -31,6 +31,7 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
|
|||
import binascii
|
||||
import base64
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
|
||||
from .sql_db import SqlDB, sql
|
||||
|
@ -247,17 +248,21 @@ class ChannelDB(SqlDB):
|
|||
def __init__(self, network: 'Network'):
|
||||
path = os.path.join(get_headers_dir(network.config), 'gossip_db')
|
||||
super().__init__(network, path, commit_interval=100)
|
||||
self.lock = threading.RLock()
|
||||
self.num_nodes = 0
|
||||
self.num_channels = 0
|
||||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
||||
self.ca_verifier = LNChannelVerifier(network, self)
|
||||
|
||||
# initialized in load_data
|
||||
# note: modify/iterate needs self.lock
|
||||
self._channels = {} # type: Dict[bytes, ChannelInfo]
|
||||
self._policies = {} # type: Dict[Tuple[bytes, bytes], Policy] # (node_id, scid) -> Policy
|
||||
self._nodes = {} # type: Dict[bytes, NodeInfo] # node_id -> NodeInfo
|
||||
# node_id -> (host, port, ts)
|
||||
self._addresses = defaultdict(set) # type: Dict[bytes, Set[Tuple[str, int, int]]]
|
||||
self._channels_for_node = defaultdict(set) # type: Dict[bytes, Set[ShortChannelID]]
|
||||
|
||||
self.data_loaded = asyncio.Event()
|
||||
self.network = network # only for callback
|
||||
|
||||
|
@ -268,16 +273,19 @@ class ChannelDB(SqlDB):
|
|||
self.network.trigger_callback('channel_db', self.num_nodes, self.num_channels, self.num_policies)
|
||||
|
||||
def get_channel_ids(self):
|
||||
return set(self._channels.keys())
|
||||
with self.lock:
|
||||
return set(self._channels.keys())
|
||||
|
||||
def add_recent_peer(self, peer: LNPeerAddr):
|
||||
now = int(time.time())
|
||||
node_id = peer.pubkey
|
||||
self._addresses[node_id].add((peer.host, peer.port, now))
|
||||
with self.lock:
|
||||
self._addresses[node_id].add((peer.host, peer.port, now))
|
||||
self.save_node_address(node_id, peer, now)
|
||||
|
||||
def get_200_randomly_sorted_nodes_not_in(self, node_ids):
|
||||
unshuffled = set(self._nodes.keys()) - node_ids
|
||||
with self.lock:
|
||||
unshuffled = set(self._nodes.keys()) - node_ids
|
||||
return random.sample(unshuffled, min(200, len(unshuffled)))
|
||||
|
||||
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
|
||||
|
@ -296,8 +304,10 @@ class ChannelDB(SqlDB):
|
|||
# FIXME this does not reliably return "recent" peers...
|
||||
# Also, the list() cast over the whole dict (thousands of elements),
|
||||
# is really inefficient.
|
||||
with self.lock:
|
||||
_addresses_keys = list(self._addresses.keys())
|
||||
r = [self.get_last_good_address(node_id)
|
||||
for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]]
|
||||
for node_id in _addresses_keys[-self.NUM_MAX_RECENT_PEERS:]]
|
||||
return list(reversed(r))
|
||||
|
||||
# note: currently channel announcements are trusted by default (trusted=True);
|
||||
|
@ -336,9 +346,10 @@ class ChannelDB(SqlDB):
|
|||
except UnknownEvenFeatureBits:
|
||||
return
|
||||
channel_info = channel_info._replace(capacity_sat=capacity_sat)
|
||||
self._channels[channel_info.short_channel_id] = channel_info
|
||||
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
||||
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
||||
with self.lock:
|
||||
self._channels[channel_info.short_channel_id] = channel_info
|
||||
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
|
||||
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
||||
if 'raw' in msg:
|
||||
self.save_channel(channel_info.short_channel_id, msg['raw'])
|
||||
|
||||
|
@ -397,7 +408,8 @@ class ChannelDB(SqlDB):
|
|||
if verify:
|
||||
self.verify_channel_update(payload)
|
||||
policy = Policy.from_msg(payload)
|
||||
self._policies[key] = policy
|
||||
with self.lock:
|
||||
self._policies[key] = policy
|
||||
if 'raw' in payload:
|
||||
self.save_policy(policy.key, payload['raw'])
|
||||
#
|
||||
|
@ -492,32 +504,38 @@ class ChannelDB(SqlDB):
|
|||
if node and node.timestamp >= node_info.timestamp:
|
||||
continue
|
||||
# save
|
||||
self._nodes[node_id] = node_info
|
||||
with self.lock:
|
||||
self._nodes[node_id] = node_info
|
||||
if 'raw' in msg_payload:
|
||||
self.save_node_info(node_id, msg_payload['raw'])
|
||||
for addr in node_addresses:
|
||||
self._addresses[node_id].add((addr.host, addr.port, 0))
|
||||
with self.lock:
|
||||
for addr in node_addresses:
|
||||
self._addresses[node_id].add((addr.host, addr.port, 0))
|
||||
self.save_node_addresses(node_id, node_addresses)
|
||||
|
||||
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
|
||||
self.update_counts()
|
||||
|
||||
def get_old_policies(self, delta):
|
||||
with self.lock:
|
||||
_policies = self._policies.copy()
|
||||
now = int(time.time())
|
||||
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
|
||||
return list(k for k, v in _policies.items() if v.timestamp <= now - delta)
|
||||
|
||||
def prune_old_policies(self, delta):
|
||||
l = self.get_old_policies(delta)
|
||||
if l:
|
||||
for k in l:
|
||||
self._policies.pop(k)
|
||||
with self.lock:
|
||||
self._policies.pop(k)
|
||||
self.delete_policy(*k)
|
||||
self.update_counts()
|
||||
self.logger.info(f'Deleting {len(l)} old policies')
|
||||
|
||||
def get_orphaned_channels(self):
|
||||
ids = set(x[1] for x in self._policies.keys())
|
||||
return list(x for x in self._channels.keys() if x not in ids)
|
||||
with self.lock:
|
||||
ids = set(x[1] for x in self._policies.keys())
|
||||
return list(x for x in self._channels.keys() if x not in ids)
|
||||
|
||||
def prune_orphaned_channels(self):
|
||||
l = self.get_orphaned_channels()
|
||||
|
@ -535,10 +553,11 @@ class ChannelDB(SqlDB):
|
|||
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
|
||||
|
||||
def remove_channel(self, short_channel_id: ShortChannelID):
|
||||
channel_info = self._channels.pop(short_channel_id, None)
|
||||
if channel_info:
|
||||
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
|
||||
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
|
||||
with self.lock:
|
||||
channel_info = self._channels.pop(short_channel_id, None)
|
||||
if channel_info:
|
||||
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id)
|
||||
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id)
|
||||
# delete from database
|
||||
self.delete_channel(short_channel_id)
|
||||
|
||||
|
@ -571,17 +590,19 @@ class ChannelDB(SqlDB):
|
|||
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
|
||||
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
|
||||
self.update_counts()
|
||||
self.count_incomplete_channels()
|
||||
self.logger.info(f'semi-orphaned channels: {self.get_num_incomplete_channels()}')
|
||||
self.data_loaded.set()
|
||||
|
||||
def count_incomplete_channels(self):
|
||||
out = set()
|
||||
for short_channel_id, ci in self._channels.items():
|
||||
def get_num_incomplete_channels(self) -> int:
|
||||
found = set()
|
||||
with self.lock:
|
||||
_channels = self._channels.copy()
|
||||
for short_channel_id, ci in _channels.items():
|
||||
p1 = self.get_policy_for_node(short_channel_id, ci.node1_id)
|
||||
p2 = self.get_policy_for_node(short_channel_id, ci.node2_id)
|
||||
if p1 is None or p2 is not None:
|
||||
out.add(short_channel_id)
|
||||
self.logger.info(f'semi-orphaned: {len(out)}')
|
||||
found.add(short_channel_id)
|
||||
return len(found)
|
||||
|
||||
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes, *,
|
||||
my_channels: Dict[ShortChannelID, 'Channel'] = None) -> Optional['Policy']:
|
||||
|
|
Loading…
Add table
Reference in a new issue