lnworker: fix threading issues for .channels attribute

external code (commands/gui) did not always take lock when iterating lnworker.channels.
instead of exposing lock, let's take a copy internally (as with .peers)
This commit is contained in:
SomberNight 2020-04-30 21:08:26 +02:00
parent f5eb91900a
commit b9b53e7f76
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
3 changed files with 43 additions and 44 deletions

View file

@ -107,7 +107,7 @@ class Peer(Logger):
if not (message_name.startswith("update_") or is_commitment_signed): if not (message_name.startswith("update_") or is_commitment_signed):
return return
assert channel_id assert channel_id
chan = self.lnworker.channels[channel_id] # type: Channel chan = self.channels[channel_id]
chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed) chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
if is_commitment_signed: if is_commitment_signed:
# saving now, to ensure replaying updates works (in case of channel reestablishment) # saving now, to ensure replaying updates works (in case of channel reestablishment)
@ -139,6 +139,8 @@ class Peer(Logger):
@property @property
def channels(self) -> Dict[bytes, Channel]: def channels(self) -> Dict[bytes, Channel]:
# FIXME this iterates over all channels in lnworker,
# so if we just want to lookup a channel by channel_id, it's wasteful
return self.lnworker.channels_for_peer(self.pubkey) return self.lnworker.channels_for_peer(self.pubkey)
def diagnostic_name(self): def diagnostic_name(self):

View file

@ -491,22 +491,26 @@ class LNWallet(LNWorker):
self.enable_htlc_settle.set() self.enable_htlc_settle.set()
# note: accessing channels (besides simple lookup) needs self.lock! # note: accessing channels (besides simple lookup) needs self.lock!
self.channels = {} self._channels = {} # type: Dict[bytes, Channel]
channels = self.db.get_dict("channels") channels = self.db.get_dict("channels")
for channel_id, c in channels.items(): for channel_id, c in channels.items():
self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self) self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]] self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]]
@property
def channels(self) -> Mapping[bytes, Channel]:
"""Returns a read-only copy of channels."""
with self.lock:
return self._channels.copy()
@ignore_exceptions @ignore_exceptions
@log_exceptions @log_exceptions
async def sync_with_local_watchtower(self): async def sync_with_local_watchtower(self):
watchtower = self.network.local_watchtower watchtower = self.network.local_watchtower
if watchtower: if watchtower:
while True: while True:
with self.lock: for chan in self.channels.values():
channels = list(self.channels.values())
for chan in channels:
await self.sync_channel_with_watchtower(chan, watchtower.sweepstore) await self.sync_channel_with_watchtower(chan, watchtower.sweepstore)
await asyncio.sleep(5) await asyncio.sleep(5)
@ -524,12 +528,10 @@ class LNWallet(LNWorker):
watchtower_url = self.config.get('watchtower_url') watchtower_url = self.config.get('watchtower_url')
if not watchtower_url: if not watchtower_url:
continue continue
with self.lock:
channels = list(self.channels.values())
try: try:
async with make_aiohttp_session(proxy=self.network.proxy) as session: async with make_aiohttp_session(proxy=self.network.proxy) as session:
watchtower = myAiohttpClient(session, watchtower_url) watchtower = myAiohttpClient(session, watchtower_url)
for chan in channels: for chan in self.channels.values():
await self.sync_channel_with_watchtower(chan, watchtower) await self.sync_channel_with_watchtower(chan, watchtower)
except aiohttp.client_exceptions.ClientConnectorError: except aiohttp.client_exceptions.ClientConnectorError:
self.logger.info(f'could not contact remote watchtower {watchtower_url}') self.logger.info(f'could not contact remote watchtower {watchtower_url}')
@ -574,9 +576,7 @@ class LNWallet(LNWorker):
# return one item per payment_hash # return one item per payment_hash
# note: with AMP we will have several channels per payment # note: with AMP we will have several channels per payment
out = defaultdict(list) out = defaultdict(list)
with self.lock: for chan in self.channels.values():
channels = list(self.channels.values())
for chan in channels:
d = chan.get_settled_payments() d = chan.get_settled_payments()
for k, v in d.items(): for k, v in d.items():
out[k].append(v) out[k].append(v)
@ -628,9 +628,7 @@ class LNWallet(LNWorker):
def get_onchain_history(self): def get_onchain_history(self):
out = {} out = {}
# add funding events # add funding events
with self.lock: for chan in self.channels.values():
channels = list(self.channels.values())
for chan in channels:
item = chan.get_funding_height() item = chan.get_funding_height()
if item is None: if item is None:
continue continue
@ -693,8 +691,7 @@ class LNWallet(LNWorker):
def channels_for_peer(self, node_id): def channels_for_peer(self, node_id):
assert type(node_id) is bytes assert type(node_id) is bytes
with self.lock: return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
def channel_state_changed(self, chan): def channel_state_changed(self, chan):
self.save_channel(chan) self.save_channel(chan)
@ -708,9 +705,7 @@ class LNWallet(LNWorker):
util.trigger_callback('channel', chan) util.trigger_callback('channel', chan)
def channel_by_txo(self, txo): def channel_by_txo(self, txo):
with self.lock: for chan in self.channels.values():
channels = list(self.channels.values())
for chan in channels:
if chan.funding_outpoint.to_str() == txo: if chan.funding_outpoint.to_str() == txo:
return chan return chan
@ -762,7 +757,7 @@ class LNWallet(LNWorker):
def add_channel(self, chan): def add_channel(self, chan):
with self.lock: with self.lock:
self.channels[chan.channel_id] = chan self._channels[chan.channel_id] = chan
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
def add_new_channel(self, chan): def add_new_channel(self, chan):
@ -805,10 +800,9 @@ class LNWallet(LNWorker):
success = fut.result() success = fut.result()
def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel: def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel:
with self.lock: for chan in self.channels.values():
for chan in self.channels.values(): if chan.short_channel_id == short_channel_id:
if chan.short_channel_id == short_channel_id: return chan
return chan
async def _pay(self, invoice, amount_sat=None, attempts=1) -> bool: async def _pay(self, invoice, amount_sat=None, attempts=1) -> bool:
lnaddr = self._check_invoice(invoice, amount_sat) lnaddr = self._check_invoice(invoice, amount_sat)
@ -981,8 +975,7 @@ class LNWallet(LNWorker):
# if there are multiple hints, we will use the first one that works, # if there are multiple hints, we will use the first one that works,
# from a random permutation # from a random permutation
random.shuffle(r_tags) random.shuffle(r_tags)
with self.lock: channels = list(self.channels.values())
channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None} if chan.short_channel_id is not None}
for private_route in r_tags: for private_route in r_tags:
@ -1196,8 +1189,7 @@ class LNWallet(LNWorker):
async def _calc_routing_hints_for_invoice(self, amount_sat: Optional[int]): async def _calc_routing_hints_for_invoice(self, amount_sat: Optional[int]):
"""calculate routing hints (BOLT-11 'r' field)""" """calculate routing hints (BOLT-11 'r' field)"""
routing_hints = [] routing_hints = []
with self.lock: channels = list(self.channels.values())
channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None} if chan.short_channel_id is not None}
ignore_min_htlc_value = False ignore_min_htlc_value = False
@ -1251,24 +1243,27 @@ class LNWallet(LNWorker):
def get_balance(self): def get_balance(self):
with self.lock: with self.lock:
return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000 return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0
for chan in self.channels.values())) / 1000
def num_sats_can_send(self) -> Union[Decimal, int]: def num_sats_can_send(self) -> Union[Decimal, int]:
with self.lock: with self.lock:
return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0 return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0
for chan in self.channels.values()))/1000 if self.channels else 0
def num_sats_can_receive(self) -> Union[Decimal, int]: def num_sats_can_receive(self) -> Union[Decimal, int]:
with self.lock: with self.lock:
return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0 return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0
for chan in self.channels.values()))/1000 if self.channels else 0
async def close_channel(self, chan_id): async def close_channel(self, chan_id):
chan = self.channels[chan_id] chan = self._channels[chan_id]
peer = self._peers[chan.node_id] peer = self._peers[chan.node_id]
return await peer.close_channel(chan_id) return await peer.close_channel(chan_id)
async def force_close_channel(self, chan_id): async def force_close_channel(self, chan_id):
# returns txid or raises # returns txid or raises
chan = self.channels[chan_id] chan = self._channels[chan_id]
tx = chan.force_close_tx() tx = chan.force_close_tx()
await self.network.broadcast_transaction(tx) await self.network.broadcast_transaction(tx)
chan.set_state(ChannelState.FORCE_CLOSING) chan.set_state(ChannelState.FORCE_CLOSING)
@ -1276,16 +1271,16 @@ class LNWallet(LNWorker):
async def try_force_closing(self, chan_id): async def try_force_closing(self, chan_id):
# fails silently but sets the state, so that we will retry later # fails silently but sets the state, so that we will retry later
chan = self.channels[chan_id] chan = self._channels[chan_id]
tx = chan.force_close_tx() tx = chan.force_close_tx()
chan.set_state(ChannelState.FORCE_CLOSING) chan.set_state(ChannelState.FORCE_CLOSING)
await self.network.try_broadcasting(tx, 'force-close') await self.network.try_broadcasting(tx, 'force-close')
def remove_channel(self, chan_id): def remove_channel(self, chan_id):
chan = self.channels[chan_id] chan = self._channels[chan_id]
assert chan.get_state() == ChannelState.REDEEMED assert chan.get_state() == ChannelState.REDEEMED
with self.lock: with self.lock:
self.channels.pop(chan_id) self._channels.pop(chan_id)
self.db.get('channels').pop(chan_id.hex()) self.db.get('channels').pop(chan_id.hex())
util.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
@ -1316,9 +1311,7 @@ class LNWallet(LNWorker):
async def reestablish_peers_and_channels(self): async def reestablish_peers_and_channels(self):
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
with self.lock: for chan in self.channels.values():
channels = list(self.channels.values())
for chan in channels:
if chan.is_closed(): if chan.is_closed():
continue continue
# reestablish # reestablish
@ -1340,7 +1333,7 @@ class LNWallet(LNWorker):
return max(253, feerate_per_kvbyte // 4) return max(253, feerate_per_kvbyte // 4)
def create_channel_backup(self, channel_id): def create_channel_backup(self, channel_id):
chan = self.channels[channel_id] chan = self._channels[channel_id]
peer_addresses = list(chan.get_peer_addresses()) peer_addresses = list(chan.get_peer_addresses())
peer_addr = peer_addresses[0] peer_addr = peer_addresses[0]
return ChannelBackupStorage( return ChannelBackupStorage(

View file

@ -102,7 +102,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.remote_keypair = remote_keypair self.remote_keypair = remote_keypair
self.node_keypair = local_keypair self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue) self.network = MockNetwork(tx_queue)
self.channels = {chan.channel_id: chan} self._channels = {chan.channel_id: chan}
self.payments = {} self.payments = {}
self.logs = defaultdict(list) self.logs = defaultdict(list)
self.wallet = MockWallet() self.wallet = MockWallet()
@ -122,6 +122,10 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
def lock(self): def lock(self):
return noop_lock() return noop_lock()
@property
def channels(self):
return self._channels
@property @property
def peers(self): def peers(self):
return self._peers return self._peers
@ -131,11 +135,11 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
return {self.remote_keypair.pubkey: self.peer} return {self.remote_keypair.pubkey: self.peer}
def channels_for_peer(self, pubkey): def channels_for_peer(self, pubkey):
return self.channels return self._channels
def get_channel_by_short_id(self, short_channel_id): def get_channel_by_short_id(self, short_channel_id):
with self.lock: with self.lock:
for chan in self.channels.values(): for chan in self._channels.values():
if chan.short_channel_id == short_channel_id: if chan.short_channel_id == short_channel_id:
return chan return chan