From e829d6bbcfd65910ff778406a0f14ac9cececa77 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Tue, 11 Sep 2018 20:24:01 +0200 Subject: [PATCH] wallet: put Sync and Verifier in their own TaskGroup, and that into interface.group --- electrum/address_synchronizer.py | 52 +++++++++++++++++++------------- electrum/daemon.py | 2 +- electrum/gui/kivy/main_window.py | 2 +- electrum/gui/qt/__init__.py | 2 +- electrum/gui/stdio.py | 2 +- electrum/gui/text.py | 2 +- electrum/interface.py | 22 ++++++-------- electrum/network.py | 14 +-------- electrum/synchronizer.py | 8 ++--- electrum/util.py | 18 ++++++++--- electrum/verifier.py | 12 +++++--- run_electrum | 2 +- 12 files changed, 73 insertions(+), 65 deletions(-) diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index f976b1f55..4793074d4 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -26,11 +26,9 @@ import asyncio import itertools from collections import defaultdict -from aiorpcx import TaskGroup - from . import bitcoin from .bitcoin import COINBASE_MATURITY, TYPE_ADDRESS, TYPE_PUBKEY -from .util import PrintError, profiler, bfh, VerifiedTxInfo, TxMinedStatus +from .util import PrintError, profiler, bfh, VerifiedTxInfo, TxMinedStatus, aiosafe, CustomTaskGroup from .transaction import Transaction, TxOutput from .synchronizer import Synchronizer from .verifier import SPV @@ -62,6 +60,7 @@ class AddressSynchronizer(PrintError): self.synchronizer = None self.verifier = None self.sync_restart_lock = asyncio.Lock() + self.group = None # locks: if you need to take multiple ones, acquire them in the order they are defined here! self.lock = threading.RLock() self.transaction_lock = threading.RLock() @@ -138,34 +137,45 @@ class AddressSynchronizer(PrintError): # add it in case it was previously unconfirmed self.add_unverified_tx(tx_hash, tx_height) - async def on_default_server_changed(self, evt): + @aiosafe + async def on_default_server_changed(self, event): async with self.sync_restart_lock: - interface = self.network.interface - if interface is None: - return # we should get called again soon - self.verifier = SPV(self.network, self) - self.synchronizer = Synchronizer(self) - await interface.group.spawn(self.verifier.main(interface)) - await interface.group.spawn(self.synchronizer.send_subscriptions(interface)) - await interface.group.spawn(self.synchronizer.handle_status(interface)) - await interface.group.spawn(self.synchronizer.main()) + self.stop_threads() + await self._start_threads() - def start_threads(self, network): + def start_network(self, network): self.network = network if self.network is not None: self.network.register_callback(self.on_default_server_changed, ['default_server_changed']) - self.network.trigger_callback('default_server_changed') - else: - self.verifier = None - self.synchronizer = None + asyncio.run_coroutine_threadsafe(self._start_threads(), network.asyncio_loop) + + async def _start_threads(self): + interface = self.network.interface + if interface is None: + return # we should get called again soon + + self.verifier = SPV(self.network, self) + self.synchronizer = synchronizer = Synchronizer(self) + assert self.group is None, 'group already exists' + self.group = CustomTaskGroup() + + async def job(): + async with self.group as group: + await group.spawn(self.verifier.main(group)) + await group.spawn(self.synchronizer.send_subscriptions(group)) + await group.spawn(self.synchronizer.handle_status(group)) + await group.spawn(self.synchronizer.main()) + # we are being cancelled now + interface.session.unsubscribe(synchronizer.status_queue) + await interface.group.spawn(job) def stop_threads(self): if self.network: - #self.network.remove_jobs([self.verifier]) self.synchronizer = None self.verifier = None - # Now no references to the synchronizer or verifier - # remain so they will be GC-ed + if self.group: + asyncio.run_coroutine_threadsafe(self.group.cancel_remaining(), self.network.asyncio_loop) + self.group = None self.storage.put('stored_height', self.get_local_height()) self.save_transactions() self.save_verified_tx() diff --git a/electrum/daemon.py b/electrum/daemon.py index f2a9316e0..c939d1097 100644 --- a/electrum/daemon.py +++ b/electrum/daemon.py @@ -243,7 +243,7 @@ class Daemon(DaemonThread): if storage.get_action(): return wallet = Wallet(storage) - wallet.start_threads(self.network) + wallet.start_network(self.network) self.wallets[path] = wallet return wallet diff --git a/electrum/gui/kivy/main_window.py b/electrum/gui/kivy/main_window.py index 9ccabfbb8..0dbb5280f 100644 --- a/electrum/gui/kivy/main_window.py +++ b/electrum/gui/kivy/main_window.py @@ -512,7 +512,7 @@ class ElectrumWindow(App): def on_wizard_complete(self, wizard, wallet): if wallet: # wizard returned a wallet - wallet.start_threads(self.daemon.network) + wallet.start_network(self.daemon.network) self.daemon.add_wallet(wallet) self.load_wallet(wallet) elif not self.wallet: diff --git a/electrum/gui/qt/__init__.py b/electrum/gui/qt/__init__.py index b49826c66..3fd204d7d 100644 --- a/electrum/gui/qt/__init__.py +++ b/electrum/gui/qt/__init__.py @@ -236,7 +236,7 @@ class ElectrumGui: if not self.daemon.get_wallet(wallet.storage.path): # wallet was not in memory - wallet.start_threads(self.daemon.network) + wallet.start_network(self.daemon.network) self.daemon.add_wallet(wallet) try: for w in self.windows: diff --git a/electrum/gui/stdio.py b/electrum/gui/stdio.py index d3d16342c..cbb0d3403 100644 --- a/electrum/gui/stdio.py +++ b/electrum/gui/stdio.py @@ -34,7 +34,7 @@ class ElectrumGui: self.str_fee = "" self.wallet = Wallet(storage) - self.wallet.start_threads(self.network) + self.wallet.start_network(self.network) self.contacts = self.wallet.contacts self.network.register_callback(self.on_network, ['updated', 'banner']) diff --git a/electrum/gui/text.py b/electrum/gui/text.py index d00528ccc..981340139 100644 --- a/electrum/gui/text.py +++ b/electrum/gui/text.py @@ -30,7 +30,7 @@ class ElectrumGui: password = getpass.getpass('Password:', stream=None) storage.decrypt(password) self.wallet = Wallet(storage) - self.wallet.start_threads(self.network) + self.wallet.start_network(self.network) self.contacts = self.wallet.contacts locale.setlocale(locale.LC_ALL, '') diff --git a/electrum/interface.py b/electrum/interface.py index e35890004..da2e08ad8 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -24,24 +24,21 @@ # SOFTWARE. import os import re -import socket import ssl import sys import traceback import asyncio -import concurrent.futures from typing import Tuple, Union import aiorpcx -from aiorpcx import ClientSession, Notification, TaskGroup +from aiorpcx import ClientSession, Notification -from .util import PrintError, aiosafe, bfh, AIOSafeSilentException +from .util import PrintError, aiosafe, bfh, AIOSafeSilentException, CustomTaskGroup from . import util from . import x509 from . import pem from .version import ELECTRUM_VERSION, PROTOCOL_VERSION from . import blockchain -from .blockchain import deserialize_header from . import constants @@ -83,6 +80,14 @@ class NotificationSession(ClientSession): self.cache[key] = result await queue.put(params + [result]) + def unsubscribe(self, queue): + """Unsubscribe a callback to free object references to enable GC.""" + # note: we can't unsubscribe from the server, so we keep receiving + # subsequent notifications + for v in self.subscriptions.values(): + if queue in v: + v.remove(queue) + # FIXME this is often raised inside a TaskGroup, but then it's not silent :( class GracefulDisconnect(AIOSafeSilentException): pass @@ -94,13 +99,6 @@ class ErrorParsingSSLCert(Exception): pass class ErrorGettingSSLCertFromServer(Exception): pass -class CustomTaskGroup(TaskGroup): - - def spawn(self, *args, **kwargs): - if self._closed: - raise asyncio.CancelledError() - return super().spawn(*args, **kwargs) - def deserialize_server(server_str: str) -> Tuple[str, str, str]: # host might be IPv6 address, hence do rsplit: diff --git a/electrum/network.py b/electrum/network.py index d2fc8beb8..20ebde59a 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -211,9 +211,6 @@ class Network(PrintError): self.banner = '' self.donation_address = '' self.relay_fee = None - # callbacks passed with subscriptions - self.subscriptions = defaultdict(list) # note: needs self.callback_lock - self.sub_cache = {} # note: needs self.interface_lock # callbacks set by the GUI self.callbacks = defaultdict(list) # note: needs self.callback_lock @@ -272,6 +269,7 @@ class Network(PrintError): callbacks = self.callbacks[event][:] for callback in callbacks: if asyncio.iscoroutinefunction(callback): + # FIXME: if callback throws, we will lose the traceback asyncio.run_coroutine_threadsafe(callback(event, *args), self.asyncio_loop) else: callback(event, *args) @@ -605,16 +603,6 @@ class Network(PrintError): """ hashable index for subscriptions and cache""" return str(method) + (':' + str(params[0]) if params else '') - def unsubscribe(self, callback): - '''Unsubscribe a callback to free object references to enable GC.''' - # Note: we can't unsubscribe from the server, so if we receive - # subsequent notifications process_response() will emit a harmless - # "received unexpected notification" warning - with self.callback_lock: - for v in self.subscriptions.values(): - if callback in v: - v.remove(callback) - @with_interface_lock def connection_down(self, server): '''A connection to server either went down, or was never made. diff --git a/electrum/synchronizer.py b/electrum/synchronizer.py index 8182f8321..a74961a31 100644 --- a/electrum/synchronizer.py +++ b/electrum/synchronizer.py @@ -144,16 +144,16 @@ class Synchronizer(PrintError): await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue) self.requested_addrs.remove(addr) - async def send_subscriptions(self, interface): + async def send_subscriptions(self, group: TaskGroup): while True: addr = await self.add_queue.get() - await interface.group.spawn(self.subscribe_to_address, addr) + await group.spawn(self.subscribe_to_address, addr) - async def handle_status(self, interface): + async def handle_status(self, group: TaskGroup): while True: h, status = await self.status_queue.get() addr = self.scripthash_to_address[h] - await interface.group.spawn(self.on_address_status, addr, status) + await group.spawn(self.on_address_status, addr, status) @property def session(self): diff --git a/electrum/util.py b/electrum/util.py index d252faa75..9d79ef99a 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -35,14 +35,15 @@ import stat import inspect from locale import localeconv import asyncio - -from .i18n import _ +import urllib.request, urllib.parse, urllib.error +import queue import aiohttp from aiohttp_socks import SocksConnector, SocksVer +from aiorpcx import TaskGroup + +from .i18n import _ -import urllib.request, urllib.parse, urllib.error -import queue def inv_dict(d): return {v: k for k, v in d.items()} @@ -972,3 +973,12 @@ def make_aiohttp_session(proxy): return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10), connector=connector) else: return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10)) + + +class CustomTaskGroup(TaskGroup): + + def spawn(self, *args, **kwargs): + # don't complain if group is already closed. + if self._closed: + raise asyncio.CancelledError() + return super().spawn(*args, **kwargs) diff --git a/electrum/verifier.py b/electrum/verifier.py index 493d402f0..3b502e7af 100644 --- a/electrum/verifier.py +++ b/electrum/verifier.py @@ -24,6 +24,8 @@ import asyncio from typing import Sequence, Optional +from aiorpcx import TaskGroup + from .util import ThreadJob, bh2u, VerifiedTxInfo from .bitcoin import Hash, hash_decode, hash_encode from .transaction import Transaction @@ -47,12 +49,12 @@ class SPV(ThreadJob): self.merkle_roots = {} # txid -> merkle root (once it has been verified) self.requested_merkle = set() # txid set of pending requests - async def main(self, interface): + async def main(self, group: TaskGroup): while True: - await self._request_proofs(interface) + await self._request_proofs(group) await asyncio.sleep(0.1) - async def _request_proofs(self, interface): + async def _request_proofs(self, group: TaskGroup): blockchain = self.network.blockchain() if not blockchain: self.print_error("no blockchain") @@ -70,12 +72,12 @@ class SPV(ThreadJob): if header is None: index = tx_height // 2016 if index < len(blockchain.checkpoints): - await interface.group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True)) + await group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True)) elif (tx_hash not in self.requested_merkle and tx_hash not in self.merkle_roots): self.print_error('requested merkle', tx_hash) self.requested_merkle.add(tx_hash) - await interface.group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height) + await group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height) if self.network.blockchain() != self.blockchain: self.blockchain = self.network.blockchain() diff --git a/run_electrum b/run_electrum index 3db004e22..0a43adbcf 100755 --- a/run_electrum +++ b/run_electrum @@ -135,7 +135,7 @@ def run_non_RPC(config): if not config.get('offline'): network = Network(config) network.start() - wallet.start_threads(network) + wallet.start_network(network) print_msg("Recovering wallet...") wallet.synchronize() wallet.wait_until_synchronized()