diff --git a/lbry/stream/reflector/server.py b/lbry/stream/reflector/server.py index cfb8b00ec..da11e5ec5 100644 --- a/lbry/stream/reflector/server.py +++ b/lbry/stream/reflector/server.py @@ -15,7 +15,10 @@ log = logging.getLogger(__name__) class ReflectorServerProtocol(asyncio.Protocol): - def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000): + def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000, + stop_event: typing.Optional[asyncio.Event] = None, + incoming_event: typing.Optional[asyncio.Event] = None, + not_incoming_event: typing.Optional[asyncio.Event] = None): self.loop = asyncio.get_event_loop() self.blob_manager = blob_manager self.server_task: asyncio.Task = None @@ -27,11 +30,24 @@ class ReflectorServerProtocol(asyncio.Protocol): self.descriptor: typing.Optional['StreamDescriptor'] = None self.sd_blob: typing.Optional['BlobFile'] = None self.received = [] - self.incoming = asyncio.Event(loop=self.loop) + self.incoming = incoming_event or asyncio.Event(loop=self.loop) + self.not_incoming = not_incoming_event or asyncio.Event(loop=self.loop) + self.stop_event = stop_event or asyncio.Event(loop=self.loop) self.chunk_size = response_chunk_size + async def wait_for_stop(self): + await self.stop_event.wait() + if self.transport: + self.transport.close() + def connection_made(self, transport): self.transport = transport + self.wait_for_stop_task = self.loop.create_task(self.wait_for_stop()) + + def connection_lost(self, exc: typing.Optional[Exception]) -> None: + if self.wait_for_stop_task: + self.wait_for_stop_task.cancel() + self.wait_for_stop_task = None def data_received(self, data: bytes): if self.incoming.is_set(): @@ -73,6 +89,7 @@ class ReflectorServerProtocol(asyncio.Protocol): self.sd_blob = self.blob_manager.get_blob(request['sd_blob_hash'], request['sd_blob_size']) if not self.sd_blob.get_is_verified(): self.writer = self.sd_blob.get_blob_writer(self.transport.get_extra_info('peername')) + self.not_incoming.clear() self.incoming.set() self.send_response({"send_sd_blob": True}) try: @@ -86,6 +103,7 @@ class ReflectorServerProtocol(asyncio.Protocol): self.transport.close() finally: self.incoming.clear() + self.not_incoming.set() self.writer.close_handle() self.writer = None else: @@ -93,6 +111,7 @@ class ReflectorServerProtocol(asyncio.Protocol): self.loop, self.blob_manager.blob_dir, self.sd_blob ) self.incoming.clear() + self.not_incoming.set() if self.writer: self.writer.close_handle() self.writer = None @@ -112,6 +131,7 @@ class ReflectorServerProtocol(asyncio.Protocol): blob = self.blob_manager.get_blob(request['blob_hash'], request['blob_size']) if not blob.get_is_verified(): self.writer = blob.get_blob_writer(self.transport.get_extra_info('peername')) + self.not_incoming.clear() self.incoming.set() self.send_response({"send_blob": True}) try: @@ -120,6 +140,7 @@ class ReflectorServerProtocol(asyncio.Protocol): except asyncio.TimeoutError: self.send_response({"received_blob": False}) self.incoming.clear() + self.not_incoming.set() self.writer.close_handle() self.writer = None else: @@ -130,12 +151,19 @@ class ReflectorServerProtocol(asyncio.Protocol): class ReflectorServer: - def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000): + def __init__(self, blob_manager: 'BlobManager', response_chunk_size: int = 10000, + stop_event: typing.Optional[asyncio.Event] = None, + incoming_event: typing.Optional[asyncio.Event] = None, + not_incoming_event: typing.Optional[asyncio.Event] = None): self.loop = asyncio.get_event_loop() self.blob_manager = blob_manager self.server_task: typing.Optional[asyncio.Task] = None self.started_listening = asyncio.Event(loop=self.loop) + self.stopped_listening = asyncio.Event(loop=self.loop) + self.incoming_event = incoming_event or asyncio.Event(loop=self.loop) + self.not_incoming_event = not_incoming_event or asyncio.Event(loop=self.loop) self.response_chunk_size = response_chunk_size + self.stop_event = stop_event def start_server(self, port: int, interface: typing.Optional[str] = '0.0.0.0'): if self.server_task is not None: @@ -143,13 +171,20 @@ class ReflectorServer: async def _start_server(): server = await self.loop.create_server( - lambda: ReflectorServerProtocol(self.blob_manager, self.response_chunk_size), + lambda: ReflectorServerProtocol( + self.blob_manager, self.response_chunk_size, self.stop_event, self.incoming_event, + self.not_incoming_event + ), interface, port ) self.started_listening.set() + self.stopped_listening.clear() log.info("Reflector server listening on TCP %s:%i", interface, port) - async with server: - await server.serve_forever() + try: + async with server: + await server.serve_forever() + finally: + self.stopped_listening.set() self.server_task = self.loop.create_task(_start_server()) diff --git a/tests/unit/stream/test_reflector.py b/tests/unit/stream/test_reflector.py index b47cf31d0..6616d0d04 100644 --- a/tests/unit/stream/test_reflector.py +++ b/tests/unit/stream/test_reflector.py @@ -10,7 +10,7 @@ from lbry.stream.stream_manager import StreamManager from lbry.stream.reflector.server import ReflectorServer -class TestStreamAssembler(AsyncioTestCase): +class TestReflector(AsyncioTestCase): async def asyncSetUp(self): self.loop = asyncio.get_event_loop() self.key = b'deadbeef' * 4 @@ -22,6 +22,7 @@ class TestStreamAssembler(AsyncioTestCase): self.storage = SQLiteStorage(self.conf, os.path.join(tmp_dir, "lbrynet.sqlite")) await self.storage.open() self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.conf) + self.addCleanup(self.blob_manager.stop) self.stream_manager = StreamManager(self.loop, Config(), self.blob_manager, None, self.storage, None) server_tmp_dir = tempfile.mkdtemp() @@ -30,6 +31,7 @@ class TestStreamAssembler(AsyncioTestCase): self.server_storage = SQLiteStorage(self.server_conf, os.path.join(server_tmp_dir, "lbrynet.sqlite")) await self.server_storage.open() self.server_blob_manager = BlobManager(self.loop, server_tmp_dir, self.server_storage, self.server_conf) + self.addCleanup(self.server_blob_manager.stop) download_dir = tempfile.mkdtemp() self.addCleanup(lambda: shutil.rmtree(download_dir)) @@ -54,6 +56,7 @@ class TestStreamAssembler(AsyncioTestCase): set(map(lambda b: b.blob_hash, self.stream.descriptor.blobs[:-1] + [self.blob_manager.get_blob(self.stream.sd_hash)])) ) + self.assertTrue(self.stream.is_fully_reflected) server_sd_blob = self.server_blob_manager.get_blob(self.stream.sd_hash) self.assertTrue(server_sd_blob.get_is_verified()) self.assertEqual(server_sd_blob.length, server_sd_blob.length) @@ -75,3 +78,87 @@ class TestStreamAssembler(AsyncioTestCase): to_announce = await self.storage.get_blobs_to_announce() self.assertIn(self.stream.sd_hash, to_announce, "sd blob not set to announce") self.assertIn(self.stream.descriptor.blobs[0].blob_hash, to_announce, "head blob not set to announce") + + async def test_result_from_disconnect_mid_sd_transfer(self): + stop = asyncio.Event() + incoming = asyncio.Event() + reflector = ReflectorServer( + self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming + ) + reflector.start_server(5566, '127.0.0.1') + await reflector.started_listening.wait() + self.addCleanup(reflector.stop_server) + self.assertEqual(0, self.stream.reflector_progress) + reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566)) + await incoming.wait() + stop.set() + # this used to raise (and then propagate) a CancelledError + self.assertListEqual(await reflect_task, []) + self.assertFalse(self.stream.is_fully_reflected) + + async def test_result_from_disconnect_after_sd_transfer(self): + stop = asyncio.Event() + incoming = asyncio.Event() + not_incoming = asyncio.Event() + reflector = ReflectorServer( + self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming, + not_incoming_event=not_incoming + ) + reflector.start_server(5566, '127.0.0.1') + await reflector.started_listening.wait() + self.addCleanup(reflector.stop_server) + self.assertEqual(0, self.stream.reflector_progress) + reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566)) + await incoming.wait() + await not_incoming.wait() + stop.set() + self.assertListEqual(await reflect_task, [self.stream.sd_hash]) + self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified()) + self.assertFalse(self.stream.is_fully_reflected) + + async def test_result_from_disconnect_after_data_transfer(self): + stop = asyncio.Event() + incoming = asyncio.Event() + not_incoming = asyncio.Event() + reflector = ReflectorServer( + self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming, + not_incoming_event=not_incoming + ) + reflector.start_server(5566, '127.0.0.1') + await reflector.started_listening.wait() + self.addCleanup(reflector.stop_server) + self.assertEqual(0, self.stream.reflector_progress) + reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566)) + await incoming.wait() + await not_incoming.wait() + await incoming.wait() + await not_incoming.wait() + stop.set() + self.assertListEqual(await reflect_task, [self.stream.sd_hash, self.stream.descriptor.blobs[0].blob_hash]) + self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified()) + self.assertTrue(self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified()) + self.assertFalse(self.stream.is_fully_reflected) + + async def test_result_from_disconnect_mid_data_transfer(self): + stop = asyncio.Event() + incoming = asyncio.Event() + not_incoming = asyncio.Event() + reflector = ReflectorServer( + self.server_blob_manager, response_chunk_size=50, stop_event=stop, incoming_event=incoming, + not_incoming_event=not_incoming + ) + reflector.start_server(5566, '127.0.0.1') + await reflector.started_listening.wait() + self.addCleanup(reflector.stop_server) + self.assertEqual(0, self.stream.reflector_progress) + reflect_task = asyncio.create_task(self.stream.upload_to_reflector('127.0.0.1', 5566)) + await incoming.wait() + await not_incoming.wait() + await incoming.wait() + stop.set() + self.assertListEqual(await reflect_task, [self.stream.sd_hash]) + self.assertTrue(self.server_blob_manager.get_blob(self.stream.sd_hash).get_is_verified()) + self.assertFalse( + self.server_blob_manager.get_blob(self.stream.descriptor.blobs[0].blob_hash).get_is_verified() + ) + self.assertFalse(self.stream.is_fully_reflected)