mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-23 17:47:31 +00:00
restructure synchronizer
fix CLI notify cmd. fix merchant websockets.
This commit is contained in:
parent
788b5b04fe
commit
02f108d927
9 changed files with 249 additions and 178 deletions
|
@ -56,11 +56,9 @@ class AddressSynchronizer(PrintError):
|
||||||
def __init__(self, storage):
|
def __init__(self, storage):
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.network = None
|
self.network = None
|
||||||
# verifier (SPV) and synchronizer are started in start_threads
|
# verifier (SPV) and synchronizer are started in start_network
|
||||||
self.synchronizer = None
|
self.synchronizer = None # type: Synchronizer
|
||||||
self.verifier = None
|
self.verifier = None # type: SPV
|
||||||
self.sync_restart_lock = asyncio.Lock()
|
|
||||||
self.group = None
|
|
||||||
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
|
# locks: if you need to take multiple ones, acquire them in the order they are defined here!
|
||||||
self.lock = threading.RLock()
|
self.lock = threading.RLock()
|
||||||
self.transaction_lock = threading.RLock()
|
self.transaction_lock = threading.RLock()
|
||||||
|
@ -143,45 +141,20 @@ class AddressSynchronizer(PrintError):
|
||||||
# add it in case it was previously unconfirmed
|
# add it in case it was previously unconfirmed
|
||||||
self.add_unverified_tx(tx_hash, tx_height)
|
self.add_unverified_tx(tx_hash, tx_height)
|
||||||
|
|
||||||
@aiosafe
|
|
||||||
async def on_default_server_changed(self, event):
|
|
||||||
async with self.sync_restart_lock:
|
|
||||||
self.stop_threads(write_to_disk=False)
|
|
||||||
await self._start_threads()
|
|
||||||
|
|
||||||
def start_network(self, network):
|
def start_network(self, network):
|
||||||
self.network = network
|
self.network = network
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
self.network.register_callback(self.on_default_server_changed, ['default_server_changed'])
|
self.synchronizer = Synchronizer(self)
|
||||||
asyncio.run_coroutine_threadsafe(self._start_threads(), network.asyncio_loop)
|
|
||||||
|
|
||||||
async def _start_threads(self):
|
|
||||||
interface = self.network.interface
|
|
||||||
if interface is None:
|
|
||||||
return # we should get called again soon
|
|
||||||
|
|
||||||
self.verifier = SPV(self.network, self)
|
self.verifier = SPV(self.network, self)
|
||||||
self.synchronizer = synchronizer = Synchronizer(self)
|
|
||||||
assert self.group is None, 'group already exists'
|
|
||||||
self.group = SilentTaskGroup()
|
|
||||||
|
|
||||||
async def job():
|
|
||||||
async with self.group as group:
|
|
||||||
await group.spawn(self.verifier.main(group))
|
|
||||||
await group.spawn(self.synchronizer.send_subscriptions(group))
|
|
||||||
await group.spawn(self.synchronizer.handle_status(group))
|
|
||||||
await group.spawn(self.synchronizer.main())
|
|
||||||
# we are being cancelled now
|
|
||||||
interface.session.unsubscribe(synchronizer.status_queue)
|
|
||||||
await interface.group.spawn(job)
|
|
||||||
|
|
||||||
def stop_threads(self, write_to_disk=True):
|
def stop_threads(self, write_to_disk=True):
|
||||||
if self.network:
|
if self.network:
|
||||||
|
if self.synchronizer:
|
||||||
|
asyncio.run_coroutine_threadsafe(self.synchronizer.stop(), self.network.asyncio_loop)
|
||||||
self.synchronizer = None
|
self.synchronizer = None
|
||||||
|
if self.verifier:
|
||||||
|
asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop)
|
||||||
self.verifier = None
|
self.verifier = None
|
||||||
if self.group:
|
|
||||||
asyncio.run_coroutine_threadsafe(self.group.cancel_remaining(), self.network.asyncio_loop)
|
|
||||||
self.group = None
|
|
||||||
self.storage.put('stored_height', self.get_local_height())
|
self.storage.put('stored_height', self.get_local_height())
|
||||||
if write_to_disk:
|
if write_to_disk:
|
||||||
self.save_transactions()
|
self.save_transactions()
|
||||||
|
|
|
@ -40,7 +40,7 @@ from .bitcoin import is_address, hash_160, COIN, TYPE_ADDRESS
|
||||||
from .i18n import _
|
from .i18n import _
|
||||||
from .transaction import Transaction, multisig_script, TxOutput
|
from .transaction import Transaction, multisig_script, TxOutput
|
||||||
from .paymentrequest import PR_PAID, PR_UNPAID, PR_UNKNOWN, PR_EXPIRED
|
from .paymentrequest import PR_PAID, PR_UNPAID, PR_UNKNOWN, PR_EXPIRED
|
||||||
from .plugin import run_hook
|
from .synchronizer import Notifier
|
||||||
|
|
||||||
known_commands = {}
|
known_commands = {}
|
||||||
|
|
||||||
|
@ -635,21 +635,11 @@ class Commands:
|
||||||
self.wallet.remove_payment_request(k, self.config)
|
self.wallet.remove_payment_request(k, self.config)
|
||||||
|
|
||||||
@command('n')
|
@command('n')
|
||||||
def notify(self, address, URL):
|
def notify(self, address: str, URL: str):
|
||||||
"""Watch an address. Every time the address changes, a http POST is sent to the URL."""
|
"""Watch an address. Every time the address changes, a http POST is sent to the URL."""
|
||||||
raise NotImplementedError() # TODO this method is currently broken
|
if not hasattr(self, "_notifier"):
|
||||||
def callback(x):
|
self._notifier = Notifier(self.network)
|
||||||
import urllib.request
|
self.network.run_from_another_thread(self._notifier.start_watching_queue.put((address, URL)))
|
||||||
headers = {'content-type':'application/json'}
|
|
||||||
data = {'address':address, 'status':x.get('result')}
|
|
||||||
serialized_data = util.to_bytes(json.dumps(data))
|
|
||||||
try:
|
|
||||||
req = urllib.request.Request(URL, serialized_data, headers)
|
|
||||||
response_stream = urllib.request.urlopen(req, timeout=5)
|
|
||||||
util.print_error('Got Response for %s' % address)
|
|
||||||
except BaseException as e:
|
|
||||||
util.print_error(str(e))
|
|
||||||
self.network.subscribe_to_addresses([address], callback)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@command('wn')
|
@command('wn')
|
||||||
|
|
|
@ -28,7 +28,7 @@ import ssl
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union, List
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import aiorpcx
|
import aiorpcx
|
||||||
|
@ -57,7 +57,7 @@ class NotificationSession(ClientSession):
|
||||||
# will catch the exception, count errors, and at some point disconnect
|
# will catch the exception, count errors, and at some point disconnect
|
||||||
if isinstance(request, Notification):
|
if isinstance(request, Notification):
|
||||||
params, result = request.args[:-1], request.args[-1]
|
params, result = request.args[:-1], request.args[-1]
|
||||||
key = self.get_index(request.method, params)
|
key = self.get_hashable_key_for_rpc_call(request.method, params)
|
||||||
if key in self.subscriptions:
|
if key in self.subscriptions:
|
||||||
self.cache[key] = result
|
self.cache[key] = result
|
||||||
for queue in self.subscriptions[key]:
|
for queue in self.subscriptions[key]:
|
||||||
|
@ -78,10 +78,10 @@ class NotificationSession(ClientSession):
|
||||||
except asyncio.TimeoutError as e:
|
except asyncio.TimeoutError as e:
|
||||||
raise RequestTimedOut('request timed out: {}'.format(args)) from e
|
raise RequestTimedOut('request timed out: {}'.format(args)) from e
|
||||||
|
|
||||||
async def subscribe(self, method, params, queue):
|
async def subscribe(self, method: str, params: List, queue: asyncio.Queue):
|
||||||
# note: until the cache is written for the first time,
|
# note: until the cache is written for the first time,
|
||||||
# each 'subscribe' call might make a request on the network.
|
# each 'subscribe' call might make a request on the network.
|
||||||
key = self.get_index(method, params)
|
key = self.get_hashable_key_for_rpc_call(method, params)
|
||||||
self.subscriptions[key].append(queue)
|
self.subscriptions[key].append(queue)
|
||||||
if key in self.cache:
|
if key in self.cache:
|
||||||
result = self.cache[key]
|
result = self.cache[key]
|
||||||
|
@ -99,7 +99,7 @@ class NotificationSession(ClientSession):
|
||||||
v.remove(queue)
|
v.remove(queue)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_index(cls, method, params):
|
def get_hashable_key_for_rpc_call(cls, method, params):
|
||||||
"""Hashable index for subscriptions and cache"""
|
"""Hashable index for subscriptions and cache"""
|
||||||
return str(method) + repr(params)
|
return str(method) + repr(params)
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ class Interface(PrintError):
|
||||||
self._requested_chunks = set()
|
self._requested_chunks = set()
|
||||||
self.network = network
|
self.network = network
|
||||||
self._set_proxy(proxy)
|
self._set_proxy(proxy)
|
||||||
self.session = None
|
self.session = None # type: NotificationSession
|
||||||
|
|
||||||
self.tip_header = None
|
self.tip_header = None
|
||||||
self.tip = 0
|
self.tip = 0
|
||||||
|
|
|
@ -852,3 +852,54 @@ class Network(PrintError):
|
||||||
await self.interface.group.spawn(self._request_fee_estimates, self.interface)
|
await self.interface.group.spawn(self._request_fee_estimates, self.interface)
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkJobOnDefaultServer(PrintError):
|
||||||
|
"""An abstract base class for a job that runs on the main network
|
||||||
|
interface. Every time the main interface changes, the job is
|
||||||
|
restarted, and some of its internals are reset.
|
||||||
|
"""
|
||||||
|
def __init__(self, network: Network):
|
||||||
|
asyncio.set_event_loop(network.asyncio_loop)
|
||||||
|
self.network = network
|
||||||
|
self.interface = None # type: Interface
|
||||||
|
self._restart_lock = asyncio.Lock()
|
||||||
|
self._reset()
|
||||||
|
asyncio.run_coroutine_threadsafe(self._restart(), network.asyncio_loop)
|
||||||
|
network.register_callback(self._restart, ['default_server_changed'])
|
||||||
|
|
||||||
|
def _reset(self):
|
||||||
|
"""Initialise fields. Called every time the underlying
|
||||||
|
server connection changes.
|
||||||
|
"""
|
||||||
|
self.group = SilentTaskGroup()
|
||||||
|
|
||||||
|
async def _start(self, interface):
|
||||||
|
self.interface = interface
|
||||||
|
await interface.group.spawn(self._start_tasks)
|
||||||
|
|
||||||
|
async def _start_tasks(self):
|
||||||
|
"""Start tasks in self.group. Called every time the underlying
|
||||||
|
server connection changes.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError() # implemented by subclasses
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
await self.group.cancel_remaining()
|
||||||
|
|
||||||
|
@aiosafe
|
||||||
|
async def _restart(self, *args):
|
||||||
|
interface = self.network.interface
|
||||||
|
if interface is None:
|
||||||
|
return # we should get called again soon
|
||||||
|
|
||||||
|
async with self._restart_lock:
|
||||||
|
await self.stop()
|
||||||
|
self._reset()
|
||||||
|
await self._start(interface)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self):
|
||||||
|
s = self.interface.session
|
||||||
|
assert s is not None
|
||||||
|
return s
|
||||||
|
|
|
@ -24,12 +24,15 @@
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from typing import Dict, List
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from aiorpcx import TaskGroup, run_in_thread
|
from aiorpcx import TaskGroup, run_in_thread
|
||||||
|
|
||||||
from .transaction import Transaction
|
from .transaction import Transaction
|
||||||
from .util import bh2u, PrintError
|
from .util import bh2u, make_aiohttp_session
|
||||||
from .bitcoin import address_to_scripthash
|
from .bitcoin import address_to_scripthash
|
||||||
|
from .network import NetworkJobOnDefaultServer
|
||||||
|
|
||||||
|
|
||||||
def history_status(h):
|
def history_status(h):
|
||||||
|
@ -41,7 +44,68 @@ def history_status(h):
|
||||||
return bh2u(hashlib.sha256(status.encode('ascii')).digest())
|
return bh2u(hashlib.sha256(status.encode('ascii')).digest())
|
||||||
|
|
||||||
|
|
||||||
class Synchronizer(PrintError):
|
class SynchronizerBase(NetworkJobOnDefaultServer):
|
||||||
|
"""Subscribe over the network to a set of addresses, and monitor their statuses.
|
||||||
|
Every time a status changes, run a coroutine provided by the subclass.
|
||||||
|
"""
|
||||||
|
def __init__(self, network):
|
||||||
|
NetworkJobOnDefaultServer.__init__(self, network)
|
||||||
|
self.asyncio_loop = network.asyncio_loop
|
||||||
|
|
||||||
|
def _reset(self):
|
||||||
|
super()._reset()
|
||||||
|
self.requested_addrs = set()
|
||||||
|
self.scripthash_to_address = {}
|
||||||
|
self._processed_some_notifications = False # so that we don't miss them
|
||||||
|
# Queues
|
||||||
|
self.add_queue = asyncio.Queue()
|
||||||
|
self.status_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def _start_tasks(self):
|
||||||
|
try:
|
||||||
|
async with self.group as group:
|
||||||
|
await group.spawn(self.send_subscriptions())
|
||||||
|
await group.spawn(self.handle_status())
|
||||||
|
await group.spawn(self.main())
|
||||||
|
finally:
|
||||||
|
# we are being cancelled now
|
||||||
|
self.session.unsubscribe(self.status_queue)
|
||||||
|
|
||||||
|
def add(self, addr):
|
||||||
|
asyncio.run_coroutine_threadsafe(self._add_address(addr), self.asyncio_loop)
|
||||||
|
|
||||||
|
async def _add_address(self, addr):
|
||||||
|
if addr in self.requested_addrs: return
|
||||||
|
self.requested_addrs.add(addr)
|
||||||
|
await self.add_queue.put(addr)
|
||||||
|
|
||||||
|
async def _on_address_status(self, addr, status):
|
||||||
|
"""Handle the change of the status of an address."""
|
||||||
|
raise NotImplementedError() # implemented by subclasses
|
||||||
|
|
||||||
|
async def send_subscriptions(self):
|
||||||
|
async def subscribe_to_address(addr):
|
||||||
|
h = address_to_scripthash(addr)
|
||||||
|
self.scripthash_to_address[h] = addr
|
||||||
|
await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue)
|
||||||
|
self.requested_addrs.remove(addr)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
addr = await self.add_queue.get()
|
||||||
|
await self.group.spawn(subscribe_to_address, addr)
|
||||||
|
|
||||||
|
async def handle_status(self):
|
||||||
|
while True:
|
||||||
|
h, status = await self.status_queue.get()
|
||||||
|
addr = self.scripthash_to_address[h]
|
||||||
|
await self.group.spawn(self._on_address_status, addr, status)
|
||||||
|
self._processed_some_notifications = True
|
||||||
|
|
||||||
|
async def main(self):
|
||||||
|
raise NotImplementedError() # implemented by subclasses
|
||||||
|
|
||||||
|
|
||||||
|
class Synchronizer(SynchronizerBase):
|
||||||
'''The synchronizer keeps the wallet up-to-date with its set of
|
'''The synchronizer keeps the wallet up-to-date with its set of
|
||||||
addresses and their transactions. It subscribes over the network
|
addresses and their transactions. It subscribes over the network
|
||||||
to wallet addresses, gets the wallet to generate new addresses
|
to wallet addresses, gets the wallet to generate new addresses
|
||||||
|
@ -51,16 +115,12 @@ class Synchronizer(PrintError):
|
||||||
'''
|
'''
|
||||||
def __init__(self, wallet):
|
def __init__(self, wallet):
|
||||||
self.wallet = wallet
|
self.wallet = wallet
|
||||||
self.network = wallet.network
|
SynchronizerBase.__init__(self, wallet.network)
|
||||||
self.asyncio_loop = wallet.network.asyncio_loop
|
|
||||||
|
def _reset(self):
|
||||||
|
super()._reset()
|
||||||
self.requested_tx = {}
|
self.requested_tx = {}
|
||||||
self.requested_histories = {}
|
self.requested_histories = {}
|
||||||
self.requested_addrs = set()
|
|
||||||
self.scripthash_to_address = {}
|
|
||||||
self._processed_some_notifications = False # so that we don't miss them
|
|
||||||
# Queues
|
|
||||||
self.add_queue = asyncio.Queue()
|
|
||||||
self.status_queue = asyncio.Queue()
|
|
||||||
|
|
||||||
def diagnostic_name(self):
|
def diagnostic_name(self):
|
||||||
return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name())
|
return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name())
|
||||||
|
@ -70,14 +130,6 @@ class Synchronizer(PrintError):
|
||||||
and not self.requested_histories
|
and not self.requested_histories
|
||||||
and not self.requested_tx)
|
and not self.requested_tx)
|
||||||
|
|
||||||
def add(self, addr):
|
|
||||||
asyncio.run_coroutine_threadsafe(self._add(addr), self.asyncio_loop)
|
|
||||||
|
|
||||||
async def _add(self, addr):
|
|
||||||
if addr in self.requested_addrs: return
|
|
||||||
self.requested_addrs.add(addr)
|
|
||||||
await self.add_queue.put(addr)
|
|
||||||
|
|
||||||
async def _on_address_status(self, addr, status):
|
async def _on_address_status(self, addr, status):
|
||||||
history = self.wallet.history.get(addr, [])
|
history = self.wallet.history.get(addr, [])
|
||||||
if history_status(history) == status:
|
if history_status(history) == status:
|
||||||
|
@ -144,30 +196,6 @@ class Synchronizer(PrintError):
|
||||||
# callbacks
|
# callbacks
|
||||||
self.wallet.network.trigger_callback('new_transaction', self.wallet, tx)
|
self.wallet.network.trigger_callback('new_transaction', self.wallet, tx)
|
||||||
|
|
||||||
async def send_subscriptions(self, group: TaskGroup):
|
|
||||||
async def subscribe_to_address(addr):
|
|
||||||
h = address_to_scripthash(addr)
|
|
||||||
self.scripthash_to_address[h] = addr
|
|
||||||
await self.session.subscribe('blockchain.scripthash.subscribe', [h], self.status_queue)
|
|
||||||
self.requested_addrs.remove(addr)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
addr = await self.add_queue.get()
|
|
||||||
await group.spawn(subscribe_to_address, addr)
|
|
||||||
|
|
||||||
async def handle_status(self, group: TaskGroup):
|
|
||||||
while True:
|
|
||||||
h, status = await self.status_queue.get()
|
|
||||||
addr = self.scripthash_to_address[h]
|
|
||||||
await group.spawn(self._on_address_status, addr, status)
|
|
||||||
self._processed_some_notifications = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def session(self):
|
|
||||||
s = self.wallet.network.interface.session
|
|
||||||
assert s is not None
|
|
||||||
return s
|
|
||||||
|
|
||||||
async def main(self):
|
async def main(self):
|
||||||
self.wallet.set_up_to_date(False)
|
self.wallet.set_up_to_date(False)
|
||||||
# request missing txns, if any
|
# request missing txns, if any
|
||||||
|
@ -178,7 +206,7 @@ class Synchronizer(PrintError):
|
||||||
await self._request_missing_txs(history)
|
await self._request_missing_txs(history)
|
||||||
# add addresses to bootstrap
|
# add addresses to bootstrap
|
||||||
for addr in self.wallet.get_addresses():
|
for addr in self.wallet.get_addresses():
|
||||||
await self._add(addr)
|
await self._add_address(addr)
|
||||||
# main loop
|
# main loop
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
@ -189,3 +217,37 @@ class Synchronizer(PrintError):
|
||||||
self._processed_some_notifications = False
|
self._processed_some_notifications = False
|
||||||
self.wallet.set_up_to_date(up_to_date)
|
self.wallet.set_up_to_date(up_to_date)
|
||||||
self.wallet.network.trigger_callback('wallet_updated', self.wallet)
|
self.wallet.network.trigger_callback('wallet_updated', self.wallet)
|
||||||
|
|
||||||
|
|
||||||
|
class Notifier(SynchronizerBase):
|
||||||
|
"""Watch addresses. Every time the status of an address changes,
|
||||||
|
an HTTP POST is sent to the corresponding URL.
|
||||||
|
"""
|
||||||
|
def __init__(self, network):
|
||||||
|
SynchronizerBase.__init__(self, network)
|
||||||
|
self.watched_addresses = defaultdict(list) # type: Dict[str, List[str]]
|
||||||
|
self.start_watching_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def main(self):
|
||||||
|
# resend existing subscriptions if we were restarted
|
||||||
|
for addr in self.watched_addresses:
|
||||||
|
await self._add_address(addr)
|
||||||
|
# main loop
|
||||||
|
while True:
|
||||||
|
addr, url = await self.start_watching_queue.get()
|
||||||
|
self.watched_addresses[addr].append(url)
|
||||||
|
await self._add_address(addr)
|
||||||
|
|
||||||
|
async def _on_address_status(self, addr, status):
|
||||||
|
self.print_error('new status for addr {}'.format(addr))
|
||||||
|
headers = {'content-type': 'application/json'}
|
||||||
|
data = {'address': addr, 'status': status}
|
||||||
|
for url in self.watched_addresses[addr]:
|
||||||
|
try:
|
||||||
|
async with make_aiohttp_session(proxy=self.network.proxy, headers=headers) as session:
|
||||||
|
async with session.post(url, json=data, headers=headers) as resp:
|
||||||
|
await resp.text()
|
||||||
|
except Exception as e:
|
||||||
|
self.print_error(str(e))
|
||||||
|
else:
|
||||||
|
self.print_error('Got Response for {}'.format(addr))
|
||||||
|
|
|
@ -869,7 +869,12 @@ VerifiedTxInfo = NamedTuple("VerifiedTxInfo", [("height", int),
|
||||||
("txpos", int),
|
("txpos", int),
|
||||||
("header_hash", str)])
|
("header_hash", str)])
|
||||||
|
|
||||||
def make_aiohttp_session(proxy):
|
|
||||||
|
def make_aiohttp_session(proxy: dict, headers=None, timeout=None):
|
||||||
|
if headers is None:
|
||||||
|
headers = {'User-Agent': 'Electrum'}
|
||||||
|
if timeout is None:
|
||||||
|
timeout = aiohttp.ClientTimeout(total=10)
|
||||||
if proxy:
|
if proxy:
|
||||||
connector = SocksConnector(
|
connector = SocksConnector(
|
||||||
socks_ver=SocksVer.SOCKS5 if proxy['mode'] == 'socks5' else SocksVer.SOCKS4,
|
socks_ver=SocksVer.SOCKS5 if proxy['mode'] == 'socks5' else SocksVer.SOCKS4,
|
||||||
|
@ -879,9 +884,9 @@ def make_aiohttp_session(proxy):
|
||||||
password=proxy.get('password', None),
|
password=proxy.get('password', None),
|
||||||
rdns=True
|
rdns=True
|
||||||
)
|
)
|
||||||
return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10), connector=connector)
|
return aiohttp.ClientSession(headers=headers, timeout=timeout, connector=connector)
|
||||||
else:
|
else:
|
||||||
return aiohttp.ClientSession(headers={'User-Agent' : 'Electrum'}, timeout=aiohttp.ClientTimeout(total=10))
|
return aiohttp.ClientSession(headers=headers, timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
class SilentTaskGroup(TaskGroup):
|
class SilentTaskGroup(TaskGroup):
|
||||||
|
|
|
@ -25,14 +25,14 @@ import asyncio
|
||||||
from typing import Sequence, Optional
|
from typing import Sequence, Optional
|
||||||
|
|
||||||
import aiorpcx
|
import aiorpcx
|
||||||
from aiorpcx import TaskGroup
|
|
||||||
|
|
||||||
from .util import PrintError, bh2u, VerifiedTxInfo
|
from .util import bh2u, VerifiedTxInfo
|
||||||
from .bitcoin import Hash, hash_decode, hash_encode
|
from .bitcoin import Hash, hash_decode, hash_encode
|
||||||
from .transaction import Transaction
|
from .transaction import Transaction
|
||||||
from .blockchain import hash_header
|
from .blockchain import hash_header
|
||||||
from .interface import GracefulDisconnect
|
from .interface import GracefulDisconnect
|
||||||
from . import constants
|
from . import constants
|
||||||
|
from .network import NetworkJobOnDefaultServer
|
||||||
|
|
||||||
|
|
||||||
class MerkleVerificationFailure(Exception): pass
|
class MerkleVerificationFailure(Exception): pass
|
||||||
|
@ -41,26 +41,33 @@ class MerkleRootMismatch(MerkleVerificationFailure): pass
|
||||||
class InnerNodeOfSpvProofIsValidTx(MerkleVerificationFailure): pass
|
class InnerNodeOfSpvProofIsValidTx(MerkleVerificationFailure): pass
|
||||||
|
|
||||||
|
|
||||||
class SPV(PrintError):
|
class SPV(NetworkJobOnDefaultServer):
|
||||||
""" Simple Payment Verification """
|
""" Simple Payment Verification """
|
||||||
|
|
||||||
def __init__(self, network, wallet):
|
def __init__(self, network, wallet):
|
||||||
|
NetworkJobOnDefaultServer.__init__(self, network)
|
||||||
self.wallet = wallet
|
self.wallet = wallet
|
||||||
self.network = network
|
|
||||||
|
def _reset(self):
|
||||||
|
super()._reset()
|
||||||
self.merkle_roots = {} # txid -> merkle root (once it has been verified)
|
self.merkle_roots = {} # txid -> merkle root (once it has been verified)
|
||||||
self.requested_merkle = set() # txid set of pending requests
|
self.requested_merkle = set() # txid set of pending requests
|
||||||
|
|
||||||
|
async def _start_tasks(self):
|
||||||
|
async with self.group as group:
|
||||||
|
await group.spawn(self.main)
|
||||||
|
|
||||||
def diagnostic_name(self):
|
def diagnostic_name(self):
|
||||||
return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name())
|
return '{}:{}'.format(self.__class__.__name__, self.wallet.diagnostic_name())
|
||||||
|
|
||||||
async def main(self, group: TaskGroup):
|
async def main(self):
|
||||||
self.blockchain = self.network.blockchain()
|
self.blockchain = self.network.blockchain()
|
||||||
while True:
|
while True:
|
||||||
await self._maybe_undo_verifications()
|
await self._maybe_undo_verifications()
|
||||||
await self._request_proofs(group)
|
await self._request_proofs()
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
async def _request_proofs(self, group: TaskGroup):
|
async def _request_proofs(self):
|
||||||
local_height = self.blockchain.height()
|
local_height = self.blockchain.height()
|
||||||
unverified = self.wallet.get_unverified_txs()
|
unverified = self.wallet.get_unverified_txs()
|
||||||
|
|
||||||
|
@ -75,12 +82,12 @@ class SPV(PrintError):
|
||||||
header = self.blockchain.read_header(tx_height)
|
header = self.blockchain.read_header(tx_height)
|
||||||
if header is None:
|
if header is None:
|
||||||
if tx_height < constants.net.max_checkpoint():
|
if tx_height < constants.net.max_checkpoint():
|
||||||
await group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True))
|
await self.group.spawn(self.network.request_chunk(tx_height, None, can_return_early=True))
|
||||||
continue
|
continue
|
||||||
# request now
|
# request now
|
||||||
self.print_error('requested merkle', tx_hash)
|
self.print_error('requested merkle', tx_hash)
|
||||||
self.requested_merkle.add(tx_hash)
|
self.requested_merkle.add(tx_hash)
|
||||||
await group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height)
|
await self.group.spawn(self._request_and_verify_single_proof, tx_hash, tx_height)
|
||||||
|
|
||||||
async def _request_and_verify_single_proof(self, tx_hash, tx_height):
|
async def _request_and_verify_single_proof(self, tx_hash, tx_height):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -22,44 +22,49 @@
|
||||||
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
import queue
|
import threading
|
||||||
import threading, os, json
|
import os
|
||||||
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, List
|
||||||
|
import traceback
|
||||||
|
import sys
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from SimpleWebSocketServer import WebSocket, SimpleSSLWebSocketServer
|
from SimpleWebSocketServer import WebSocket, SimpleSSLWebSocketServer
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import sys
|
|
||||||
sys.exit("install SimpleWebSocketServer")
|
sys.exit("install SimpleWebSocketServer")
|
||||||
|
|
||||||
from . import util
|
from .util import PrintError
|
||||||
from . import bitcoin
|
from . import bitcoin
|
||||||
|
from .synchronizer import SynchronizerBase
|
||||||
|
|
||||||
request_queue = queue.Queue()
|
request_queue = asyncio.Queue()
|
||||||
|
|
||||||
class ElectrumWebSocket(WebSocket):
|
|
||||||
|
class ElectrumWebSocket(WebSocket, PrintError):
|
||||||
|
|
||||||
def handleMessage(self):
|
def handleMessage(self):
|
||||||
assert self.data[0:3] == 'id:'
|
assert self.data[0:3] == 'id:'
|
||||||
util.print_error("message received", self.data)
|
self.print_error("message received", self.data)
|
||||||
request_id = self.data[3:]
|
request_id = self.data[3:]
|
||||||
request_queue.put((self, request_id))
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
request_queue.put((self, request_id)), asyncio.get_event_loop())
|
||||||
|
|
||||||
def handleConnected(self):
|
def handleConnected(self):
|
||||||
util.print_error("connected", self.address)
|
self.print_error("connected", self.address)
|
||||||
|
|
||||||
def handleClose(self):
|
def handleClose(self):
|
||||||
util.print_error("closed", self.address)
|
self.print_error("closed", self.address)
|
||||||
|
|
||||||
|
|
||||||
|
class BalanceMonitor(SynchronizerBase):
|
||||||
class WsClientThread(util.DaemonThread):
|
|
||||||
|
|
||||||
def __init__(self, config, network):
|
def __init__(self, config, network):
|
||||||
util.DaemonThread.__init__(self)
|
SynchronizerBase.__init__(self, network)
|
||||||
self.network = network
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.response_queue = queue.Queue()
|
self.expected_payments = defaultdict(list) # type: Dict[str, List[WebSocket, int]]
|
||||||
self.subscriptions = defaultdict(list)
|
|
||||||
|
|
||||||
def make_request(self, request_id):
|
def make_request(self, request_id):
|
||||||
# read json file
|
# read json file
|
||||||
|
@ -72,69 +77,47 @@ class WsClientThread(util.DaemonThread):
|
||||||
amount = d.get('amount')
|
amount = d.get('amount')
|
||||||
return addr, amount
|
return addr, amount
|
||||||
|
|
||||||
def reading_thread(self):
|
async def main(self):
|
||||||
while self.is_running():
|
# resend existing subscriptions if we were restarted
|
||||||
try:
|
for addr in self.expected_payments:
|
||||||
ws, request_id = request_queue.get()
|
await self._add_address(addr)
|
||||||
except queue.Empty:
|
# main loop
|
||||||
continue
|
while True:
|
||||||
|
ws, request_id = await request_queue.get()
|
||||||
try:
|
try:
|
||||||
addr, amount = self.make_request(request_id)
|
addr, amount = self.make_request(request_id)
|
||||||
except:
|
except Exception:
|
||||||
|
traceback.print_exc(file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
l = self.subscriptions.get(addr, [])
|
self.expected_payments[addr].append((ws, amount))
|
||||||
l.append((ws, amount))
|
await self._add_address(addr)
|
||||||
self.subscriptions[addr] = l
|
|
||||||
self.network.subscribe_to_addresses([addr], self.response_queue.put)
|
|
||||||
|
|
||||||
def run(self):
|
async def _on_address_status(self, addr, status):
|
||||||
threading.Thread(target=self.reading_thread).start()
|
self.print_error('new status for addr {}'.format(addr))
|
||||||
while self.is_running():
|
sh = bitcoin.address_to_scripthash(addr)
|
||||||
try:
|
balance = await self.network.get_balance_for_scripthash(sh)
|
||||||
r = self.response_queue.get(timeout=0.1)
|
for ws, amount in self.expected_payments[addr]:
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
util.print_error('response', r)
|
|
||||||
method = r.get('method')
|
|
||||||
result = r.get('result')
|
|
||||||
if result is None:
|
|
||||||
continue
|
|
||||||
if method == 'blockchain.scripthash.subscribe':
|
|
||||||
addr = r.get('params')[0]
|
|
||||||
scripthash = bitcoin.address_to_scripthash(addr)
|
|
||||||
self.network.get_balance_for_scripthash(
|
|
||||||
scripthash, self.response_queue.put)
|
|
||||||
elif method == 'blockchain.scripthash.get_balance':
|
|
||||||
scripthash = r.get('params')[0]
|
|
||||||
addr = self.network.h2addr.get(scripthash, None)
|
|
||||||
if addr is None:
|
|
||||||
util.print_error(
|
|
||||||
"can't find address for scripthash: %s" % scripthash)
|
|
||||||
l = self.subscriptions.get(addr, [])
|
|
||||||
for ws, amount in l:
|
|
||||||
if not ws.closed:
|
if not ws.closed:
|
||||||
if sum(result.values()) >=amount:
|
if sum(balance.values()) >= amount:
|
||||||
ws.sendMessage('paid')
|
ws.sendMessage('paid')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketServer(threading.Thread):
|
class WebSocketServer(threading.Thread):
|
||||||
|
|
||||||
def __init__(self, config, ns):
|
def __init__(self, config, network):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.net_server = ns
|
self.network = network
|
||||||
|
asyncio.set_event_loop(network.asyncio_loop)
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
|
self.balance_monitor = BalanceMonitor(self.config, self.network)
|
||||||
|
self.start()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
t = WsClientThread(self.config, self.net_server)
|
asyncio.set_event_loop(self.network.asyncio_loop)
|
||||||
t.start()
|
|
||||||
|
|
||||||
host = self.config.get('websocket_server')
|
host = self.config.get('websocket_server')
|
||||||
port = self.config.get('websocket_port', 9999)
|
port = self.config.get('websocket_port', 9999)
|
||||||
certfile = self.config.get('ssl_chain')
|
certfile = self.config.get('ssl_chain')
|
||||||
keyfile = self.config.get('ssl_privkey')
|
keyfile = self.config.get('ssl_privkey')
|
||||||
self.server = SimpleSSLWebSocketServer(host, port, ElectrumWebSocket, certfile, keyfile)
|
self.server = SimpleSSLWebSocketServer(host, port, ElectrumWebSocket, certfile, keyfile)
|
||||||
self.server.serveforever()
|
self.server.serveforever()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -438,7 +438,7 @@ if __name__ == '__main__':
|
||||||
d = daemon.Daemon(config, fd)
|
d = daemon.Daemon(config, fd)
|
||||||
if config.get('websocket_server'):
|
if config.get('websocket_server'):
|
||||||
from electrum import websockets
|
from electrum import websockets
|
||||||
websockets.WebSocketServer(config, d.network).start()
|
websockets.WebSocketServer(config, d.network)
|
||||||
if config.get('requests_dir'):
|
if config.get('requests_dir'):
|
||||||
path = os.path.join(config.get('requests_dir'), 'index.html')
|
path = os.path.join(config.get('requests_dir'), 'index.html')
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
|
|
Loading…
Add table
Reference in a new issue