From 899a6f0d4a93b59feacab994159b5d26a4878b04 Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Wed, 5 Dec 2018 00:40:06 -0500 Subject: [PATCH] merged aiorpcx into torba.rpc --- setup.py | 1 - torba/rpc/__init__.py | 13 + torba/rpc/curio.py | 411 ++++++++++++++++ torba/rpc/framing.py | 239 ++++++++++ torba/rpc/jsonrpc.py | 801 ++++++++++++++++++++++++++++++++ torba/rpc/session.py | 549 ++++++++++++++++++++++ torba/rpc/socks.py | 439 +++++++++++++++++ torba/rpc/util.py | 120 +++++ torba/server/block_processor.py | 2 +- torba/server/daemon.py | 2 +- torba/server/db.py | 2 +- torba/server/mempool.py | 2 +- torba/server/merkle.py | 3 +- torba/server/peers.py | 3 +- torba/server/session.py | 9 +- 15 files changed, 2582 insertions(+), 14 deletions(-) create mode 100644 torba/rpc/__init__.py create mode 100644 torba/rpc/curio.py create mode 100644 torba/rpc/framing.py create mode 100644 torba/rpc/jsonrpc.py create mode 100644 torba/rpc/session.py create mode 100644 torba/rpc/socks.py create mode 100644 torba/rpc/util.py diff --git a/setup.py b/setup.py index 2ba288f98..075bd0c8f 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,6 @@ with open(os.path.join(BASE, 'README.md'), encoding='utf-8') as fh: REQUIRES = [ 'aiohttp', - 'aiorpcx==0.9.0', 'coincurve', 'pbkdf2', 'cryptography', diff --git a/torba/rpc/__init__.py b/torba/rpc/__init__.py new file mode 100644 index 000000000..ce96bd517 --- /dev/null +++ b/torba/rpc/__init__.py @@ -0,0 +1,13 @@ +from .curio import * +from .framing import * +from .jsonrpc import * +from .socks import * +from .session import * +from .util import * + +__all__ = (curio.__all__ + + framing.__all__ + + jsonrpc.__all__ + + socks.__all__ + + session.__all__ + + util.__all__) diff --git a/torba/rpc/curio.py b/torba/rpc/curio.py new file mode 100644 index 000000000..2e9d53bff --- /dev/null +++ b/torba/rpc/curio.py @@ -0,0 +1,411 @@ +# The code below is mostly my own but based on the interfaces of the +# curio library by David Beazley. I'm considering switching to using +# curio. In the mean-time this is an attempt to provide a similar +# clean, pure-async interface and move away from direct +# framework-specific dependencies. As asyncio differs in its design +# it is not possible to provide identical semantics. +# +# The curio library is distributed under the following licence: +# +# Copyright (C) 2015-2017 +# David Beazley (Dabeaz LLC) +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the David Beazley or Dabeaz LLC may be used to +# endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import logging +import asyncio +from asyncio import ( + CancelledError, get_event_loop, Queue, Event, Lock, Semaphore, + sleep, Task +) +from collections import deque +from contextlib import suppress +from functools import partial + +from .util import normalize_corofunc, check_task + + +__all__ = ( + 'Queue', 'Event', 'Lock', 'Semaphore', 'sleep', 'CancelledError', + 'run_in_thread', 'spawn', 'spawn_sync', 'TaskGroup', + 'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError', + 'timeout_after', 'timeout_at', 'ignore_after', 'ignore_at', +) + + +async def run_in_thread(func, *args): + '''Run a function in a separate thread, and await its completion.''' + return await get_event_loop().run_in_executor(None, func, *args) + + +async def spawn(coro, *args, loop=None, report_crash=True): + return spawn_sync(coro, *args, loop=loop, report_crash=report_crash) + + +def spawn_sync(coro, *args, loop=None, report_crash=True): + coro = normalize_corofunc(coro, args) + loop = loop or get_event_loop() + task = loop.create_task(coro) + if report_crash: + task.add_done_callback(partial(check_task, logging)) + return task + + +class TaskGroup(object): + '''A class representing a group of executing tasks. tasks is an + optional set of existing tasks to put into the group. New tasks + can later be added using the spawn() method below. wait specifies + the policy used for waiting for tasks. See the join() method + below. Each TaskGroup is an independent entity. Task groups do not + form a hierarchy or any kind of relationship to other previously + created task groups or tasks. Moreover, Tasks created by the top + level spawn() function are not placed into any task group. To + create a task in a group, it should be created using + TaskGroup.spawn() or explicitly added using TaskGroup.add_task(). + + completed attribute: the first task that completed with a result + in the group. Takes into account the wait option used in the + TaskGroup constructor (but not in the join method)`. + ''' + + def __init__(self, tasks=(), *, wait=all): + if wait not in (any, all, object): + raise ValueError('invalid wait argument') + self._done = deque() + self._pending = set() + self._wait = wait + self._done_event = Event() + self._logger = logging.getLogger(self.__class__.__name__) + self._closed = False + self.completed = None + for task in tasks: + self._add_task(task) + + def _add_task(self, task): + '''Add an already existing task to the task group.''' + if hasattr(task, '_task_group'): + raise RuntimeError('task is already part of a group') + if self._closed: + raise RuntimeError('task group is closed') + task._task_group = self + if task.done(): + self._done.append(task) + else: + self._pending.add(task) + task.add_done_callback(self._on_done) + + def _on_done(self, task): + task._task_group = None + self._pending.remove(task) + self._done.append(task) + self._done_event.set() + if self.completed is None: + if not task.cancelled() and not task.exception(): + if self._wait is object and task.result() is None: + pass + else: + self.completed = task + + async def spawn(self, coro, *args): + '''Create a new task that’s part of the group. Returns a Task + instance. + ''' + task = await spawn(coro, *args, report_crash=False) + self._add_task(task) + return task + + async def add_task(self, task): + '''Add an already existing task to the task group.''' + self._add_task(task) + + async def next_done(self): + '''Returns the next completed task. Returns None if no more tasks + remain. A TaskGroup may also be used as an asynchronous iterator. + ''' + if not self._done and self._pending: + self._done_event.clear() + await self._done_event.wait() + if self._done: + return self._done.popleft() + return None + + async def next_result(self): + '''Returns the result of the next completed task. If the task failed + with an exception, that exception is raised. A RuntimeError + exception is raised if this is called when no remaining tasks + are available.''' + task = await self.next_done() + if not task: + raise RuntimeError('no tasks remain') + return task.result() + + async def join(self): + '''Wait for tasks in the group to terminate according to the wait + policy for the group. + + If the join() operation itself is cancelled, all remaining + tasks in the group are also cancelled. + + If a TaskGroup is used as a context manager, the join() method + is called on context-exit. + + Once join() returns, no more tasks may be added to the task + group. Tasks can be added while join() is running. + ''' + def errored(task): + return not task.cancelled() and task.exception() + + try: + if self._wait in (all, object): + while True: + task = await self.next_done() + if task is None: + return + if errored(task): + break + if self._wait is object: + if task.cancelled() or task.result() is not None: + return + else: # any + task = await self.next_done() + if task is None or not errored(task): + return + finally: + await self.cancel_remaining() + + if errored(task): + raise task.exception() + + async def cancel_remaining(self): + '''Cancel all remaining tasks.''' + self._closed = True + for task in list(self._pending): + task.cancel() + with suppress(CancelledError): + await task + + def closed(self): + return self._closed + + def __aiter__(self): + return self + + async def __anext__(self): + task = await self.next_done() + if task: + return task + raise StopAsyncIteration + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if exc_type: + await self.cancel_remaining() + else: + await self.join() + + +class TaskTimeout(CancelledError): + + def __init__(self, secs): + self.secs = secs + + def __str__(self): + return f'task timed out after {self.args[0]}s' + + +class TimeoutCancellationError(CancelledError): + pass + + +class UncaughtTimeoutError(Exception): + pass + + +def _set_new_deadline(task, deadline): + def timeout_task(): + # Unfortunately task.cancel is all we can do with asyncio + task.cancel() + task._timed_out = deadline + task._deadline_handle = task._loop.call_at(deadline, timeout_task) + + +def _set_task_deadline(task, deadline): + deadlines = getattr(task, '_deadlines', []) + if deadlines: + if deadline < min(deadlines): + task._deadline_handle.cancel() + _set_new_deadline(task, deadline) + else: + _set_new_deadline(task, deadline) + deadlines.append(deadline) + task._deadlines = deadlines + task._timed_out = None + + +def _unset_task_deadline(task): + deadlines = task._deadlines + timed_out_deadline = task._timed_out + uncaught = timed_out_deadline not in deadlines + task._deadline_handle.cancel() + deadlines.pop() + if deadlines: + _set_new_deadline(task, min(deadlines)) + return timed_out_deadline, uncaught + + +class TimeoutAfter(object): + + def __init__(self, deadline, *, ignore=False, absolute=False): + self._deadline = deadline + self._ignore = ignore + self._absolute = absolute + self.expired = False + + async def __aenter__(self): + task = asyncio.current_task() + loop_time = task._loop.time() + if self._absolute: + self._secs = self._deadline - loop_time + else: + self._secs = self._deadline + self._deadline += loop_time + _set_task_deadline(task, self._deadline) + self.expired = False + self._task = task + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + timed_out_deadline, uncaught = _unset_task_deadline(self._task) + if exc_type not in (CancelledError, TaskTimeout, + TimeoutCancellationError): + return False + if timed_out_deadline == self._deadline: + self.expired = True + if self._ignore: + return True + raise TaskTimeout(self._secs) from None + if timed_out_deadline is None: + assert exc_type is CancelledError + return False + if uncaught: + raise UncaughtTimeoutError('uncaught timeout received') + if exc_type is TimeoutCancellationError: + return False + raise TimeoutCancellationError(timed_out_deadline) from None + + +async def _timeout_after_func(seconds, absolute, coro, args): + coro = normalize_corofunc(coro, args) + async with TimeoutAfter(seconds, absolute=absolute): + return await coro + + +def timeout_after(seconds, coro=None, *args): + '''Execute the specified coroutine and return its result. However, + issue a cancellation request to the calling task after seconds + have elapsed. When this happens, a TaskTimeout exception is + raised. If coro is None, the result of this function serves + as an asynchronous context manager that applies a timeout to a + block of statements. + + timeout_after() may be composed with other timeout_after() + operations (i.e., nested timeouts). If an outer timeout expires + first, then TimeoutCancellationError is raised instead of + TaskTimeout. If an inner timeout expires and fails to properly + TaskTimeout, a UncaughtTimeoutError is raised in the outer + timeout. + + ''' + if coro: + return _timeout_after_func(seconds, False, coro, args) + + return TimeoutAfter(seconds) + + +def timeout_at(clock, coro=None, *args): + '''Execute the specified coroutine and return its result. However, + issue a cancellation request to the calling task after seconds + have elapsed. When this happens, a TaskTimeout exception is + raised. If coro is None, the result of this function serves + as an asynchronous context manager that applies a timeout to a + block of statements. + + timeout_after() may be composed with other timeout_after() + operations (i.e., nested timeouts). If an outer timeout expires + first, then TimeoutCancellationError is raised instead of + TaskTimeout. If an inner timeout expires and fails to properly + TaskTimeout, a UncaughtTimeoutError is raised in the outer + timeout. + + ''' + if coro: + return _timeout_after_func(clock, True, coro, args) + + return TimeoutAfter(clock, absolute=True) + + +async def _ignore_after_func(seconds, absolute, coro, args, timeout_result): + coro = normalize_corofunc(coro, args) + async with TimeoutAfter(seconds, absolute=absolute, ignore=True): + return await coro + + return timeout_result + + +def ignore_after(seconds, coro=None, *args, timeout_result=None): + '''Execute the specified coroutine and return its result. Issue a + cancellation request after seconds have elapsed. When a timeout + occurs, no exception is raised. Instead, timeout_result is + returned. + + If coro is None, the result is an asynchronous context manager + that applies a timeout to a block of statements. For the context + manager case, the resulting context manager object has an expired + attribute set to True if time expired. + + Note: ignore_after() may also be composed with other timeout + operations. TimeoutCancellationError and UncaughtTimeoutError + exceptions might be raised according to the same rules as for + timeout_after(). + ''' + if coro: + return _ignore_after_func(seconds, False, coro, args, timeout_result) + + return TimeoutAfter(seconds, ignore=True) + + +def ignore_at(clock, coro=None, *args, timeout_result=None): + ''' + Stop the enclosed task or block of code at an absolute + clock value. Same usage as ignore_after(). + ''' + if coro: + return _ignore_after_func(clock, True, coro, args, timeout_result) + + return TimeoutAfter(clock, absolute=True, ignore=True) diff --git a/torba/rpc/framing.py b/torba/rpc/framing.py new file mode 100644 index 000000000..6a5c2b9be --- /dev/null +++ b/torba/rpc/framing.py @@ -0,0 +1,239 @@ +# Copyright (c) 2018, Neil Booth +# +# All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +'''RPC message framing in a byte stream.''' + +__all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer', + 'OversizedPayloadError', 'BadChecksumError', 'BadMagicError') + +from hashlib import sha256 as _sha256 +from struct import Struct +from asyncio import Queue + + +class FramerBase(object): + '''Abstract base class for a framer. + + A framer breaks an incoming byte stream into protocol messages, + buffering if necesary. It also frames outgoing messages into + a byte stream. + ''' + + def frame(self, message): + '''Return the framed message.''' + raise NotImplementedError + + def received_bytes(self, data): + '''Pass incoming network bytes.''' + raise NotImplementedError + + async def receive_message(self): + '''Wait for a complete unframed message to arrive, and return it.''' + raise NotImplementedError + + +class NewlineFramer(FramerBase): + '''A framer for a protocol where messages are separated by newlines.''' + + # The default max_size value is motivated by JSONRPC, where a + # normal request will be 250 bytes or less, and a reasonable + # batch may contain 4000 requests. + def __init__(self, max_size=250 * 4000): + '''max_size - an anti-DoS measure. If, after processing an incoming + message, buffered data would exceed max_size bytes, that + buffered data is dropped entirely and the framer waits for a + newline character to re-synchronize the stream. + ''' + self.max_size = max_size + self.queue = Queue() + self.received_bytes = self.queue.put_nowait + self.synchronizing = False + self.residual = b'' + + def frame(self, message): + return message + b'\n' + + async def receive_message(self): + parts = [] + buffer_size = 0 + while True: + part = self.residual + self.residual = b'' + if not part: + part = await self.queue.get() + + npos = part.find(b'\n') + if npos == -1: + parts.append(part) + buffer_size += len(part) + # Ignore over-sized messages; re-synchronize + if buffer_size <= self.max_size: + continue + self.synchronizing = True + raise MemoryError(f'dropping message over {self.max_size:,d} ' + f'bytes and re-synchronizing') + + tail, self.residual = part[:npos], part[npos + 1:] + if self.synchronizing: + self.synchronizing = False + return await self.receive_message() + else: + parts.append(tail) + return b''.join(parts) + + +class ByteQueue(object): + '''A producer-comsumer queue. Incoming network data is put as it + arrives, and the consumer calls an async method waiting for data of + a specific length.''' + + def __init__(self): + self.queue = Queue() + self.parts = [] + self.parts_len = 0 + self.put_nowait = self.queue.put_nowait + + async def receive(self, size): + while self.parts_len < size: + part = await self.queue.get() + self.parts.append(part) + self.parts_len += len(part) + self.parts_len -= size + whole = b''.join(self.parts) + self.parts = [whole[size:]] + return whole[:size] + + +class BinaryFramer(object): + '''A framer for binary messaging protocols.''' + + def __init__(self): + self.byte_queue = ByteQueue() + self.message_queue = Queue() + self.received_bytes = self.byte_queue.put_nowait + + def frame(self, message): + command, payload = message + return b''.join(( + self._build_header(command, payload), + payload + )) + + async def receive_message(self): + command, payload_len, checksum = await self._receive_header() + payload = await self.byte_queue.receive(payload_len) + payload_checksum = self._checksum(payload) + if payload_checksum != checksum: + raise BadChecksumError(payload_checksum, checksum) + return command, payload + + def _checksum(self, payload): + raise NotImplementedError + + def _build_header(self, command, payload): + raise NotImplementedError + + async def _receive_header(self): + raise NotImplementedError + + +# Helpers +struct_le_I = Struct(' 1024 * 1024: + if command != b'block' or payload_len > self._max_block_size: + raise OversizedPayloadError(command, payload_len) + return command, payload_len, checksum diff --git a/torba/rpc/jsonrpc.py b/torba/rpc/jsonrpc.py new file mode 100644 index 000000000..6cbd5f11a --- /dev/null +++ b/torba/rpc/jsonrpc.py @@ -0,0 +1,801 @@ +# Copyright (c) 2018, Neil Booth +# +# All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +'''Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation.''' + +__all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose', + 'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch', + 'RPCError', 'ProtocolError', + 'JSONRPCConnection', 'handler_invocation') + +import itertools +import json +from functools import partial +from numbers import Number + +import attr +from asyncio import Queue, Event, CancelledError +from .util import signature_info + + +class SingleRequest(object): + __slots__ = ('method', 'args') + + def __init__(self, method, args): + if not isinstance(method, str): + raise ProtocolError(JSONRPC.METHOD_NOT_FOUND, + 'method must be a string') + if not isinstance(args, (list, tuple, dict)): + raise ProtocolError.invalid_args('request arguments must be a ' + 'list or a dictionary') + self.args = args + self.method = method + + def __repr__(self): + return f'{self.__class__.__name__}({self.method!r}, {self.args!r})' + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.method == other.method and self.args == other.args) + + +class Request(SingleRequest): + def send_result(self, response): + return None + + +class Notification(SingleRequest): + pass + + +class Batch(object): + __slots__ = ('items', ) + + def __init__(self, items): + if not isinstance(items, (list, tuple)): + raise ProtocolError.invalid_request('items must be a list') + if not items: + raise ProtocolError.empty_batch() + if not (all(isinstance(item, SingleRequest) for item in items) or + all(isinstance(item, Response) for item in items)): + raise ProtocolError.invalid_request('batch must be homogeneous') + self.items = items + + def __len__(self): + return len(self.items) + + def __getitem__(self, item): + return self.items[item] + + def __iter__(self): + return iter(self.items) + + def __repr__(self): + return f'Batch({len(self.items)} items)' + + +class Response(object): + __slots__ = ('result', ) + + def __init__(self, result): + # Type checking happens when converting to a message + self.result = result + + +class CodeMessageError(Exception): + + def __init__(self, code, message): + super().__init__(code, message) + + @property + def code(self): + return self.args[0] + + @property + def message(self): + return self.args[1] + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.code == other.code and self.message == other.message) + + def __hash__(self): + # overridden to make the exception hashable + # see https://bugs.python.org/issue28603 + return hash((self.code, self.message)) + + @classmethod + def invalid_args(cls, message): + return cls(JSONRPC.INVALID_ARGS, message) + + @classmethod + def invalid_request(cls, message): + return cls(JSONRPC.INVALID_REQUEST, message) + + @classmethod + def empty_batch(cls): + return cls.invalid_request('batch is empty') + + +class RPCError(CodeMessageError): + pass + + +class ProtocolError(CodeMessageError): + + def __init__(self, code, message): + super().__init__(code, message) + # If not None send this unframed message over the network + self.error_message = None + # If the error was in a JSON response message; its message ID. + # Since None can be a response message ID, "id" means the + # error was not sent in a JSON response + self.response_msg_id = id + + +class JSONRPC(object): + '''Abstract base class that interprets and constructs JSON RPC messages.''' + + # Error codes. See http://www.jsonrpc.org/specification + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_ARGS = -32602 + INTERNAL_ERROR = -32603 + # Codes specific to this library + ERROR_CODE_UNAVAILABLE = -100 + + # Can be overridden by derived classes + allow_batches = True + + @classmethod + def _message_id(cls, message, require_id): + '''Validate the message is a dictionary and return its ID. + + Raise an error if the message is invalid or the ID is of an + invalid type. If it has no ID, raise an error if require_id + is True, otherwise return None. + ''' + raise NotImplementedError + + @classmethod + def _validate_message(cls, message): + '''Validate other parts of the message other than those + done in _message_id.''' + pass + + @classmethod + def _request_args(cls, request): + '''Validate the existence and type of the arguments passed + in the request dictionary.''' + raise NotImplementedError + + @classmethod + def _process_request(cls, payload): + request_id = None + try: + request_id = cls._message_id(payload, False) + cls._validate_message(payload) + method = payload.get('method') + if request_id is None: + item = Notification(method, cls._request_args(payload)) + else: + item = Request(method, cls._request_args(payload)) + return item, request_id + except ProtocolError as error: + code, message = error.code, error.message + raise cls._error(code, message, True, request_id) + + @classmethod + def _process_response(cls, payload): + request_id = None + try: + request_id = cls._message_id(payload, True) + cls._validate_message(payload) + return Response(cls.response_value(payload)), request_id + except ProtocolError as error: + code, message = error.code, error.message + raise cls._error(code, message, False, request_id) + + @classmethod + def _message_to_payload(cls, message): + '''Returns a Python object or a ProtocolError.''' + try: + return json.loads(message.decode()) + except UnicodeDecodeError: + message = 'messages must be encoded in UTF-8' + except json.JSONDecodeError: + message = 'invalid JSON' + raise cls._error(cls.PARSE_ERROR, message, True, None) + + @classmethod + def _error(cls, code, message, send, msg_id): + error = ProtocolError(code, message) + if send: + error.error_message = cls.response_message(error, msg_id) + else: + error.response_msg_id = msg_id + return error + + # + # External API + # + + @classmethod + def message_to_item(cls, message): + '''Translate an unframed received message and return an + (item, request_id) pair. + + The item can be a Request, Notification, Response or a list. + + A JSON RPC error response is returned as an RPCError inside a + Response object. + + If a Batch is returned, request_id is an iterable of request + ids, one per batch member. + + If the message violates the protocol in some way a + ProtocolError is returned, except if the message was + determined to be a response, in which case the ProtocolError + is placed inside a Response object. This is so that client + code can mark a request as having been responded to even if + the response was bad. + + raises: ProtocolError + ''' + payload = cls._message_to_payload(message) + if isinstance(payload, dict): + if 'method' in payload: + return cls._process_request(payload) + else: + return cls._process_response(payload) + elif isinstance(payload, list) and cls.allow_batches: + if not payload: + raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty', + True, None) + return payload, None + raise cls._error(cls.INVALID_REQUEST, + 'request object must be a dictionary', True, None) + + # Message formation + @classmethod + def request_message(cls, item, request_id): + '''Convert an RPCRequest item to a message.''' + assert isinstance(item, Request) + return cls.encode_payload(cls.request_payload(item, request_id)) + + @classmethod + def notification_message(cls, item): + '''Convert an RPCRequest item to a message.''' + assert isinstance(item, Notification) + return cls.encode_payload(cls.request_payload(item, None)) + + @classmethod + def response_message(cls, result, request_id): + '''Convert a response result (or RPCError) to a message.''' + if isinstance(result, CodeMessageError): + payload = cls.error_payload(result, request_id) + else: + payload = cls.response_payload(result, request_id) + return cls.encode_payload(payload) + + @classmethod + def batch_message(cls, batch, request_ids): + '''Convert a request Batch to a message.''' + assert isinstance(batch, Batch) + if not cls.allow_batches: + raise ProtocolError.invalid_request( + 'protocol does not permit batches') + id_iter = iter(request_ids) + rm = cls.request_message + nm = cls.notification_message + parts = (rm(request, next(id_iter)) if isinstance(request, Request) + else nm(request) for request in batch) + return cls.batch_message_from_parts(parts) + + @classmethod + def batch_message_from_parts(cls, messages): + '''Convert messages, one per batch item, into a batch message. At + least one message must be passed. + ''' + # Comma-separate the messages and wrap the lot in square brackets + middle = b', '.join(messages) + if not middle: + raise ProtocolError.empty_batch() + return b''.join([b'[', middle, b']']) + + @classmethod + def encode_payload(cls, payload): + '''Encode a Python object as JSON and convert it to bytes.''' + try: + return json.dumps(payload).encode() + except TypeError: + msg = f'JSON payload encoding error: {payload}' + raise ProtocolError(cls.INTERNAL_ERROR, msg) from None + + +class JSONRPCv1(JSONRPC): + '''JSON RPC version 1.0.''' + + allow_batches = False + + @classmethod + def _message_id(cls, message, require_id): + # JSONv1 requires an ID always, but without constraint on its type + # No need to test for a dictionary here as we don't handle batches. + if 'id' not in message: + raise ProtocolError.invalid_request('request has no "id"') + return message['id'] + + @classmethod + def _request_args(cls, request): + args = request.get('params') + if not isinstance(args, list): + raise ProtocolError.invalid_args( + f'invalid request arguments: {args}') + return args + + @classmethod + def _best_effort_error(cls, error): + # Do our best to interpret the error + code = cls.ERROR_CODE_UNAVAILABLE + message = 'no error message provided' + if isinstance(error, str): + message = error + elif isinstance(error, int): + code = error + elif isinstance(error, dict): + if isinstance(error.get('message'), str): + message = error['message'] + if isinstance(error.get('code'), int): + code = error['code'] + + return RPCError(code, message) + + @classmethod + def response_value(cls, payload): + if 'result' not in payload or 'error' not in payload: + raise ProtocolError.invalid_request( + 'response must contain both "result" and "error"') + + result = payload['result'] + error = payload['error'] + if error is None: + return result # It seems None can be a valid result + if result is not None: + raise ProtocolError.invalid_request( + 'response has a "result" and an "error"') + + return cls._best_effort_error(error) + + @classmethod + def request_payload(cls, request, request_id): + '''JSON v1 request (or notification) payload.''' + if isinstance(request.args, dict): + raise ProtocolError.invalid_args( + 'JSONRPCv1 does not support named arguments') + return { + 'method': request.method, + 'params': request.args, + 'id': request_id + } + + @classmethod + def response_payload(cls, result, request_id): + '''JSON v1 response payload.''' + return { + 'result': result, + 'error': None, + 'id': request_id + } + + @classmethod + def error_payload(cls, error, request_id): + return { + 'result': None, + 'error': {'code': error.code, 'message': error.message}, + 'id': request_id + } + + +class JSONRPCv2(JSONRPC): + '''JSON RPC version 2.0.''' + + @classmethod + def _message_id(cls, message, require_id): + if not isinstance(message, dict): + raise ProtocolError.invalid_request( + 'request object must be a dictionary') + if 'id' in message: + request_id = message['id'] + if not isinstance(request_id, (Number, str, type(None))): + raise ProtocolError.invalid_request( + f'invalid "id": {request_id}') + return request_id + else: + if require_id: + raise ProtocolError.invalid_request('request has no "id"') + return None + + @classmethod + def _validate_message(cls, message): + if message.get('jsonrpc') != '2.0': + raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"') + + @classmethod + def _request_args(cls, request): + args = request.get('params', []) + if not isinstance(args, (dict, list)): + raise ProtocolError.invalid_args( + f'invalid request arguments: {args}') + return args + + @classmethod + def response_value(cls, payload): + if 'result' in payload: + if 'error' in payload: + raise ProtocolError.invalid_request( + 'response contains both "result" and "error"') + return payload['result'] + + if 'error' not in payload: + raise ProtocolError.invalid_request( + 'response contains neither "result" nor "error"') + + # Return an RPCError object + error = payload['error'] + if isinstance(error, dict): + code = error.get('code') + message = error.get('message') + if isinstance(code, int) and isinstance(message, str): + return RPCError(code, message) + + raise ProtocolError.invalid_request( + f'ill-formed response error object: {error}') + + @classmethod + def request_payload(cls, request, request_id): + '''JSON v2 request (or notification) payload.''' + payload = { + 'jsonrpc': '2.0', + 'method': request.method, + } + # A notification? + if request_id is not None: + payload['id'] = request_id + # Preserve empty dicts as missing params is read as an array + if request.args or request.args == {}: + payload['params'] = request.args + return payload + + @classmethod + def response_payload(cls, result, request_id): + '''JSON v2 response payload.''' + return { + 'jsonrpc': '2.0', + 'result': result, + 'id': request_id + } + + @classmethod + def error_payload(cls, error, request_id): + return { + 'jsonrpc': '2.0', + 'error': {'code': error.code, 'message': error.message}, + 'id': request_id + } + + +class JSONRPCLoose(JSONRPC): + '''A relaxed versin of JSON RPC.''' + + # Don't be so loose we accept any old message ID + _message_id = JSONRPCv2._message_id + _validate_message = JSONRPC._validate_message + _request_args = JSONRPCv2._request_args + # Outoing messages are JSONRPCv2 so we give the other side the + # best chance to assume / detect JSONRPCv2 as default protocol. + error_payload = JSONRPCv2.error_payload + request_payload = JSONRPCv2.request_payload + response_payload = JSONRPCv2.response_payload + + @classmethod + def response_value(cls, payload): + # Return result, unless it is None and there is an error + if payload.get('error') is not None: + if payload.get('result') is not None: + raise ProtocolError.invalid_request( + 'response contains both "result" and "error"') + return JSONRPCv1._best_effort_error(payload['error']) + + if 'result' not in payload: + raise ProtocolError.invalid_request( + 'response contains neither "result" nor "error"') + + # Can be None + return payload['result'] + + +class JSONRPCAutoDetect(JSONRPCv2): + + @classmethod + def message_to_item(cls, message): + return cls.detect_protocol(message), None + + @classmethod + def detect_protocol(cls, message): + '''Attempt to detect the protocol from the message.''' + main = cls._message_to_payload(message) + + def protocol_for_payload(payload): + if not isinstance(payload, dict): + return JSONRPCLoose # Will error + # Obey an explicit "jsonrpc" + version = payload.get('jsonrpc') + if version == '2.0': + return JSONRPCv2 + if version == '1.0': + return JSONRPCv1 + + # Now to decide between JSONRPCLoose and JSONRPCv1 if possible + if 'result' in payload and 'error' in payload: + return JSONRPCv1 + return JSONRPCLoose + + if isinstance(main, list): + parts = set(protocol_for_payload(payload) for payload in main) + # If all same protocol, return it + if len(parts) == 1: + return parts.pop() + # If strict protocol detected, return it, preferring JSONRPCv2. + # This means a batch of JSONRPCv1 will fail + for protocol in (JSONRPCv2, JSONRPCv1): + if protocol in parts: + return protocol + # Will error if no parts + return JSONRPCLoose + + return protocol_for_payload(main) + + +class JSONRPCConnection(object): + '''Maintains state of a JSON RPC connection, in particular + encapsulating the handling of request IDs. + + protocol - the JSON RPC protocol to follow + max_response_size - responses over this size send an error response + instead. + ''' + + _id_counter = itertools.count() + + def __init__(self, protocol): + self._protocol = protocol + # Sent Requests and Batches that have not received a response. + # The key is its request ID; for a batch it is sorted tuple + # of request IDs + self._requests = {} + # A public attribute intended to be settable dynamically + self.max_response_size = 0 + + def _oversized_response_message(self, request_id): + text = f'response too large (over {self.max_response_size:,d} bytes' + error = RPCError.invalid_request(text) + return self._protocol.response_message(error, request_id) + + def _receive_response(self, result, request_id): + if request_id not in self._requests: + if request_id is None and isinstance(result, RPCError): + message = f'diagnostic error received: {result}' + else: + message = f'response to unsent request (ID: {request_id})' + raise ProtocolError.invalid_request(message) from None + request, event = self._requests.pop(request_id) + event.result = result + event.set() + return [] + + def _receive_request_batch(self, payloads): + def item_send_result(request_id, result): + nonlocal size + part = protocol.response_message(result, request_id) + size += len(part) + 2 + if size > self.max_response_size > 0: + part = self._oversized_response_message(request_id) + parts.append(part) + if len(parts) == count: + return protocol.batch_message_from_parts(parts) + return None + + parts = [] + items = [] + size = 0 + count = 0 + protocol = self._protocol + for payload in payloads: + try: + item, request_id = protocol._process_request(payload) + items.append(item) + if isinstance(item, Request): + count += 1 + item.send_result = partial(item_send_result, request_id) + except ProtocolError as error: + count += 1 + parts.append(error.error_message) + + if not items and parts: + error = ProtocolError(0, "") + error.error_message = protocol.batch_message_from_parts(parts) + raise error + return items + + def _receive_response_batch(self, payloads): + request_ids = [] + results = [] + for payload in payloads: + # Let ProtocolError exceptions through + item, request_id = self._protocol._process_response(payload) + request_ids.append(request_id) + results.append(item.result) + + ordered = sorted(zip(request_ids, results), key=lambda t: t[0]) + ordered_ids, ordered_results = zip(*ordered) + if ordered_ids not in self._requests: + raise ProtocolError.invalid_request('response to unsent batch') + request_batch, event = self._requests.pop(ordered_ids) + event.result = ordered_results + event.set() + return [] + + def _send_result(self, request_id, result): + message = self._protocol.response_message(result, request_id) + if len(message) > self.max_response_size > 0: + message = self._oversized_response_message(request_id) + return message + + def _event(self, request, request_id): + event = Event() + self._requests[request_id] = (request, event) + return event + + # + # External API + # + def send_request(self, request): + '''Send a Request. Return a (message, event) pair. + + The message is an unframed message to send over the network. + Wait on the event for the response; which will be in the + "result" attribute. + + Raises: ProtocolError if the request violates the protocol + in some way.. + ''' + request_id = next(self._id_counter) + message = self._protocol.request_message(request, request_id) + return message, self._event(request, request_id) + + def send_notification(self, notification): + return self._protocol.notification_message(notification) + + def send_batch(self, batch): + ids = tuple(next(self._id_counter) + for request in batch if isinstance(request, Request)) + message = self._protocol.batch_message(batch, ids) + event = self._event(batch, ids) if ids else None + return message, event + + def receive_message(self, message): + '''Call with an unframed message received from the network. + + Raises: ProtocolError if the message violates the protocol in + some way. However, if it happened in a response that can be + paired with a request, the ProtocolError is instead set in the + result attribute of the send_request() that caused the error. + ''' + try: + item, request_id = self._protocol.message_to_item(message) + except ProtocolError as e: + if e.response_msg_id is not id: + return self._receive_response(e, e.response_msg_id) + raise + + if isinstance(item, Request): + item.send_result = partial(self._send_result, request_id) + return [item] + if isinstance(item, Notification): + return [item] + if isinstance(item, Response): + return self._receive_response(item.result, request_id) + if isinstance(item, list): + if all(isinstance(payload, dict) + and ('result' in payload or 'error' in payload) + for payload in item): + return self._receive_response_batch(item) + else: + return self._receive_request_batch(item) + else: + # Protocol auto-detection hack + assert issubclass(item, JSONRPC) + self._protocol = item + return self.receive_message(message) + + def cancel_pending_requests(self): + '''Cancel all pending requests.''' + exception = CancelledError() + for request, event in self._requests.values(): + event.result = exception + event.set() + self._requests.clear() + + def pending_requests(self): + '''All sent requests that have not received a response.''' + return [request for request, event in self._requests.values()] + + +def handler_invocation(handler, request): + method, args = request.method, request.args + if handler is None: + raise RPCError(JSONRPC.METHOD_NOT_FOUND, + f'unknown method "{method}"') + + # We must test for too few and too many arguments. How + # depends on whether the arguments were passed as a list or as + # a dictionary. + info = signature_info(handler) + if isinstance(args, (tuple, list)): + if len(args) < info.min_args: + s = '' if len(args) == 1 else 's' + raise RPCError.invalid_args( + f'{len(args)} argument{s} passed to method ' + f'"{method}" but it requires {info.min_args}') + if info.max_args is not None and len(args) > info.max_args: + s = '' if len(args) == 1 else 's' + raise RPCError.invalid_args( + f'{len(args)} argument{s} passed to method ' + f'{method} taking at most {info.max_args}') + return partial(handler, *args) + + # Arguments passed by name + if info.other_names is None: + raise RPCError.invalid_args(f'method "{method}" cannot ' + f'be called with named arguments') + + missing = set(info.required_names).difference(args) + if missing: + s = '' if len(missing) == 1 else 's' + missing = ', '.join(sorted(f'"{name}"' for name in missing)) + raise RPCError.invalid_args(f'method "{method}" requires ' + f'parameter{s} {missing}') + + if info.other_names is not any: + excess = set(args).difference(info.required_names) + excess = excess.difference(info.other_names) + if excess: + s = '' if len(excess) == 1 else 's' + excess = ', '.join(sorted(f'"{name}"' for name in excess)) + raise RPCError.invalid_args(f'method "{method}" does not ' + f'take parameter{s} {excess}') + return partial(handler, **args) diff --git a/torba/rpc/session.py b/torba/rpc/session.py new file mode 100644 index 000000000..144cc2f02 --- /dev/null +++ b/torba/rpc/session.py @@ -0,0 +1,549 @@ +# Copyright (c) 2018, Neil Booth +# +# All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +__all__ = ('Connector', 'RPCSession', 'MessageSession', 'Server', + 'BatchError') + + +import asyncio +import logging +import time +from contextlib import suppress + +from . import * +from .util import Concurrency + + +class Connector(object): + + def __init__(self, session_factory, host=None, port=None, proxy=None, + **kwargs): + self.session_factory = session_factory + self.host = host + self.port = port + self.proxy = proxy + self.loop = kwargs.get('loop', asyncio.get_event_loop()) + self.kwargs = kwargs + + async def create_connection(self): + '''Initiate a connection.''' + connector = self.proxy or self.loop + return await connector.create_connection( + self.session_factory, self.host, self.port, **self.kwargs) + + async def __aenter__(self): + transport, self.protocol = await self.create_connection() + # By default, do not limit outgoing connections + self.protocol.bw_limit = 0 + return self.protocol + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.protocol.close() + + +class SessionBase(asyncio.Protocol): + '''Base class of networking sessions. + + There is no client / server distinction other than who initiated + the connection. + + To initiate a connection to a remote server pass host, port and + proxy to the constructor, and then call create_connection(). Each + successful call should have a corresponding call to close(). + + Alternatively if used in a with statement, the connection is made + on entry to the block, and closed on exit from the block. + ''' + + max_errors = 10 + + def __init__(self, *, framer=None, loop=None): + self.framer = framer or self.default_framer() + self.loop = loop or asyncio.get_event_loop() + self.logger = logging.getLogger(self.__class__.__name__) + self.transport = None + # Set when a connection is made + self._address = None + self._proxy_address = None + # For logger.debug messsages + self.verbosity = 0 + # Cleared when the send socket is full + self._can_send = Event() + self._can_send.set() + self._pm_task = None + self._task_group = TaskGroup() + # Force-close a connection if a send doesn't succeed in this time + self.max_send_delay = 60 + # Statistics. The RPC object also keeps its own statistics. + self.start_time = time.time() + self.errors = 0 + self.send_count = 0 + self.send_size = 0 + self.last_send = self.start_time + self.recv_count = 0 + self.recv_size = 0 + self.last_recv = self.start_time + # Bandwidth usage per hour before throttling starts + self.bw_limit = 2000000 + self.bw_time = self.start_time + self.bw_charge = 0 + # Concurrency control + self.max_concurrent = 6 + self._concurrency = Concurrency(self.max_concurrent) + + async def _update_concurrency(self): + # A non-positive value means not to limit concurrency + if self.bw_limit <= 0: + return + now = time.time() + # Reduce the recorded usage in proportion to the elapsed time + refund = (now - self.bw_time) * (self.bw_limit / 3600) + self.bw_charge = max(0, self.bw_charge - int(refund)) + self.bw_time = now + # Reduce concurrency allocation by 1 for each whole bw_limit used + throttle = int(self.bw_charge / self.bw_limit) + target = max(1, self.max_concurrent - throttle) + current = self._concurrency.max_concurrent + if target != current: + self.logger.info(f'changing task concurrency from {current} ' + f'to {target}') + await self._concurrency.set_max_concurrent(target) + + def _using_bandwidth(self, size): + '''Called when sending or receiving size bytes.''' + self.bw_charge += size + + async def _process_messages(self): + '''Process incoming messages asynchronously and consume the + results. + ''' + async def collect_tasks(): + next_done = task_group.next_done + while True: + await next_done() + + task_group = self._task_group + async with task_group: + await self.spawn(self._receive_messages) + await self.spawn(collect_tasks) + + async def _limited_wait(self, secs): + # Wait at most secs seconds to send, otherwise abort the connection + try: + async with timeout_after(secs): + await self._can_send.wait() + except TaskTimeout: + self.abort() + raise + + async def _send_message(self, message): + if not self._can_send.is_set(): + await self._limited_wait(self.max_send_delay) + if not self.is_closing(): + framed_message = self.framer.frame(message) + self.send_size += len(framed_message) + self._using_bandwidth(len(framed_message)) + self.send_count += 1 + self.last_send = time.time() + if self.verbosity >= 4: + self.logger.debug(f'Sending framed message {framed_message}') + self.transport.write(framed_message) + + def _bump_errors(self): + self.errors += 1 + if self.errors >= self.max_errors: + # Don't await self.close() because that is self-cancelling + self._close() + + def _close(self): + if self.transport: + self.transport.close() + + # asyncio framework + def data_received(self, framed_message): + '''Called by asyncio when a message comes in.''' + if self.verbosity >= 4: + self.logger.debug(f'Received framed message {framed_message}') + self.recv_size += len(framed_message) + self._using_bandwidth(len(framed_message)) + self.framer.received_bytes(framed_message) + + def pause_writing(self): + '''Transport calls when the send buffer is full.''' + if not self.is_closing(): + self._can_send.clear() + self.transport.pause_reading() + + def resume_writing(self): + '''Transport calls when the send buffer has room.''' + if not self._can_send.is_set(): + self._can_send.set() + self.transport.resume_reading() + + def connection_made(self, transport): + '''Called by asyncio when a connection is established. + + Derived classes overriding this method must call this first.''' + self.transport = transport + # This would throw if called on a closed SSL transport. Fixed + # in asyncio in Python 3.6.1 and 3.5.4 + peer_address = transport.get_extra_info('peername') + # If the Socks proxy was used then _address is already set to + # the remote address + if self._address: + self._proxy_address = peer_address + else: + self._address = peer_address + self._pm_task = spawn_sync(self._process_messages(), loop=self.loop) + + def connection_lost(self, exc): + '''Called by asyncio when the connection closes. + + Tear down things done in connection_made.''' + self._address = None + self.transport = None + self._pm_task.cancel() + # Release waiting tasks + self._can_send.set() + + # External API + def default_framer(self): + '''Return a default framer.''' + raise NotImplementedError + + def peer_address(self): + '''Returns the peer's address (Python networking address), or None if + no connection or an error. + + This is the result of socket.getpeername() when the connection + was made. + ''' + return self._address + + def peer_address_str(self): + '''Returns the peer's IP address and port as a human-readable + string.''' + if not self._address: + return 'unknown' + ip_addr_str, port = self._address[:2] + if ':' in ip_addr_str: + return f'[{ip_addr_str}]:{port}' + else: + return f'{ip_addr_str}:{port}' + + async def spawn(self, coro, *args): + '''If the session is connected, spawn a task that is cancelled + on disconnect, and return it. Otherwise return None.''' + group = self._task_group + if not group.closed(): + return await group.spawn(coro, *args) + else: + return None + + def is_closing(self): + '''Return True if the connection is closing.''' + return not self.transport or self.transport.is_closing() + + def abort(self): + '''Forcefully close the connection.''' + if self.transport: + self.transport.abort() + + async def close(self, *, force_after=30): + '''Close the connection and return when closed.''' + self._close() + if self._pm_task: + with suppress(CancelledError): + async with ignore_after(force_after): + await self._pm_task + self.abort() + await self._pm_task + + +class MessageSession(SessionBase): + '''Session class for protocols where messages are not tied to responses, + such as the Bitcoin protocol. + + To use as a client (connection-opening) session, pass host, port + and perhaps a proxy. + ''' + async def _receive_messages(self): + while not self.is_closing(): + try: + message = await self.framer.receive_message() + except BadMagicError as e: + magic, expected = e.args + self.logger.error( + f'bad network magic: got {magic} expected {expected}, ' + f'disconnecting' + ) + self._close() + except OversizedPayloadError as e: + command, payload_len = e.args + self.logger.error( + f'oversized payload of {payload_len:,d} bytes to command ' + f'{command}, disconnecting' + ) + self._close() + except BadChecksumError as e: + payload_checksum, claimed_checksum = e.args + self.logger.warning( + f'checksum mismatch: actual {payload_checksum.hex()} ' + f'vs claimed {claimed_checksum.hex()}' + ) + self._bump_errors() + else: + self.last_recv = time.time() + self.recv_count += 1 + if self.recv_count % 10 == 0: + await self._update_concurrency() + await self.spawn(self._throttled_message(message)) + + async def _throttled_message(self, message): + '''Process a single request, respecting the concurrency limit.''' + async with self._concurrency.semaphore: + try: + await self.handle_message(message) + except ProtocolError as e: + self.logger.error(f'{e}') + self._bump_errors() + except CancelledError: + raise + except Exception: + self.logger.exception(f'exception handling {message}') + self._bump_errors() + + # External API + def default_framer(self): + '''Return a bitcoin framer.''' + return BitcoinFramer(bytes.fromhex('e3e1f3e8'), 128_000_000) + + async def handle_message(self, message): + '''message is a (command, payload) pair.''' + pass + + async def send_message(self, message): + '''Send a message (command, payload) over the network.''' + await self._send_message(message) + + +class BatchError(Exception): + + def __init__(self, request): + self.request = request # BatchRequest object + + +class BatchRequest(object): + '''Used to build a batch request to send to the server. Stores + the + + Attributes batch and results are initially None. + + Adding an invalid request or notification immediately raises a + ProtocolError. + + On exiting the with clause, it will: + + 1) create a Batch object for the requests in the order they were + added. If the batch is empty this raises a ProtocolError. + + 2) set the "batch" attribute to be that batch + + 3) send the batch request and wait for a response + + 4) raise a ProtocolError if the protocol was violated by the + server. Currently this only happens if it gave more than one + response to any request + + 5) otherwise there is precisely one response to each Request. Set + the "results" attribute to the tuple of results; the responses + are ordered to match the Requests in the batch. Notifications + do not get a response. + + 6) if raise_errors is True and any individual response was a JSON + RPC error response, or violated the protocol in some way, a + BatchError exception is raised. Otherwise the caller can be + certain each request returned a standard result. + ''' + + def __init__(self, session, raise_errors): + self._session = session + self._raise_errors = raise_errors + self._requests = [] + self.batch = None + self.results = None + + def add_request(self, method, args=()): + self._requests.append(Request(method, args)) + + def add_notification(self, method, args=()): + self._requests.append(Notification(method, args)) + + def __len__(self): + return len(self._requests) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if exc_type is None: + self.batch = Batch(self._requests) + message, event = self._session.connection.send_batch(self.batch) + await self._session._send_message(message) + await event.wait() + self.results = event.result + if self._raise_errors: + if any(isinstance(item, Exception) for item in event.result): + raise BatchError(self) + + +class RPCSession(SessionBase): + '''Base class for protocols where a message can lead to a response, + for example JSON RPC.''' + + def __init__(self, *, framer=None, loop=None, connection=None): + super().__init__(framer=framer, loop=loop) + self.connection = connection or self.default_connection() + + async def _receive_messages(self): + while not self.is_closing(): + try: + message = await self.framer.receive_message() + except MemoryError as e: + self.logger.warning(f'{e!r}') + continue + + self.last_recv = time.time() + self.recv_count += 1 + if self.recv_count % 10 == 0: + await self._update_concurrency() + + try: + requests = self.connection.receive_message(message) + except ProtocolError as e: + self.logger.debug(f'{e}') + if e.error_message: + await self._send_message(e.error_message) + if e.code == JSONRPC.PARSE_ERROR: + self.max_errors = 0 + self._bump_errors() + else: + for request in requests: + await self.spawn(self._throttled_request(request)) + + async def _throttled_request(self, request): + '''Process a single request, respecting the concurrency limit.''' + async with self._concurrency.semaphore: + try: + result = await self.handle_request(request) + except (ProtocolError, RPCError) as e: + result = e + except CancelledError: + raise + except Exception: + self.logger.exception(f'exception handling {request}') + result = RPCError(JSONRPC.INTERNAL_ERROR, + 'internal server error') + if isinstance(request, Request): + message = request.send_result(result) + if message: + await self._send_message(message) + if isinstance(result, Exception): + self._bump_errors() + + def connection_lost(self, exc): + # Cancel pending requests and message processing + self.connection.cancel_pending_requests() + super().connection_lost(exc) + + # External API + def default_connection(self): + '''Return a default connection if the user provides none.''' + return JSONRPCConnection(JSONRPCv2) + + def default_framer(self): + '''Return a default framer.''' + return NewlineFramer() + + async def handle_request(self, request): + pass + + async def send_request(self, method, args=()): + '''Send an RPC request over the network.''' + message, event = self.connection.send_request(Request(method, args)) + await self._send_message(message) + await event.wait() + result = event.result + if isinstance(result, Exception): + raise result + return result + + async def send_notification(self, method, args=()): + '''Send an RPC notification over the network.''' + message = self.connection.send_notification(Notification(method, args)) + await self._send_message(message) + + def send_batch(self, raise_errors=False): + '''Return a BatchRequest. Intended to be used like so: + + async with session.send_batch() as batch: + batch.add_request("method1") + batch.add_request("sum", (x, y)) + batch.add_notification("updated") + + for result in batch.results: + ... + + Note that in some circumstances exceptions can be raised; see + BatchRequest doc string. + ''' + return BatchRequest(self, raise_errors) + + +class Server(object): + '''A simple wrapper around an asyncio.Server object.''' + + def __init__(self, session_factory, host=None, port=None, *, + loop=None, **kwargs): + self.host = host + self.port = port + self.loop = loop or asyncio.get_event_loop() + self.server = None + self._session_factory = session_factory + self._kwargs = kwargs + + async def listen(self): + self.server = await self.loop.create_server( + self._session_factory, self.host, self.port, **self._kwargs) + + async def close(self): + '''Close the listening socket. This does not close any ServerSession + objects created to handle incoming connections. + ''' + if self.server: + self.server.close() + await self.server.wait_closed() + self.server = None diff --git a/torba/rpc/socks.py b/torba/rpc/socks.py new file mode 100644 index 000000000..cc4b63f13 --- /dev/null +++ b/torba/rpc/socks.py @@ -0,0 +1,439 @@ +# Copyright (c) 2018, Neil Booth +# +# All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +'''SOCKS proxying.''' + +import sys +import asyncio +import collections +import ipaddress +import socket +import struct +from functools import partial + + +__all__ = ('SOCKSUserAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy', + 'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure') + + +SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password") + + +class SOCKSError(Exception): + '''Base class for SOCKS exceptions. Each raised exception will be + an instance of a derived class.''' + + +class SOCKSProtocolError(SOCKSError): + '''Raised when the proxy does not follow the SOCKS protocol''' + + +class SOCKSFailure(SOCKSError): + '''Raised when the proxy refuses or fails to make a connection''' + + +class NeedData(Exception): + pass + + +class SOCKSBase(object): + + @classmethod + def name(cls): + return cls.__name__ + + def __init__(self): + self._buffer = bytes() + self._state = self._start + + def _read(self, size): + if len(self._buffer) < size: + raise NeedData(size - len(self._buffer)) + result = self._buffer[:size] + self._buffer = self._buffer[size:] + return result + + def receive_data(self, data): + self._buffer += data + + def next_message(self): + return self._state() + + +class SOCKS4(SOCKSBase): + '''SOCKS4 protocol wrapper.''' + + # See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol + REPLY_CODES = { + 90: 'request granted', + 91: 'request rejected or failed', + 92: ('request rejected because SOCKS server cannot connect ' + 'to identd on the client'), + 93: ('request rejected because the client program and identd ' + 'report different user-ids') + } + + def __init__(self, dst_host, dst_port, auth): + super().__init__() + self._dst_host = self._check_host(dst_host) + self._dst_port = dst_port + self._auth = auth + + @classmethod + def _check_host(cls, host): + if not isinstance(host, ipaddress.IPv4Address): + try: + host = ipaddress.IPv4Address(host) + except ValueError: + raise SOCKSProtocolError( + f'SOCKS4 requires an IPv4 address: {host}') from None + return host + + def _start(self): + self._state = self._first_response + + if isinstance(self._dst_host, ipaddress.IPv4Address): + # SOCKS4 + dst_ip_packed = self._dst_host.packed + host_bytes = b'' + else: + # SOCKS4a + dst_ip_packed = b'\0\0\0\1' + host_bytes = self._dst_host.encode() + b'\0' + + if isinstance(self._auth, SOCKSUserAuth): + user_id = self._auth.username.encode() + else: + user_id = b'' + + # Send TCP/IP stream CONNECT request + return b''.join([b'\4\1', struct.pack('>H', self._dst_port), + dst_ip_packed, user_id, b'\0', host_bytes]) + + def _first_response(self): + # Wait for 8-byte response + data = self._read(8) + if data[0] != 0: + raise SOCKSProtocolError(f'invalid {self.name()} proxy ' + f'response: {data}') + reply_code = data[1] + if reply_code != 90: + msg = self.REPLY_CODES.get( + reply_code, f'unknown {self.name()} reply code {reply_code}') + raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}') + + # Other fields ignored + return None + + +class SOCKS4a(SOCKS4): + + @classmethod + def _check_host(cls, host): + if not isinstance(host, (str, ipaddress.IPv4Address)): + raise SOCKSProtocolError( + f'SOCKS4a requires an IPv4 address or host name: {host}') + return host + + +class SOCKS5(SOCKSBase): + '''SOCKS protocol wrapper.''' + + # See https://tools.ietf.org/html/rfc1928 + ERROR_CODES = { + 1: 'general SOCKS server failure', + 2: 'connection not allowed by ruleset', + 3: 'network unreachable', + 4: 'host unreachable', + 5: 'connection refused', + 6: 'TTL expired', + 7: 'command not supported', + 8: 'address type not supported', + } + + def __init__(self, dst_host, dst_port, auth): + super().__init__() + self._dst_bytes = self._destination_bytes(dst_host, dst_port) + self._auth_bytes, self._auth_methods = self._authentication(auth) + + def _destination_bytes(self, host, port): + if isinstance(host, ipaddress.IPv4Address): + addr_bytes = b'\1' + host.packed + elif isinstance(host, ipaddress.IPv6Address): + addr_bytes = b'\4' + host.packed + elif isinstance(host, str): + host = host.encode() + if len(host) > 255: + raise SOCKSProtocolError(f'hostname too long: ' + f'{len(host)} bytes') + addr_bytes = b'\3' + bytes([len(host)]) + host + else: + raise SOCKSProtocolError(f'SOCKS5 requires an IPv4 address, IPv6 ' + f'address, or host name: {host}') + return addr_bytes + struct.pack('>H', port) + + def _authentication(self, auth): + if isinstance(auth, SOCKSUserAuth): + user_bytes = auth.username.encode() + if not 0 < len(user_bytes) < 256: + raise SOCKSProtocolError(f'username {auth.username} has ' + f'invalid length {len(user_bytes)}') + pwd_bytes = auth.password.encode() + if not 0 < len(pwd_bytes) < 256: + raise SOCKSProtocolError(f'password has invalid length ' + f'{len(pwd_bytes)}') + return b''.join([bytes([1, len(user_bytes)]), user_bytes, + bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2] + return b'', [0] + + def _start(self): + self._state = self._first_response + return (b'\5' + bytes([len(self._auth_methods)]) + + bytes(m for m in self._auth_methods)) + + def _first_response(self): + # Wait for 2-byte response + data = self._read(2) + if data[0] != 5: + raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}') + if data[1] not in self._auth_methods: + raise SOCKSFailure('SOCKS5 proxy rejected authentication methods') + + # Authenticate if user-password authentication + if data[1] == 2: + self._state = self._auth_response + return self._auth_bytes + return self._request_connection() + + def _auth_response(self): + data = self._read(2) + if data[0] != 1: + raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth ' + f'response: {data}') + if data[1] != 0: + raise SOCKSFailure(f'SOCKS5 proxy auth failure code: ' + f'{data[1]}') + + return self._request_connection() + + def _request_connection(self): + # Send connection request + self._state = self._connect_response + return b'\5\1\0' + self._dst_bytes + + def _connect_response(self): + data = self._read(5) + if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4): + raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}') + if data[1] != 0: + raise SOCKSFailure(self.ERROR_CODES.get( + data[1], f'unknown SOCKS5 error code: {data[1]}')) + + if data[3] == 1: + addr_len = 3 # IPv4 + elif data[3] == 3: + addr_len = data[4] # Hostname + else: + addr_len = 15 # IPv6 + + self._state = partial(self._connect_response_rest, addr_len) + return self.next_message() + + def _connect_response_rest(self, addr_len): + self._read(addr_len + 2) + return None + + +class SOCKSProxy(object): + + def __init__(self, address, protocol, auth): + '''A SOCKS proxy at an address following a SOCKS protocol. auth is an + authentication method to use when connecting, or None. + + address is a (host, port) pair; for IPv6 it can instead be a + (host, port, flowinfo, scopeid) 4-tuple. + ''' + self.address = address + self.protocol = protocol + self.auth = auth + # Set on each successful connection via the proxy to the + # result of socket.getpeername() + self.peername = None + + def __str__(self): + auth = 'username' if self.auth else 'none' + return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}' + + async def _handshake(self, client, sock, loop): + while True: + count = 0 + try: + message = client.next_message() + except NeedData as e: + count = e.args[0] + else: + if message is None: + return + await loop.sock_sendall(sock, message) + + if count: + data = await loop.sock_recv(sock, count) + if not data: + raise SOCKSProtocolError("EOF received") + client.receive_data(data) + + async def _connect_one(self, host, port): + '''Connect to the proxy and perform a handshake requesting a + connection to (host, port). + + Return the open socket on success, or the exception on failure. + ''' + client = self.protocol(host, port, self.auth) + sock = socket.socket() + loop = asyncio.get_event_loop() + try: + # A non-blocking socket is required by loop socket methods + sock.setblocking(False) + await loop.sock_connect(sock, self.address) + await self._handshake(client, sock, loop) + self.peername = sock.getpeername() + return sock + except Exception as e: + # Don't close - see https://github.com/kyuupichan/aiorpcX/issues/8 + if sys.platform.startswith('linux'): + sock.close() + return e + + async def _connect(self, addresses): + '''Connect to the proxy and perform a handshake requesting a + connection to each address in addresses. + + Return an (open_socket, address) pair on success. + ''' + assert len(addresses) > 0 + + exceptions = [] + for address in addresses: + host, port = address[:2] + sock = await self._connect_one(host, port) + if isinstance(sock, socket.socket): + return sock, address + exceptions.append(sock) + + strings = set(f'{exc!r}' for exc in exceptions) + raise (exceptions[0] if len(strings) == 1 else + OSError(f'multiple exceptions: {", ".join(strings)}')) + + async def _detect_proxy(self): + '''Return True if it appears we can connect to a SOCKS proxy, + otherwise False. + ''' + if self.protocol is SOCKS4a: + host, port = 'www.apple.com', 80 + else: + host, port = ipaddress.IPv4Address('8.8.8.8'), 53 + + sock = await self._connect_one(host, port) + if isinstance(sock, socket.socket): + sock.close() + return True + + # SOCKSFailure indicates something failed, but that we are + # likely talking to a proxy + return isinstance(sock, SOCKSFailure) + + @classmethod + async def auto_detect_address(cls, address, auth): + '''Try to detect a SOCKS proxy at address using the authentication + method (or None). SOCKS5, SOCKS4a and SOCKS are tried in + order. If a SOCKS proxy is detected a SOCKSProxy object is + returned. + + Returning a SOCKSProxy does not mean it is functioning - for + example, it may have no network connectivity. + + If no proxy is detected return None. + ''' + for protocol in (SOCKS5, SOCKS4a, SOCKS4): + proxy = cls(address, protocol, auth) + if await proxy._detect_proxy(): + return proxy + return None + + @classmethod + async def auto_detect_host(cls, host, ports, auth): + '''Try to detect a SOCKS proxy on a host on one of the ports. + + Calls auto_detect for the ports in order. Returns SOCKS are + tried in order; a SOCKSProxy object for the first detected + proxy is returned. + + Returning a SOCKSProxy does not mean it is functioning - for + example, it may have no network connectivity. + + If no proxy is detected return None. + ''' + for port in ports: + address = (host, port) + proxy = await cls.auto_detect_address(address, auth) + if proxy: + return proxy + + return None + + async def create_connection(self, protocol_factory, host, port, *, + resolve=False, ssl=None, + family=0, proto=0, flags=0): + '''Set up a connection to (host, port) through the proxy. + + If resolve is True then host is resolved locally with + getaddrinfo using family, proto and flags, otherwise the proxy + is asked to resolve host. + + The function signature is similar to loop.create_connection() + with the same result. The attribute _address is set on the + protocol to the address of the successful remote connection. + Additionally raises SOCKSError if something goes wrong with + the proxy handshake. + ''' + loop = asyncio.get_event_loop() + if resolve: + infos = await loop.getaddrinfo(host, port, family=family, + type=socket.SOCK_STREAM, + proto=proto, flags=flags) + addresses = [info[4] for info in infos] + else: + addresses = [(host, port)] + + sock, address = await self._connect(addresses) + + def set_address(): + protocol = protocol_factory() + protocol._address = address + return protocol + + return await loop.create_connection( + set_address, sock=sock, ssl=ssl, + server_hostname=host if ssl else None) diff --git a/torba/rpc/util.py b/torba/rpc/util.py new file mode 100644 index 000000000..07388e7a2 --- /dev/null +++ b/torba/rpc/util.py @@ -0,0 +1,120 @@ +# Copyright (c) 2018, Neil Booth +# +# All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sublicense, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +__all__ = () + + +import asyncio +from collections import namedtuple +from functools import partial +import inspect + + +def normalize_corofunc(corofunc, args): + if asyncio.iscoroutine(corofunc): + if args != (): + raise ValueError('args cannot be passed with a coroutine') + return corofunc + return corofunc(*args) + + +def is_async_call(func): + '''inspect.iscoroutinefunction that looks through partials.''' + while isinstance(func, partial): + func = func.func + return inspect.iscoroutinefunction(func) + + +# other_params: None means cannot be called with keyword arguments only +# any means any name is good +SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args ' + 'required_names other_names') + + +def signature_info(func): + params = inspect.signature(func).parameters + min_args = max_args = 0 + required_names = [] + other_names = [] + no_names = False + for p in params.values(): + if p.kind == p.POSITIONAL_OR_KEYWORD: + max_args += 1 + if p.default is p.empty: + min_args += 1 + required_names.append(p.name) + else: + other_names.append(p.name) + elif p.kind == p.KEYWORD_ONLY: + other_names.append(p.name) + elif p.kind == p.VAR_POSITIONAL: + max_args = None + elif p.kind == p.VAR_KEYWORD: + other_names = any + elif p.kind == p.POSITIONAL_ONLY: + max_args += 1 + if p.default is p.empty: + min_args += 1 + no_names = True + + if no_names: + other_names = None + + return SignatureInfo(min_args, max_args, required_names, other_names) + + +class Concurrency(object): + + def __init__(self, max_concurrent): + self._require_non_negative(max_concurrent) + self._max_concurrent = max_concurrent + self.semaphore = asyncio.Semaphore(max_concurrent) + + def _require_non_negative(self, value): + if not isinstance(value, int) or value < 0: + raise RuntimeError('concurrency must be a natural number') + + @property + def max_concurrent(self): + return self._max_concurrent + + async def set_max_concurrent(self, value): + self._require_non_negative(value) + diff = value - self._max_concurrent + self._max_concurrent = value + if diff >= 0: + for _ in range(diff): + self.semaphore.release() + else: + for _ in range(-diff): + await self.semaphore.acquire() + + +def check_task(logger, task): + if not task.cancelled(): + try: + task.result() + except Exception: + logger.error('task crashed: %r', task, exc_info=True) diff --git a/torba/server/block_processor.py b/torba/server/block_processor.py index 2cbafe041..fe2b44bea 100644 --- a/torba/server/block_processor.py +++ b/torba/server/block_processor.py @@ -15,7 +15,7 @@ from struct import pack, unpack import time from functools import partial -from aiorpcx import TaskGroup, run_in_thread +from torba.rpc import TaskGroup, run_in_thread import torba from torba.server.daemon import DaemonError diff --git a/torba/server/daemon.py b/torba/server/daemon.py index dd53ae63c..dba43abd8 100644 --- a/torba/server/daemon.py +++ b/torba/server/daemon.py @@ -22,7 +22,7 @@ from torba.server.util import hex_to_bytes, class_logger,\ unpack_le_uint16_from, pack_varint from torba.server.hash import hex_str_to_hash, hash_to_hex_str from torba.server.tx import DeserializerDecred -from aiorpcx import JSONRPC +from torba.rpc import JSONRPC class DaemonError(Exception): diff --git a/torba/server/db.py b/torba/server/db.py index d1af32592..2e843ee62 100644 --- a/torba/server/db.py +++ b/torba/server/db.py @@ -19,8 +19,8 @@ from glob import glob from struct import pack, unpack import attr -from aiorpcx import run_in_thread, sleep +from torba.rpc import run_in_thread, sleep from torba.server import util from torba.server.hash import hash_to_hex_str, HASHX_LEN from torba.server.merkle import Merkle, MerkleCache diff --git a/torba/server/mempool.py b/torba/server/mempool.py index 1f725aa5b..d5ecd6804 100644 --- a/torba/server/mempool.py +++ b/torba/server/mempool.py @@ -14,8 +14,8 @@ from asyncio import Lock from collections import defaultdict import attr -from aiorpcx import TaskGroup, run_in_thread, sleep +from torba.rpc import TaskGroup, run_in_thread, sleep from torba.server.hash import hash_to_hex_str, hex_str_to_hash from torba.server.util import class_logger, chunks from torba.server.db import UTXO diff --git a/torba/server/merkle.py b/torba/server/merkle.py index a96921936..f0c4d9a7a 100644 --- a/torba/server/merkle.py +++ b/torba/server/merkle.py @@ -28,8 +28,7 @@ from math import ceil, log -from aiorpcx import Event - +from torba.rpc import Event from torba.server.hash import double_sha256 diff --git a/torba/server/peers.py b/torba/server/peers.py index d24810fde..d77d0de39 100644 --- a/torba/server/peers.py +++ b/torba/server/peers.py @@ -14,11 +14,10 @@ import ssl import time from collections import defaultdict, Counter -from aiorpcx import (Connector, RPCSession, SOCKSProxy, +from torba.rpc import (Connector, RPCSession, SOCKSProxy, Notification, handler_invocation, SOCKSError, RPCError, TaskTimeout, TaskGroup, Event, sleep, run_in_thread, ignore_after, timeout_after) - from torba.server.peer import Peer from torba.server.util import class_logger, protocol_tuple diff --git a/torba/server/session.py b/torba/server/session.py index 14e4668ef..af38d7725 100644 --- a/torba/server/session.py +++ b/torba/server/session.py @@ -19,13 +19,12 @@ import time from collections import defaultdict from functools import partial -from aiorpcx import ( +import torba +from torba.rpc import ( RPCSession, JSONRPCAutoDetect, JSONRPCConnection, TaskGroup, handler_invocation, RPCError, Request, ignore_after, sleep, Event ) - -import torba from torba.server import text from torba.server import util from torba.server.hash import (sha256, hash_to_hex_str, hex_str_to_hash, @@ -664,9 +663,9 @@ class SessionBase(RPCSession): super().connection_lost(exc) self.session_mgr.remove_session(self) msg = '' - if not self.can_send.is_set(): + if not self._can_send.is_set(): msg += ' whilst paused' - if self.concurrency.max_concurrent != self.max_concurrent: + if self._concurrency.max_concurrent != self.max_concurrent: msg += ' whilst throttled' if self.send_size >= 1024*1024: msg += ('. Sent {:,d} bytes in {:,d} messages'