lnworker: fix threading issues for .channels attribute

external code (commands/gui) did not always take lock when iterating lnworker.channels.
instead of exposing lock, let's take a copy internally (as with .peers)
This commit is contained in:
SomberNight 2020-04-30 21:08:26 +02:00
parent f5eb91900a
commit b9b53e7f76
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
3 changed files with 43 additions and 44 deletions

View file

@ -107,7 +107,7 @@ class Peer(Logger):
if not (message_name.startswith("update_") or is_commitment_signed):
return
assert channel_id
chan = self.lnworker.channels[channel_id] # type: Channel
chan = self.channels[channel_id]
chan.hm.store_local_update_raw_msg(raw_msg, is_commitment_signed=is_commitment_signed)
if is_commitment_signed:
# saving now, to ensure replaying updates works (in case of channel reestablishment)
@ -139,6 +139,8 @@ class Peer(Logger):
@property
def channels(self) -> Dict[bytes, Channel]:
# FIXME this iterates over all channels in lnworker,
# so if we just want to lookup a channel by channel_id, it's wasteful
return self.lnworker.channels_for_peer(self.pubkey)
def diagnostic_name(self):

View file

@ -491,22 +491,26 @@ class LNWallet(LNWorker):
self.enable_htlc_settle.set()
# note: accessing channels (besides simple lookup) needs self.lock!
self.channels = {}
self._channels = {} # type: Dict[bytes, Channel]
channels = self.db.get_dict("channels")
for channel_id, c in channels.items():
self.channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self._channels[bfh(channel_id)] = Channel(c, sweep_address=self.sweep_address, lnworker=self)
self.pending_payments = defaultdict(asyncio.Future) # type: Dict[bytes, asyncio.Future[BarePaymentAttemptLog]]
@property
def channels(self) -> Mapping[bytes, Channel]:
"""Returns a read-only copy of channels."""
with self.lock:
return self._channels.copy()
@ignore_exceptions
@log_exceptions
async def sync_with_local_watchtower(self):
watchtower = self.network.local_watchtower
if watchtower:
while True:
with self.lock:
channels = list(self.channels.values())
for chan in channels:
for chan in self.channels.values():
await self.sync_channel_with_watchtower(chan, watchtower.sweepstore)
await asyncio.sleep(5)
@ -524,12 +528,10 @@ class LNWallet(LNWorker):
watchtower_url = self.config.get('watchtower_url')
if not watchtower_url:
continue
with self.lock:
channels = list(self.channels.values())
try:
async with make_aiohttp_session(proxy=self.network.proxy) as session:
watchtower = myAiohttpClient(session, watchtower_url)
for chan in channels:
for chan in self.channels.values():
await self.sync_channel_with_watchtower(chan, watchtower)
except aiohttp.client_exceptions.ClientConnectorError:
self.logger.info(f'could not contact remote watchtower {watchtower_url}')
@ -574,9 +576,7 @@ class LNWallet(LNWorker):
# return one item per payment_hash
# note: with AMP we will have several channels per payment
out = defaultdict(list)
with self.lock:
channels = list(self.channels.values())
for chan in channels:
for chan in self.channels.values():
d = chan.get_settled_payments()
for k, v in d.items():
out[k].append(v)
@ -628,9 +628,7 @@ class LNWallet(LNWorker):
def get_onchain_history(self):
out = {}
# add funding events
with self.lock:
channels = list(self.channels.values())
for chan in channels:
for chan in self.channels.values():
item = chan.get_funding_height()
if item is None:
continue
@ -693,8 +691,7 @@ class LNWallet(LNWorker):
def channels_for_peer(self, node_id):
assert type(node_id) is bytes
with self.lock:
return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
return {x: y for (x, y) in self.channels.items() if y.node_id == node_id}
def channel_state_changed(self, chan):
self.save_channel(chan)
@ -708,9 +705,7 @@ class LNWallet(LNWorker):
util.trigger_callback('channel', chan)
def channel_by_txo(self, txo):
with self.lock:
channels = list(self.channels.values())
for chan in channels:
for chan in self.channels.values():
if chan.funding_outpoint.to_str() == txo:
return chan
@ -762,7 +757,7 @@ class LNWallet(LNWorker):
def add_channel(self, chan):
with self.lock:
self.channels[chan.channel_id] = chan
self._channels[chan.channel_id] = chan
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
def add_new_channel(self, chan):
@ -805,10 +800,9 @@ class LNWallet(LNWorker):
success = fut.result()
def get_channel_by_short_id(self, short_channel_id: ShortChannelID) -> Channel:
with self.lock:
for chan in self.channels.values():
if chan.short_channel_id == short_channel_id:
return chan
for chan in self.channels.values():
if chan.short_channel_id == short_channel_id:
return chan
async def _pay(self, invoice, amount_sat=None, attempts=1) -> bool:
lnaddr = self._check_invoice(invoice, amount_sat)
@ -981,8 +975,7 @@ class LNWallet(LNWorker):
# if there are multiple hints, we will use the first one that works,
# from a random permutation
random.shuffle(r_tags)
with self.lock:
channels = list(self.channels.values())
channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
for private_route in r_tags:
@ -1196,8 +1189,7 @@ class LNWallet(LNWorker):
async def _calc_routing_hints_for_invoice(self, amount_sat: Optional[int]):
"""calculate routing hints (BOLT-11 'r' field)"""
routing_hints = []
with self.lock:
channels = list(self.channels.values())
channels = list(self.channels.values())
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
ignore_min_htlc_value = False
@ -1251,24 +1243,27 @@ class LNWallet(LNWorker):
def get_balance(self):
with self.lock:
return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0 for chan in self.channels.values()))/1000
return Decimal(sum(chan.balance(LOCAL) if not chan.is_closed() else 0
for chan in self.channels.values())) / 1000
def num_sats_can_send(self) -> Union[Decimal, int]:
with self.lock:
return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0
return Decimal(max(chan.available_to_spend(LOCAL) if chan.is_open() else 0
for chan in self.channels.values()))/1000 if self.channels else 0
def num_sats_can_receive(self) -> Union[Decimal, int]:
with self.lock:
return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0 for chan in self.channels.values()))/1000 if self.channels else 0
return Decimal(max(chan.available_to_spend(REMOTE) if chan.is_open() else 0
for chan in self.channels.values()))/1000 if self.channels else 0
async def close_channel(self, chan_id):
chan = self.channels[chan_id]
chan = self._channels[chan_id]
peer = self._peers[chan.node_id]
return await peer.close_channel(chan_id)
async def force_close_channel(self, chan_id):
# returns txid or raises
chan = self.channels[chan_id]
chan = self._channels[chan_id]
tx = chan.force_close_tx()
await self.network.broadcast_transaction(tx)
chan.set_state(ChannelState.FORCE_CLOSING)
@ -1276,16 +1271,16 @@ class LNWallet(LNWorker):
async def try_force_closing(self, chan_id):
# fails silently but sets the state, so that we will retry later
chan = self.channels[chan_id]
chan = self._channels[chan_id]
tx = chan.force_close_tx()
chan.set_state(ChannelState.FORCE_CLOSING)
await self.network.try_broadcasting(tx, 'force-close')
def remove_channel(self, chan_id):
chan = self.channels[chan_id]
chan = self._channels[chan_id]
assert chan.get_state() == ChannelState.REDEEMED
with self.lock:
self.channels.pop(chan_id)
self._channels.pop(chan_id)
self.db.get('channels').pop(chan_id.hex())
util.trigger_callback('channels_updated', self.wallet)
@ -1316,9 +1311,7 @@ class LNWallet(LNWorker):
async def reestablish_peers_and_channels(self):
while True:
await asyncio.sleep(1)
with self.lock:
channels = list(self.channels.values())
for chan in channels:
for chan in self.channels.values():
if chan.is_closed():
continue
# reestablish
@ -1340,7 +1333,7 @@ class LNWallet(LNWorker):
return max(253, feerate_per_kvbyte // 4)
def create_channel_backup(self, channel_id):
chan = self.channels[channel_id]
chan = self._channels[channel_id]
peer_addresses = list(chan.get_peer_addresses())
peer_addr = peer_addresses[0]
return ChannelBackupStorage(

View file

@ -102,7 +102,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
self.remote_keypair = remote_keypair
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
self.channels = {chan.channel_id: chan}
self._channels = {chan.channel_id: chan}
self.payments = {}
self.logs = defaultdict(list)
self.wallet = MockWallet()
@ -122,6 +122,10 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
def lock(self):
return noop_lock()
@property
def channels(self):
return self._channels
@property
def peers(self):
return self._peers
@ -131,11 +135,11 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
return {self.remote_keypair.pubkey: self.peer}
def channels_for_peer(self, pubkey):
return self.channels
return self._channels
def get_channel_by_short_id(self, short_channel_id):
with self.lock:
for chan in self.channels.values():
for chan in self._channels.values():
if chan.short_channel_id == short_channel_id:
return chan