diff --git a/lbry/lbry/blob_exchange/server.py b/lbry/lbry/blob_exchange/server.py index 70234acb8..93867e2e7 100644 --- a/lbry/lbry/blob_exchange/server.py +++ b/lbry/lbry/blob_exchange/server.py @@ -12,25 +12,56 @@ if typing.TYPE_CHECKING: log = logging.getLogger(__name__) +# a standard request will be 295 bytes +MAX_REQUEST_SIZE = 1200 + class BlobServerProtocol(asyncio.Protocol): - def __init__(self, loop: asyncio.AbstractEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str): + def __init__(self, loop: asyncio.AbstractEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str, + idle_timeout: float = 30.0, transfer_timeout: float = 60.0): self.loop = loop self.blob_manager = blob_manager - self.server_task: asyncio.Task = None + self.idle_timeout = idle_timeout + self.transfer_timeout = transfer_timeout + self.server_task: typing.Optional[asyncio.Task] = None self.started_listening = asyncio.Event(loop=self.loop) self.buf = b'' - self.transport = None + self.transport: typing.Optional[asyncio.Transport] = None self.lbrycrd_address = lbrycrd_address self.peer_address_and_port: typing.Optional[str] = None + self.started_transfer = asyncio.Event(loop=self.loop) + self.transfer_finished = asyncio.Event(loop=self.loop) + self.close_on_idle_task: typing.Optional[asyncio.Task] = None + + async def close_on_idle(self): + while self.transport: + try: + await asyncio.wait_for(self.started_transfer.wait(), self.idle_timeout, loop=self.loop) + except asyncio.TimeoutError: + log.debug("closing idle connection from %s", self.peer_address_and_port) + return self.close() + self.started_transfer.clear() + await self.transfer_finished.wait() + self.transfer_finished.clear() + + def close(self): + if self.transport: + self.transport.close() def connection_made(self, transport): self.transport = transport + self.close_on_idle_task = self.loop.create_task(self.close_on_idle()) self.peer_address_and_port = "%s:%i" % self.transport.get_extra_info('peername') self.blob_manager.connection_manager.connection_received(self.peer_address_and_port) + log.debug("received connection from %s", self.peer_address_and_port) def connection_lost(self, exc: typing.Optional[Exception]) -> None: + log.debug("lost connection from %s", self.peer_address_and_port) self.blob_manager.connection_manager.incoming_connection_lost(self.peer_address_and_port) + self.transport = None + if self.close_on_idle_task and not self.close_on_idle_task.done(): + self.close_on_idle_task.cancel() + self.close_on_idle_task = None def send_response(self, responses: typing.List[blob_response_types]): to_send = [] @@ -65,20 +96,30 @@ class BlobServerProtocol(asyncio.Protocol): incoming_blob = {'blob_hash': blob.blob_hash, 'length': blob.length} responses.append(BlobDownloadResponse(incoming_blob=incoming_blob)) self.send_response(responses) - log.debug("send %s to %s:%i", blob.blob_hash[:8], peer_address, peer_port) + bh = blob.blob_hash[:8] + log.debug("send %s to %s:%i", bh, peer_address, peer_port) + self.started_transfer.set() try: - sent = await blob.sendfile(self) - except (ConnectionResetError, BrokenPipeError, RuntimeError, OSError): - if self.transport: - self.transport.close() - return - log.info("sent %s (%i bytes) to %s:%i", blob.blob_hash[:8], sent, peer_address, peer_port) + sent = await asyncio.wait_for(blob.sendfile(self), self.transfer_timeout, loop=self.loop) + self.blob_manager.connection_manager.sent_data(self.peer_address_and_port, sent) + log.debug("sent %s (%i bytes) to %s:%i", bh, sent, peer_address, peer_port) + except (ConnectionResetError, BrokenPipeError, RuntimeError, OSError, asyncio.TimeoutError) as err: + if isinstance(err, asyncio.TimeoutError): + log.debug("timed out sending blob %s to %s", bh, peer_address) + else: + log.debug("stopped sending %s to %s:%i", bh, peer_address, peer_port) + self.close() + finally: + self.transfer_finished.set() if responses: self.send_response(responses) - # self.transport.close() def data_received(self, data): request = None + if len(self.buf) + len(data or b'') >= MAX_REQUEST_SIZE: + log.warning("request from %s is too large", self.peer_address_and_port) + self.close() + return if data: self.blob_manager.connection_manager.received_data(self.peer_address_and_port, len(data)) message, separator, remainder = data.rpartition(b'}') @@ -89,26 +130,28 @@ class BlobServerProtocol(asyncio.Protocol): request = BlobRequest.deserialize(self.buf + data) self.buf = remainder except JSONDecodeError: - addr = self.transport.get_extra_info('peername') - peer_address, peer_port = addr - log.error("failed to decode blob request from %s:%i (%i bytes): %s", peer_address, peer_port, - len(data), '' if not data else binascii.hexlify(data).decode()) - if not request: - addr = self.transport.get_extra_info('peername') - peer_address, peer_port = addr - log.warning("failed to decode blob request from %s:%i", peer_address, peer_port) - self.transport.close() + log.error("request from %s is not valid json (%i bytes): %s", self.peer_address_and_port, + len(self.buf + data), '' if not data else binascii.hexlify(self.buf + data).decode()) + self.close() + return + if not request.requests: + log.error("failed to decode request from %s (%i bytes): %s", self.peer_address_and_port, + len(self.buf + data), '' if not data else binascii.hexlify(self.buf + data).decode()) + self.close() return self.loop.create_task(self.handle_request(request)) class BlobServer: - def __init__(self, loop: asyncio.AbstractEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str): + def __init__(self, loop: asyncio.AbstractEventLoop, blob_manager: 'BlobManager', lbrycrd_address: str, + idle_timeout: float = 30.0, transfer_timeout: float = 60.0): self.loop = loop self.blob_manager = blob_manager self.server_task: typing.Optional[asyncio.Task] = None self.started_listening = asyncio.Event(loop=self.loop) self.lbrycrd_address = lbrycrd_address + self.idle_timeout = idle_timeout + self.transfer_timeout = transfer_timeout self.server_protocol_class = BlobServerProtocol def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'): @@ -117,7 +160,8 @@ class BlobServer: async def _start_server(): server = await self.loop.create_server( - lambda: self.server_protocol_class(self.loop, self.blob_manager, self.lbrycrd_address), + lambda: self.server_protocol_class(self.loop, self.blob_manager, self.lbrycrd_address, + self.idle_timeout, self.transfer_timeout), interface, port ) self.started_listening.set() diff --git a/lbry/tests/unit/blob_exchange/test_transfer_blob.py b/lbry/tests/unit/blob_exchange/test_transfer_blob.py index 02ecec07b..ae9764939 100644 --- a/lbry/tests/unit/blob_exchange/test_transfer_blob.py +++ b/lbry/tests/unit/blob_exchange/test_transfer_blob.py @@ -225,3 +225,85 @@ class TestBlobExchange(BlobExchangeTestBase): server_protocol.data_received(bytes([byte])) await asyncio.sleep(0.1) # yield execution self.assertTrue(len(received_data.getvalue()) > 0) + + async def test_idle_timeout(self): + self.server.idle_timeout = 1 + + blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" + mock_blob_bytes = b'1' * ((2 * 2 ** 20) - 1) + await self._add_blob_to_server(blob_hash, mock_blob_bytes) + client_blob = self.client_blob_manager.get_blob(blob_hash) + + # download the blob + downloaded, transport = await request_blob(self.loop, client_blob, self.server_from_client.address, + self.server_from_client.tcp_port, 2, 3) + self.assertIsNotNone(transport) + self.assertFalse(transport.is_closing()) + await client_blob.verified.wait() + self.assertTrue(client_blob.get_is_verified()) + self.assertTrue(downloaded) + client_blob.delete() + + # wait for less than the idle timeout + await asyncio.sleep(0.5, loop=self.loop) + + # download the blob again + downloaded, transport2 = await request_blob(self.loop, client_blob, self.server_from_client.address, + self.server_from_client.tcp_port, 2, 3, + connected_transport=transport) + self.assertTrue(transport is transport2) + self.assertFalse(transport.is_closing()) + await client_blob.verified.wait() + self.assertTrue(client_blob.get_is_verified()) + self.assertTrue(downloaded) + client_blob.delete() + + # check that the connection times out from the server side + await asyncio.sleep(0.9, loop=self.loop) + self.assertFalse(transport.is_closing()) + self.assertIsNotNone(transport._sock) + await asyncio.sleep(0.1, loop=self.loop) + self.assertIsNone(transport._sock) + self.assertTrue(transport.is_closing()) + + def test_max_request_size(self): + protocol = BlobServerProtocol(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') + called = asyncio.Event() + protocol.close = called.set + protocol.data_received(b'0' * 1199) + self.assertFalse(called.is_set()) + protocol.data_received(b'0') + self.assertTrue(called.is_set()) + + def test_bad_json(self): + protocol = BlobServerProtocol(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') + called = asyncio.Event() + protocol.close = called.set + protocol.data_received(b'{{0}') + self.assertTrue(called.is_set()) + + def test_no_request(self): + protocol = BlobServerProtocol(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') + called = asyncio.Event() + protocol.close = called.set + protocol.data_received(b'{}') + self.assertTrue(called.is_set()) + + async def test_transfer_timeout(self): + self.server.transfer_timeout = 1 + + blob_hash = "7f5ab2def99f0ddd008da71db3a3772135f4002b19b7605840ed1034c8955431bd7079549e65e6b2a3b9c17c773073ed" + mock_blob_bytes = b'1' * ((2 * 2 ** 20) - 1) + await self._add_blob_to_server(blob_hash, mock_blob_bytes) + client_blob = self.client_blob_manager.get_blob(blob_hash) + server_blob = self.server_blob_manager.get_blob(blob_hash) + + async def sendfile(writer): + await asyncio.sleep(2, loop=self.loop) + return 0 + + server_blob.sendfile = sendfile + + with self.assertRaises(asyncio.CancelledError): + await request_blob(self.loop, client_blob, self.server_from_client.address, + self.server_from_client.tcp_port, 2, 3)