diff --git a/lbry/tests/integration/test_wallet_server_sessions.py b/lbry/tests/integration/test_wallet_server_sessions.py index b2e9c779d..451006d48 100644 --- a/lbry/tests/integration/test_wallet_server_sessions.py +++ b/lbry/tests/integration/test_wallet_server_sessions.py @@ -22,7 +22,6 @@ class TestSessionBloat(IntegrationTestCase): await self.conductor.start_spv() session = ClientSession(network=None, server=self.ledger.network.client.server, timeout=0.2) await session.create_connection() - session.ping_task.cancel() await session.send_request('server.banner', ()) self.assertEqual(len(self.conductor.spv_node.server.session_mgr.sessions), 1) self.assertFalse(session.is_closing()) diff --git a/torba/torba/client/basenetwork.py b/torba/torba/client/basenetwork.py index 25c858dc4..7ba0fb215 100644 --- a/torba/torba/client/basenetwork.py +++ b/torba/torba/client/basenetwork.py @@ -1,9 +1,7 @@ import logging import asyncio -from asyncio import CancelledError +from typing import Dict from time import time -from typing import List -import socket from torba.rpc import RPCSession as BaseClientSession, Connector, RPCError @@ -24,11 +22,20 @@ class ClientSession(BaseClientSession): self.bw_limit = self.framer.max_size = self.max_errors = 1 << 32 self.timeout = timeout self.max_seconds_idle = timeout * 2 - self.ping_task = None + self.latency = 1 << 32 + + @property + def available(self): + return not self.is_closing() and self._can_send.is_set() async def send_request(self, method, args=()): try: - return await asyncio.wait_for(super().send_request(method, args), timeout=self.timeout) + start = time() + result = await asyncio.wait_for( + super().send_request(method, args), timeout=self.timeout + ) + self.latency = time() - start + return result except RPCError as e: log.warning("Wallet server returned an error. Code: %s Message: %s", *e.args) raise e @@ -36,21 +43,29 @@ class ClientSession(BaseClientSession): self.abort() raise - async def ping_forever(self): + async def ensure_session(self): + # Handles reconnecting and maintaining a session alive # TODO: change to 'ping' on newer protocol (above 1.2) - while not self.is_closing(): - if (time() - self.last_send) > self.max_seconds_idle: - try: + retry_delay = 1.0 + while True: + try: + if self.is_closing(): + await self.create_connection(self.timeout) + await self.ensure_server_version() + if (time() - self.last_send) > self.max_seconds_idle: await self.send_request('server.banner') - except: - self.abort() - raise - await asyncio.sleep(self.max_seconds_idle//3) + retry_delay = 1.0 + except asyncio.TimeoutError: + log.warning("Wallet server timeout (retry in %s seconds): %s:%d", retry_delay, *self.server) + retry_delay = max(60, retry_delay * 2) + await asyncio.sleep(retry_delay) + + def ensure_server_version(self, required='1.2'): + return self.send_request('server.version', [__version__, required]) async def create_connection(self, timeout=6): connector = Connector(lambda: self, *self.server) await asyncio.wait_for(connector.create_connection(), timeout=timeout) - self.ping_task = asyncio.create_task(self.ping_forever()) async def handle_request(self, request): controller = self.network.subscription_controllers[request.method] @@ -58,9 +73,8 @@ class ClientSession(BaseClientSession): def connection_lost(self, exc): super().connection_lost(exc) + self.latency = 1 << 32 self._on_disconnect_controller.add(True) - if self.ping_task: - self.ping_task.cancel() class BaseNetwork: @@ -92,26 +106,20 @@ class BaseNetwork: self.session_pool = SessionPool(network=self, timeout=connect_timeout) self.session_pool.start(self.config['default_servers']) self.on_header.listen(self._update_remote_height) - while True: + while self.running: try: - self.client = await self.pick_fastest_session() - if self.is_connected: - await self.ensure_server_version() - self._update_remote_height((await self.subscribe_headers(),)) - log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) - self._on_connected_controller.add(True) - await self.client.on_disconnected.first - except CancelledError: - self.running = False + self.client = await self.session_pool.wait_for_fastest_session() + self._update_remote_height((await self.subscribe_headers(),)) + log.info("Successfully connected to SPV wallet server: %s:%d", *self.client.server) + self._on_connected_controller.add(True) + await self.client.on_disconnected.first + except asyncio.CancelledError: + await self.stop() + raise except asyncio.TimeoutError: - log.warning("Timed out while trying to find a server!") + pass except Exception: # pylint: disable=broad-except log.exception("Exception while trying to find a server!") - if not self.running: - return - elif self.client: - await self.client.close() - self.client.connection.cancel_pending_requests() async def stop(self): self.running = False @@ -124,25 +132,15 @@ class BaseNetwork: @property def is_connected(self): - return self.client is not None and not self.client.is_closing() + return self.session_pool.online - def rpc(self, list_or_method, args): + async def rpc(self, list_or_method, args): if self.is_connected: - return self.client.send_request(list_or_method, args) + await self.session_pool.wait_for_fastest_session() + return await self.session_pool.fastest_session.send_request(list_or_method, args) else: raise ConnectionError("Attempting to send rpc request when connection is not available.") - async def pick_fastest_session(self): - sessions = await self.session_pool.get_online_sessions() - done, pending = await asyncio.wait([ - self.probe_session(session) - for session in sessions if not session.is_closing() - ], return_when='FIRST_COMPLETED') - for task in pending: - task.cancel() - for session in done: - return await session - async def probe_session(self, session: ClientSession): await session.send_request('server.banner') return session @@ -150,9 +148,6 @@ class BaseNetwork: def _update_remote_height(self, header_args): self.remote_height = header_args[0]["height"] - def ensure_server_version(self, required='1.2'): - return self.rpc('server.version', [__version__, required]) - def broadcast(self, raw_transaction): return self.rpc('blockchain.transaction.broadcast', [raw_transaction]) @@ -172,83 +167,62 @@ class BaseNetwork: return self.rpc('blockchain.block.headers', [height, count]) def subscribe_headers(self): - return self.rpc('blockchain.headers.subscribe', [True]) + return self.client.send_request('blockchain.headers.subscribe', [True]) def subscribe_address(self, address): - return self.rpc('blockchain.address.subscribe', [address]) + return self.client.send_request('blockchain.address.subscribe', [address]) class SessionPool: def __init__(self, network: BaseNetwork, timeout: float): self.network = network - self.sessions: List[ClientSession] = [] - self._dead_servers: List[ClientSession] = [] + self.sessions: Dict[ClientSession, asyncio.Task] = dict() self.maintain_connections_task = None self.timeout = timeout - # triggered when the master server is out, to speed up reconnect - self._lost_master = asyncio.Event() @property def online(self): - for session in self.sessions: - if not session.is_closing(): - return True - return False + return any(not session.is_closing() for session in self.sessions) + + @property + def available_sessions(self): + return [session for session in self.sessions if session.available] + + @property + def fastest_session(self): + if not self.available_sessions: + return None + return min([(session.latency, session) for session in self.available_sessions])[1] def start(self, default_servers): - self.sessions = [ - ClientSession(network=self.network, server=server) - for server in default_servers - ] + for server in default_servers: + session = ClientSession(network=self.network, server=server) + self.sessions[session] = asyncio.create_task(session.ensure_session()) self.maintain_connections_task = asyncio.create_task(self.ensure_connections()) def stop(self): if self.maintain_connections_task: self.maintain_connections_task.cancel() - for session in self.sessions: + self.maintain_connections_task = None + for session, maintenance_task in self.sessions.items(): + maintenance_task.cancel() if not session.is_closing(): session.abort() - self.sessions, self._dead_servers, self.maintain_connections_task = [], [], None + self.sessions.clear() async def ensure_connections(self): while True: - await asyncio.gather(*[ - self.ensure_connection(session) - for session in self.sessions - ], return_exceptions=True) - try: - await asyncio.wait_for(self._lost_master.wait(), timeout=3) - except asyncio.TimeoutError: - pass - self._lost_master.clear() - if not self.sessions: - self.sessions.extend(self._dead_servers) - self._dead_servers = [] + log.info("Checking conns") + for session, task in list(self.sessions.items()): + if task.done(): + self.sessions[session] = asyncio.create_task(session.ensure_session()) + await asyncio.wait(self.sessions.items(), timeout=10) - async def ensure_connection(self, session): - self._dead_servers.append(session) - self.sessions.remove(session) - try: - if session.is_closing(): - await session.create_connection(self.timeout) - await asyncio.wait_for(session.send_request('server.banner'), timeout=self.timeout) - self.sessions.append(session) - self._dead_servers.remove(session) - except asyncio.TimeoutError: - log.warning("Timeout connecting to %s:%d", *session.server) - except asyncio.CancelledError: # pylint: disable=try-except-raise - raise - except socket.gaierror: - log.warning("Could not resolve IP for %s", session.server[0]) - except Exception as err: # pylint: disable=broad-except - if 'Connect call failed' in str(err): - log.warning("Could not connect to %s:%d", *session.server) + async def wait_for_fastest_session(self): + while True: + fastest = self.fastest_session + if fastest: + return fastest else: - log.exception("Connecting to %s:%d raised an exception:", *session.server) - - async def get_online_sessions(self): - while not self.online: - self._lost_master.set() - await asyncio.sleep(0.5) - return self.sessions + await asyncio.sleep(0.5)