diff --git a/lbrynet/blob_exchange/downloader.py b/lbrynet/blob_exchange/downloader.py index 24142f955..7010e54a7 100644 --- a/lbrynet/blob_exchange/downloader.py +++ b/lbrynet/blob_exchange/downloader.py @@ -14,7 +14,7 @@ log = logging.getLogger(__name__) class BlobDownloader: - BAN_TIME = 10.0 # fixme: when connection manager gets implemented, move it out from here + BAN_FACTOR = 2.0 # fixme: when connection manager gets implemented, move it out from here def __init__(self, loop: asyncio.BaseEventLoop, config: 'Config', blob_manager: 'BlobManager', peer_queue: asyncio.Queue): @@ -25,6 +25,7 @@ class BlobDownloader: self.active_connections: typing.Dict['KademliaPeer', asyncio.Task] = {} # active request_blob calls self.ignored: typing.Dict['KademliaPeer', int] = {} self.scores: typing.Dict['KademliaPeer', int] = {} + self.failures: typing.Dict['KademliaPeer', int] = {} self.connections: typing.Dict['KademliaPeer', asyncio.Transport] = {} self.time_since_last_blob = loop.time() @@ -54,27 +55,20 @@ class BlobDownloader: if not transport and peer not in self.ignored: self.ignored[peer] = self.loop.time() log.debug("drop peer %s:%i", peer.address, peer.tcp_port) + self.failures[peer] = self.failures.get(peer, 0) + 1 if peer in self.connections: del self.connections[peer] elif transport: log.debug("keep peer %s:%i", peer.address, peer.tcp_port) + if bytes_received: + self.failures[peer] = 0 self.connections[peer] = transport rough_speed = (bytes_received / (self.loop.time() - start)) if bytes_received else 0 self.scores[peer] = rough_speed - async def new_peer_or_finished(self, blob: 'AbstractBlob'): - async def get_and_re_add_peers(): - try: - new_peers = await asyncio.wait_for(self.peer_queue.get(), timeout=1.0) - self.peer_queue.put_nowait(new_peers) - except asyncio.TimeoutError: - pass - tasks = [self.loop.create_task(get_and_re_add_peers()), self.loop.create_task(blob.verified.wait())] - active_tasks = list(self.active_connections.values()) - try: - await asyncio.wait(tasks + active_tasks, loop=self.loop, return_when='FIRST_COMPLETED') - finally: - drain_tasks(tasks) + async def new_peer_or_finished(self): + active_tasks = list(self.active_connections.values()) + [asyncio.sleep(0.2)] + await asyncio.wait(active_tasks, loop=self.loop, return_when='FIRST_COMPLETED') def cleanup_active(self): to_remove = [peer for (peer, task) in self.active_connections.items() if task.done()] @@ -85,7 +79,9 @@ class BlobDownloader: now = self.loop.time() if now - self.time_since_last_blob > 60.0: return - forgiven = [banned_peer for banned_peer, when in self.ignored.items() if now - when > self.BAN_TIME] + timeout_for = lambda peer: (self.failures.get(peer, 0) ** self.BAN_FACTOR) - 1.0 + forgiven = [banned_peer for banned_peer, when in self.ignored.items() if now - when > timeout_for(banned_peer)] + log.warning([(timeout_for(peer), when) for peer, when in self.ignored.items()]) self.peer_queue.put_nowait(forgiven) for banned_peer in forgiven: self.ignored.pop(banned_peer) @@ -112,7 +108,7 @@ class BlobDownloader: log.debug("request %s from %s:%i", blob_hash[:8], peer.address, peer.tcp_port) t = self.loop.create_task(self.request_blob_from_peer(blob, peer)) self.active_connections[peer] = t - await self.new_peer_or_finished(blob) + await self.new_peer_or_finished() self.cleanup_active() if batch: to_re_add = list(set(batch).difference(self.ignored)) diff --git a/tests/integration/test_file_commands.py b/tests/integration/test_file_commands.py index 5ccde6068..a2a0b0928 100644 --- a/tests/integration/test_file_commands.py +++ b/tests/integration/test_file_commands.py @@ -159,7 +159,7 @@ class FileCommands(CommandTestCase): self.assertEqual('finished', file_info['status']) async def test_unban_recovers_stream(self): - BlobDownloader.BAN_TIME = .5 # fixme: temporary field, will move to connection manager or a conf + BlobDownloader.BAN_FACTOR = .5 # fixme: temporary field, will move to connection manager or a conf tx = await self.stream_create('foo', '0.01', data=bytes([0] * (1 << 23))) sd_hash = tx['outputs'][0]['value']['source']['sd_hash'] missing_blob_hash = (await self.daemon.jsonrpc_blob_list(sd_hash=sd_hash))[-2]