mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-23 17:47:31 +00:00
wallet: put Sync and Verifier in their own TaskGroup, and that into interface.group
This commit is contained in:
parent
09dfb0fd1d
commit
e829d6bbcf
12 changed files with 73 additions and 65 deletions
|
@ -26,11 +26,9 @@ import asyncio
|
|||
import itertools
|
||||
from collections import defaultdict
|
||||
|
||||
from aiorpcx import TaskGroup
|
||||
|
||||
from . import bitcoin
|
||||
from .bitcoin import COINBASE_MATURITY, TYPE_ADDRESS, TYPE_PUBKEY
|
||||
from .util import PrintError, profiler, bfh, VerifiedTxInfo, TxMinedStatus
|
||||
from .util import PrintError, profiler, bfh, VerifiedTxInfo, TxMinedStatus, aiosafe, CustomTaskGroup
|
||||
from .transaction import Transaction, TxOutput
|
||||
from .synchronizer import Synchronizer
|
||||
from .verifier import SPV
|
||||
|
@ -62,6 +60,7 @@ class AddressSynchronizer(PrintError):
|
|||
self.synchronizer = None
|
||||
self.verifier = None
|
||||
self.sync_restart_lock = asyncio.Lock()
|
||||
self.group = None
|
||||
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
|
||||
self.lock = threading.RLock()
|
||||
self.transaction_lock = threading.RLock()
|
||||
|
@ -138,34 +137,45 @@ class AddressSynchronizer(PrintError):
|
|||
# add it in case it was previously unconfirmed
|
||||
self.add_unverified_tx(tx_hash, tx_height)
|
||||
|
||||
async def on_default_server_changed(self, evt):
|
||||
@aiosafe
|
||||
async def on_default_server_changed(self, event):
|
||||
async with self.sync_restart_lock:
|
||||
interface = self.network.interface
|
||||
if interface is None:
|
||||
return # we should get called again soon
|
||||
self.verifier = SPV(self.network, self)
|
||||
self.synchronizer = Synchronizer(self)
|
||||
await interface.group.spawn(self.verifier.main(interface))
|
||||
await interface.group.spawn(self.synchronizer.send_subscriptions(interface))
|
||||
await interface.group.spawn(self.synchronizer.handle_status(interface))
|
||||
await interface.group.spawn(self.synchronizer.main())
|
||||
self.stop_threads()
|
||||
await self._start_threads()
|
||||
|
||||
def start_threads(self, network):
|
||||
def start_network(self, network):
|
||||
self.network = network
|
||||
if self.network is not None:
|
||||
self.network.register_callback(self.on_default_server_changed, ['default_server_changed'])
|
||||
self.network.trigger_callback('default_server_changed')
|
||||
else:
|
||||
self.verifier = None
|
||||
self.synchronizer = None
|
||||
asyncio.run_coroutine_threadsafe(self._start_threads(), network.asyncio_loop)
|
||||
|
||||
async def _start_threads(self):
|
||||
interface = self.network.interface
|
||||
if interface is None:
|
||||
return # we should get called again soon
|
||||
|
||||
self.verifier = SPV(self.network, self)
|
||||
self.synchronizer = synchronizer = Synchronizer(self)
|
||||
assert self.group is None, 'group already exists'
|
||||
self.group = CustomTaskGroup()
|
||||
|
||||
async def job():
|
||||
async with self.group as group:
|
||||
await group.spawn(self.verifier.main(group))
|
||||
await group.spawn(self.synchronizer.send_subscriptions(group))
|
||||
await group.spawn(self.synchronizer.handle_status(group))
|
||||
await group.spawn(self.synchronizer.main())
|
||||
# we are being cancelled now
|
||||
interface.session.unsubscribe(synchronizer.status_queue)
|
||||
await interface.group.spawn(job)
|
||||
|
||||
def stop_threads(self):
|
||||
if self.network:
|
||||
#self.network.remove_jobs([self.verifier])
|
||||
self.synchronizer = None
|
||||
self.verifier = None
|
||||
# Now no references to the synchronizer or verifier
|
||||
# remain so they will be GC-ed
|
||||
if self.group:
|
||||
asyncio.run_coroutine_threadsafe(self.group.cancel_remaining(), self.network.asyncio_loop)
|
||||
self.group = None
|
||||
self.storage.put('stored_height', self.get_local_height())
|
||||
self.save_transactions()
|
||||
self.save_verified_tx()
|
||||
|
|
|
@ -243,7 +243,7 @@ class Daemon(DaemonThread):
|
|||
if storage.get_action():
|
||||
return
|
||||
wallet = Wallet(storage)
|
||||
wallet.start_threads(self.network)
|
||||
wallet.start_network(self.network)
|
||||
self.wallets[path] = wallet
|
||||
return wallet
|
||||
|
||||
|
|
|
@ -512,7 +512,7 @@ class ElectrumWindow(App):
|
|||
|
||||
def on_wizard_complete(self, wizard, wallet):
|
||||
if wallet: # wizard returned a wallet
|
||||
wallet.start_threads(self.daemon.network)
|
||||
wallet.start_network(self.daemon.network)
|
||||
self.daemon.add_wallet(wallet)
|
||||
self.load_wallet(wallet)
|
||||
elif not self.wallet:
|
||||
|
|
|
@ -236,7 +236,7 @@ class ElectrumGui:
|
|||
|
||||
if not self.daemon.get_wallet(wallet.storage.path):
|
||||
# wallet was not in memory
|
||||
wallet.start_threads(self.daemon.network)
|
||||
wallet.start_network(self.daemon.network)
|
||||
self.daemon.add_wallet(wallet)
|
||||
try:
|
||||
for w in self.windows:
|
||||
|
|
|
@ -34,7 +34,7 @@ class ElectrumGui:
|
|||
self.str_fee = ""
|
||||
|
||||
self.wallet = Wallet(storage)
|
||||
self.wallet.start_threads(self.network)
|
||||
self.wallet.start_network(self.network)
|
||||
self.contacts = self.wallet.contacts
|
||||
|
||||
self.network.register_callback(self.on_network, ['updated', 'banner'])
|
||||
|
|
|
@ -30,7 +30,7 @@ class ElectrumGui:
|
|||
password = getpass.getpass('Password:', stream=None)
|
||||
storage.decrypt(password)
|
||||
self.wallet = Wallet(storage)
|
||||
self.wallet.start_threads(self.network)
|
||||
self.wallet.start_network(self.network)
|
||||
self.contacts = self.wallet.contacts
|
||||
|
||||
locale.setlocale(locale.LC_ALL, '')
|
||||
|
|
|
@ -24,24 +24,21 @@
|
|||
# SOFTWARE.
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from typing import Tuple, Union
|
||||
|
||||
import aiorpcx
|
||||
from aiorpcx import ClientSession, Notification, TaskGroup
|
||||
from aiorpcx import ClientSession, Notification
|
||||
|
||||
from .util import PrintError, aiosafe, bfh, AIOSafeSilentException
|
||||
from .util import PrintError, aiosafe, bfh, AIOSafeSilentException, CustomTaskGroup
|
||||
from . import util
|
||||
from . import x509
|
||||
from . import pem
|
||||
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
|
||||
from . import blockchain
|
||||
from .blockchain import deserialize_header
|
||||
from . import constants
|
||||
|
||||
|
||||
|
@ -83,6 +80,14 @@ class NotificationSession(ClientSession):
|
|||
self.cache[key] = result
|
||||
await queue.put(params + [result])
|
||||
|
||||
def unsubscribe(self, queue):
|
||||
"""Unsubscribe a callback to free object references to enable GC."""
|
||||
# note: we can't unsubscribe from the server, so we keep receiving
|
||||
# subsequent notifications
|
||||
for v in self.subscriptions.values():
|
||||
if queue in v:
|
||||
v.remove(queue)
|
||||
|
||||
|
||||
# FIXME this is often raised inside a TaskGroup, but then it's not silent :(
|
||||
class GracefulDisconnect(AIOSafeSilentException): pass
|
||||
|
@ -94,13 +99,6 @@ class ErrorParsingSSLCert(Exception): pass
|
|||
class ErrorGettingSSLCertFromServer(Exception): pass
|
||||
|
||||
|
||||
class CustomTaskGroup(TaskGroup):
|
||||
|
||||
def spawn(self, *args, **kwargs):
|
||||
if self._closed:
|
||||
raise asyncio.CancelledError()
|
||||
return super().spawn(*args, **kwargs)
|
||||
|
||||
|
||||
def deserialize_server(server_str: str) -> Tuple[str, str, str]:
|
||||
# host might be IPv6 address, hence do rsplit:
|
||||
|
|
|
@ -211,9 +211,6 @@ class Network(PrintError):
|
|||
self.banner = ''
|
||||
self.donation_address = ''
|
||||
self.relay_fee = None
|
||||
# callbacks passed with subscriptions
|
||||
self.subscriptions = defaultdict(list) # note: needs self.callback_lock
|
||||
self.sub_cache = {} # note: needs self.interface_lock
|
||||
# callbacks set by the GUI
|
||||
self.callbacks = defaultdict(list) # note: needs self.callback_lock
|
||||
|
||||
|
@ -272,6 +269,7 @@ class Network(PrintError):
|
|||
callbacks = self.callbacks[event][:]
|
||||
for callback in callbacks:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
# FIXME: if callback throws, we will lose the traceback
|
||||
asyncio.run_coroutine_threadsafe(callback(event, *args), self.asyncio_loop)
|
||||
else:
|
||||
callback(event, *args)
|
||||
|
@ -605,16 +603,6 @@ class Network(PrintError):
|
|||
""" hashable index for subscriptions and cache"""
|
||||
return str(method) + (':' + str(params[0]) if params else '')
|
||||
|
||||
def unsubscribe(self, callback):
|
||||
'''Unsubscribe a callback to free object references to enable GC.'''
|
||||
# Note: we can't unsubscribe from the server, so if we receive
|
||||
# subsequent notifications process_response() will emit a harmless
|
||||
# "received unexpected notification" warning
|
||||
with self.callback_lock:
|
||||
for v in self.subscriptions.values():
|
||||
if callback in v:
|
||||
v.remove(callback)
|
||||
|
||||
@with_interface_lock
|
||||
def connection_down(self, server):
|
||||
'''A connection to server either went down, or was never made.
|
||||
|
|
|
@ -144,16 +144,16 @@ class Synchronizer(PrintError):
|
|||
await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue)
|
||||
self.requested_addrs.remove(addr)
|
||||
|
||||
async def send_subscriptions(self, interface):
|
||||
async def send_subscriptions(self, group: TaskGroup):
|
||||
while True:
|
||||
addr = await self.add_queue.get()
|
||||
await interface.group.spawn(self.subscribe_to_address, addr)
|
||||
await group.spawn(self.subscribe_to_address, addr)
|
||||
|
||||
async def handle_status(self, interface):
|
||||
async def handle_status(self, group: TaskGroup):
|
||||
while True:
|
||||
h, status = await self.status_queue.get()
|
||||
addr = self.scripthash_to_address[h]
|
||||
await interface.group.spawn(self.on_address_status, addr, status)
|
||||
await group.spawn(self.on_address_status, addr, status)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
|
|
|
@ -35,14 +35,15 @@ import stat
|
|||
import inspect
|
||||
from locale import localeconv
|
||||
import asyncio
|
||||
|
||||
from .i18n import _
|
||||
import urllib.request, urllib.parse, urllib.error
|
||||
import queue
|
||||
|
||||
import aiohttp
|
||||
from aiohttp_socks import SocksConnector, SocksVer
|
||||
from aiorpcx import TaskGroup
|
||||
|
||||
from .i18n import _
|
||||
|
||||
import urllib.request, urllib.parse, urllib.error
|
||||
import queue
|
||||
|
||||
def inv_dict(d):
|
||||
return {v: k for k, v in d.items()}
|
||||
|
@ -972,3 +973,12 @@ def make_aiohttp_session(proxy):
|
|||
return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10), connector=connector)
|
||||
else:
|
||||
return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10))
|
||||
|
||||
|
||||
class CustomTaskGroup(TaskGroup):
|
||||
|
||||
def spawn(self, *args, **kwargs):
|
||||
# don't complain if group is already closed.
|
||||
if self._closed:
|
||||
raise asyncio.CancelledError()
|
||||
return super().spawn(*args, **kwargs)
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
import asyncio
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from aiorpcx import TaskGroup
|
||||
|
||||
from .util import ThreadJob, bh2u, VerifiedTxInfo
|
||||
from .bitcoin import Hash, hash_decode, hash_encode
|
||||
from .transaction import Transaction
|
||||
|
@ -47,12 +49,12 @@ class SPV(ThreadJob):
|
|||
self.merkle_roots = {} # txid -> merkle root (once it has been verified)
|
||||
self.requested_merkle = set() # txid set of pending requests
|
||||
|
||||
async def main(self, interface):
|
||||
async def main(self, group: TaskGroup):
|
||||
while True:
|
||||
await self._request_proofs(interface)
|
||||
await self._request_proofs(group)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _request_proofs(self, interface):
|
||||
async def _request_proofs(self, group: TaskGroup):
|
||||
blockchain = self.network.blockchain()
|
||||
if not blockchain:
|
||||
self.print_error("no blockchain")
|
||||
|
@ -70,12 +72,12 @@ class SPV(ThreadJob):
|
|||
if header is None:
|
||||
index = tx_height // 2016
|
||||
if index < len(blockchain.checkpoints):
|
||||
await interface.group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True))
|
||||
await group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True))
|
||||
elif (tx_hash not in self.requested_merkle
|
||||
and tx_hash not in self.merkle_roots):
|
||||
self.print_error('requested merkle', tx_hash)
|
||||
self.requested_merkle.add(tx_hash)
|
||||
await interface.group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height)
|
||||
await group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height)
|
||||
|
||||
if self.network.blockchain() != self.blockchain:
|
||||
self.blockchain = self.network.blockchain()
|
||||
|
|
|
@ -135,7 +135,7 @@ def run_non_RPC(config):
|
|||
if not config.get('offline'):
|
||||
network = Network(config)
|
||||
network.start()
|
||||
wallet.start_threads(network)
|
||||
wallet.start_network(network)
|
||||
print_msg("Recovering wallet...")
|
||||
wallet.synchronize()
|
||||
wallet.wait_until_synchronized()
|
||||
|
|
Loading…
Add table
Reference in a new issue