diff --git a/lbry/tests/integration/test_wallet_server_sessions.py b/lbry/tests/integration/test_wallet_server_sessions.py index b07595b3f..618d71b57 100644 --- a/lbry/tests/integration/test_wallet_server_sessions.py +++ b/lbry/tests/integration/test_wallet_server_sessions.py @@ -22,7 +22,6 @@ class TestSessionBloat(IntegrationTestCase): await self.conductor.start_spv() session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2) await session.create_connection() - session.ping_task.cancel() await session.send_request('server.banner', ()) self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1) self.assertFalse(session.is_closing()) diff --git a/torba/tests/client_tests/integration/test_network.py b/torba/tests/client_tests/integration/test_network.py index b5b4643f1..8d0faed2a 100644 --- a/torba/tests/client_tests/integration/test_network.py +++ b/torba/tests/client_tests/integration/test_network.py @@ -22,31 +22,40 @@ class ReconnectTests(IntegrationTestCase): async def test_connection_drop_still_receives_events_after_reconnected(self): address1 = await self.account.receiving.get_or_create_usable_address() + # disconnect and send a new tx, should reconnect and get it self.ledger.network.client.connection_lost(Exception()) + self.assertFalse(self.ledger.network.is_connected) sendtxid = await self.blockchain.send_to_address(address1, 1.1337) - await self.on_transaction_id(sendtxid) # mempool + await asyncio.wait_for(self.on_transaction_id(sendtxid), 1.0) # mempool await self.blockchain.generate(1) await self.on_transaction_id(sendtxid) # confirmed + self.assertLess(self.ledger.network.client.response_time, 1) # response time properly set lower, we are fine await self.assertBalance(self.account, '1.1337') # is it real? are we rich!? let me see this tx... d = self.ledger.network.get_transaction(sendtxid) # what's that smoke on my ethernet cable? oh no! self.ledger.network.client.connection_lost(Exception()) - with self.assertRaises(asyncio.CancelledError): + with self.assertRaises(asyncio.TimeoutError): await d + self.assertIsNone(self.ledger.network.client.response_time) # response time unknown as it failed # rich but offline? no way, no water, let's retry with self.assertRaisesRegex(ConnectionError, 'connection is not available'): await self.ledger.network.get_transaction(sendtxid) # * goes to pick some water outside... * time passes by and another donation comes in sendtxid = await self.blockchain.send_to_address(address1, 42) await self.blockchain.generate(1) + # (this is just so the test doesnt hang forever if it doesnt reconnect) + if not self.ledger.network.is_connected: + await asyncio.wait_for(self.ledger.network.on_connected.first, timeout=1.0) # omg, the burned cable still works! torba is fire proof! await self.ledger.network.get_transaction(sendtxid) async def test_timeout_then_reconnect(self): + # tests that it connects back after some failed attempts await self.conductor.spv_node.stop() self.assertFalse(self.ledger.network.is_connected) + await asyncio.sleep(0.2) # let it retry and fail once await self.conductor.spv_node.start(self.conductor.blockchain_node) await self.ledger.network.on_connected.first self.assertTrue(self.ledger.network.is_connected) @@ -79,9 +88,9 @@ class ServerPickingTestCase(AsyncioTestCase): await self._make_bad_server(), ('localhost', 1), ('example.that.doesnt.resolve', 9000), - await self._make_fake_server(latency=1.2, port=1340), - await self._make_fake_server(latency=0.5, port=1337), - await self._make_fake_server(latency=0.7, port=1339), + await self._make_fake_server(latency=1.0, port=1340), + await self._make_fake_server(latency=0.1, port=1337), + await self._make_fake_server(latency=0.4, port=1339), ], 'connect_timeout': 3 }) @@ -89,9 +98,10 @@ class ServerPickingTestCase(AsyncioTestCase): network = BaseNetwork(ledger) self.addCleanup(network.stop) asyncio.ensure_future(network.start()) - await asyncio.wait_for(network.on_connected.first, timeout=3) + await asyncio.wait_for(network.on_connected.first, timeout=1) self.assertTrue(network.is_connected) self.assertEqual(network.client.server, ('127.0.0.1', 1337)) - # ensure we are connected to all of them - self.assertTrue(all([not session.is_closing() for session in network.session_pool.sessions])) - self.assertEqual(len(network.session_pool.sessions), 3) + self.assertTrue(all([not session.is_closing() for session in network.session_pool.available_sessions])) + # ensure we are connected to all of them after a while + await asyncio.sleep(1) + self.assertEqual(len(network.session_pool.available_sessions), 3) diff --git a/torba/tests/client_tests/unit/test_stream_controller.py b/torba/tests/client_tests/unit/test_stream_controller.py new file mode 100644 index 000000000..70c20fd3a --- /dev/null +++ b/torba/tests/client_tests/unit/test_stream_controller.py @@ -0,0 +1,20 @@ +from torba.stream import StreamController +from torba.testcase import AsyncioTestCase + + +class StreamControllerTestCase(AsyncioTestCase): + def test_non_unique_events(self): + events = [] + controller = StreamController() + controller.stream.listen(on_data=events.append) + controller.add("yo") + controller.add("yo") + self.assertEqual(events, ["yo", "yo"]) + + def test_unique_events(self): + events = [] + controller = StreamController(merge_repeated_events=True) + controller.stream.listen(on_data=events.append) + controller.add("yo") + controller.add("yo") + self.assertEqual(events, ["yo"]) diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 25c858dc4..39ed1fbdf 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -1,9 +1,8 @@ import logging import asyncio -from asyncio import CancelledError -from time import time -from typing import List -import socket +from operator import itemgetter +from typing import Dict, Optional +from time import time, perf_counter from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError @@ -15,7 +14,7 @@ log = logging.getLogger(__name__) class ClientSession(BaseClientSession): - def __init__(self, *args, network, server, timeout=30, **kwargs): + def __init__(self, *args, network, server, timeout=30, on_connect_callback=None, **kwargs): self.network = network self.server = server super().__init__(*args, **kwargs) @@ -24,61 +23,88 @@ class ClientSession(BaseClientSession): self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 self.timeout = timeout self.max_seconds_idle = timeout * 2 - self.ping_task = None + self.response_time: Optional[float] = None + self._on_connect_cb = on_connect_callback or (lambda: None) + self.trigger_urgent_reconnect = asyncio.Event() + + @property + def available(self): + return not self.is_closing() and self._can_send.is_set() and self.response_time is not None async def send_request(self, method, args=()): try: - return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout) + start = perf_counter() + result = await asyncio.wait_for( + super().send_request(method, args), timeout=self.timeout + ) + self.response_time = perf_counter() - start + return result except RPCError as e: log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) raise e - except asyncio.TimeoutError: - self.abort() + except TimeoutError: + self.response_time = None raise - async def ping_forever(self): + async def ensure_session(self): + # Handles reconnecting and maintaining a session alive # TODO: change to 'ping' on newer protocol (above 1.2) - while not self.is_closing(): - if (time() - self.last_send) > self.max_seconds_idle: - try: + retry_delay = default_delay = 0.1 + while True: + try: + if self.is_closing(): + await self.create_connection(self.timeout) + await self.ensure_server_version() + self._on_connect_cb() + if (time() - self.last_send) > self.max_seconds_idle or self.response_time is None: await self.send_request('server.banner') - except: - self.abort() - raise - await asyncio.sleep(self.max_seconds_idle//3) + retry_delay = default_delay + except (asyncio.TimeoutError, OSError): + await self.close() + retry_delay = min(60, retry_delay * 2) + log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) + try: + await asyncio.wait_for(self.trigger_urgent_reconnect.wait(), timeout=retry_delay) + except asyncio.TimeoutError: + pass + finally: + self.trigger_urgent_reconnect.clear() + + def ensure_server_version(self, required='1.2'): + return self.send_request('server.version', [__version__, required]) async def create_connection(self, timeout=6): connector = Connector(lambda: self, *self.server) await asyncio.wait_for(connector.create_connection(), timeout=timeout) - self.ping_task = asyncio.create_task(self.ping_forever()) async def handle_request(self, request): controller = self.network.subscription_controllers[request.method] controller.add(request.args) def connection_lost(self, exc): + log.debug("Connection lost: %s:%d", *self.server) super().connection_lost(exc) + self.response_time = None self._on_disconnect_controller.add(True) - if self.ping_task: - self.ping_task.cancel() class BaseNetwork: def __init__(self, ledger): + self.switch_event = asyncio.Event() self.config = ledger.config - self.client: ClientSession = None - self.session_pool: SessionPool = None + self.session_pool = SessionPool(network=self, timeout=self.config.get('connect_timeout', 6)) + self.client: Optional[ClientSession] = None self.running = False self.remote_height: int = 0 self._on_connected_controller = StreamController() self.on_connected = self._on_connected_controller.stream - self._on_header_controller = StreamController() + self._on_header_controller = StreamController(merge_repeated_events=True) self.on_header = self._on_header_controller.stream - self._on_status_controller = StreamController() + self._on_status_controller = StreamController(merge_repeated_events=True) self.on_status = self._on_status_controller.stream self.subscription_controllers = { @@ -88,30 +114,22 @@ class BaseNetwork: async def start(self): self.running = True - connect_timeout = self.config.get('connect_timeout', 6) - self.session_pool = SessionPool(network=self, timeout=connect_timeout) self.session_pool.start(self.config['default_servers']) self.on_header.listen(self._update_remote_height) - while True: + while self.running: try: - self.client = await self.pick_fastest_session() - if self.is_connected: - await self.ensure_server_version() - self._update_remote_height((await self.subscribe_headers(),)) - log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) - self._on_connected_controller.add(True) - await self.client.on_disconnected.first - except CancelledError: - self.running = False + self.client = await self.session_pool.wait_for_fastest_session() + self._update_remote_height((await self.subscribe_headers(),)) + log.info("Switching to SPV wallet server: %s:%d", *self.client.server) + self._on_connected_controller.add(True) + self.client.on_disconnected.listen(lambda _: self.switch_event.set()) + await self.switch_event.wait() + self.switch_event.clear() + except asyncio.CancelledError: + await self.stop() + raise except asyncio.TimeoutError: - log.warning("Timed out while trying to find a server!") - except Exception: # pylint: disable=broad-except - log.exception("Exception while trying to find a server!") - if not self.running: - return - elif self.client: - await self.client.close() - self.client.connection.cancel_pending_requests() + pass async def stop(self): self.running = False @@ -124,35 +142,21 @@ class BaseNetwork: @property def is_connected(self): - return self.client is not None and not self.client.is_closing() + return self.client and not self.client.is_closing() def rpc(self, list_or_method, args): + fastest = self.session_pool.fastest_session + if fastest is not None and self.client != fastest: + self.switch_event.set() if self.is_connected: return self.client.send_request(list_or_method, args) else: + self.session_pool.trigger_nodelay_connect() raise ConnectionError("Attempting to send rpc request when connection is not available.") - async def pick_fastest_session(self): - sessions = await self.session_pool.get_online_sessions() - done, pending = await asyncio.wait([ - self.probe_session(session) - for session in sessions if not session.is_closing() - ], return_when='FIRST_COMPLETED') - for task in pending: - task.cancel() - for session in done: - return await session - - async def probe_session(self, session: ClientSession): - await session.send_request('server.banner') - return session - def _update_remote_height(self, header_args): self.remote_height = header_args[0]["height"] - def ensure_server_version(self, required='1.2'): - return self.rpc('server.version', [__version__, required]) - def broadcast(self, raw_transaction): return self.rpc('blockchain.transaction.broadcast', [raw_transaction]) @@ -182,73 +186,57 @@ class SessionPool: def __init__(self, network: BaseNetwork, timeout: float): self.network = network - self.sessions: List[ClientSession] = [] - self._dead_servers: List[ClientSession] = [] - self.maintain_connections_task = None + self.sessions: Dict[ClientSession, Optional[asyncio.Task]] = dict() self.timeout = timeout - # triggered when the master server is out, to speed up reconnect - self._lost_master = asyncio.Event() + self.new_connection_event = asyncio.Event() @property def online(self): - for session in self.sessions: - if not session.is_closing(): - return True - return False + return any(not session.is_closing() for session in self.sessions) + + @property + def available_sessions(self): + return [session for session in self.sessions if session.available] + + @property + def fastest_session(self): + if not self.available_sessions: + return None + return min( + [(session.response_time, session) for session in self.available_sessions], key=itemgetter(0) + )[1] def start(self, default_servers): - self.sessions = [ - ClientSession(network=self.network, server=server) - for server in default_servers - ] - self.maintain_connections_task = asyncio.create_task(self.ensure_connections()) + callback = self.new_connection_event.set + self.sessions = { + ClientSession( + network=self.network, server=server, on_connect_callback=callback + ): None for server in default_servers + } + self.ensure_connections() def stop(self): - if self.maintain_connections_task: - self.maintain_connections_task.cancel() + for session, task in self.sessions.items(): + task.cancel() + session.abort() + self.sessions.clear() + + def ensure_connections(self): + for session, task in list(self.sessions.items()): + if not task or task.done(): + task = asyncio.create_task(session.ensure_session()) + task.add_done_callback(lambda _: self.ensure_connections()) + self.sessions[session] = task + + def trigger_nodelay_connect(self): + # used when other parts of the system sees we might have internet back + # bypasses the retry interval for session in self.sessions: - if not session.is_closing(): - session.abort() - self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None + session.trigger_urgent_reconnect.set() - async def ensure_connections(self): - while True: - await asyncio.gather(*[ - self.ensure_connection(session) - for session in self.sessions - ], return_exceptions=True) - try: - await asyncio.wait_for(self._lost_master.wait(), timeout=3) - except asyncio.TimeoutError: - pass - self._lost_master.clear() - if not self.sessions: - self.sessions.extend(self._dead_servers) - self._dead_servers = [] - - async def ensure_connection(self, session): - self._dead_servers.append(session) - self.sessions.remove(session) - try: - if session.is_closing(): - await session.create_connection(self.timeout) - await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout) - self.sessions.append(session) - self._dead_servers.remove(session) - except asyncio.TimeoutError: - log.warning("Timeout connecting to %s:%d", *session.server) - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except socket.gaierror: - log.warning("Could not resolve IP for %s", session.server[0]) - except Exception as err: # pylint: disable=broad-except - if 'Connect call failed' in str(err): - log.warning("Could not connect to %s:%d", *session.server) - else: - log.exception("Connecting to %s:%d raised an exception:", *session.server) - - async def get_online_sessions(self): - while not self.online: - self._lost_master.set() - await asyncio.sleep(0.5) - return self.sessions + async def wait_for_fastest_session(self): + while not self.fastest_session: + self.trigger_nodelay_connect() + self.new_connection_event.clear() + await self.new_connection_event.wait() + return self.fastest_session diff --git a/torba/torba/rpc/jsonrpc.py b/torba/torba/rpc/jsonrpc.py index 5e908cd02..4e5cca8ca 100644 --- a/torba/torba/rpc/jsonrpc.py +++ b/torba/torba/rpc/jsonrpc.py @@ -33,6 +33,7 @@ __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose', import itertools import json import typing +import asyncio from functools import partial from numbers import Number @@ -745,9 +746,10 @@ class JSONRPCConnection(object): self._protocol = item return self.receive_message(message) - def cancel_pending_requests(self): - """Cancel all pending requests.""" - exception = CancelledError() + def time_out_pending_requests(self): + """Times out all pending requests.""" + # this used to be CancelledError, but thats confusing as in are we closing the whole sdk or failing? + exception = asyncio.TimeoutError() for request, event in self._requests.values(): event.result = exception event.set() diff --git a/torba/torba/rpc/session.py b/torba/torba/rpc/session.py index c8b8c6945..e16b6bbb4 100644 --- a/torba/torba/rpc/session.py +++ b/torba/torba/rpc/session.py @@ -456,7 +456,7 @@ class RPCSession(SessionBase): def connection_lost(self, exc): # Cancel pending requests and message processing - self.connection.cancel_pending_requests() + self.connection.time_out_pending_requests() super().connection_lost(exc) # External API @@ -473,6 +473,8 @@ class RPCSession(SessionBase): async def send_request(self, method, args=()): """Send an RPC request over the network.""" + if self.is_closing(): + raise asyncio.TimeoutError("Trying to send request on a recently dropped connection.") message, event = self.connection.send_request(Request(method, args)) await self._send_message(message) await event.wait() diff --git a/torba/torba/stream.py b/torba/torba/stream.py index 40589ade0..412a94525 100644 --- a/torba/torba/stream.py +++ b/torba/torba/stream.py @@ -45,10 +45,12 @@ class BroadcastSubscription: class StreamController: - def __init__(self): + def __init__(self, merge_repeated_events=False): self.stream = Stream(self) self._first_subscription = None self._last_subscription = None + self._last_event = None + self._merge_repeated = merge_repeated_events @property def has_listener(self): @@ -76,8 +78,10 @@ class StreamController: return f def add(self, event): + skip = self._merge_repeated and event == self._last_event + self._last_event = event return self._notify_and_ensure_future( - lambda subscription: subscription._add(event) + lambda subscription: None if skip else subscription._add(event) ) def add_error(self, exception):