From d4aca89a48e359f7b9b9eda8bd04b18c3673e01a Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Wed, 9 Mar 2022 17:47:23 -0300 Subject: [PATCH] handle multiple results from multiple trackers --- lbry/torrent/tracker.py | 14 +++++----- tests/unit/torrent/test_tracker.py | 41 ++++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/lbry/torrent/tracker.py b/lbry/torrent/tracker.py index f38cf17fc..e23a5ba84 100644 --- a/lbry/torrent/tracker.py +++ b/lbry/torrent/tracker.py @@ -138,7 +138,7 @@ class TrackerClient: self.transport, _ = await asyncio.get_running_loop().create_datagram_endpoint( lambda: self.client, local_addr=("0.0.0.0", 0)) self.EVENT_CONTROLLER.stream.listen( - lambda request: self.on_hash(request[1]) if request[0] == 'search' else None) + lambda request: self.on_hash(request[1], request[2]) if request[0] == 'search' else None) def stop(self): if self.transport is not None: @@ -155,18 +155,18 @@ class TrackerClient: log.info("Tracker finished announcing %d files.", self.announced) self.announced = 0 - def on_hash(self, info_hash): + def on_hash(self, info_hash, on_announcement=None): if info_hash not in self.tasks: - task = asyncio.create_task(self.get_peer_list(info_hash)) + task = asyncio.create_task(self.get_peer_list(info_hash, on_announcement=on_announcement)) task.add_done_callback(lambda *_: self.hash_done(info_hash)) self.tasks[info_hash] = task - async def get_peer_list(self, info_hash, stopped=False): + async def get_peer_list(self, info_hash, stopped=False, on_announcement=None): found = [] for done in asyncio.as_completed([self._probe_server(info_hash, *server, stopped) for server in self.servers]): result = await done if result is not None: - self.EVENT_CONTROLLER.add((info_hash, result)) + await asyncio.gather(*filter(asyncio.iscoroutine, [on_announcement(result)] if on_announcement else [])) found.append(result) return found @@ -191,6 +191,4 @@ class TrackerClient: def subscribe_hash(info_hash: bytes, on_data): - TrackerClient.EVENT_CONTROLLER.add(('search', info_hash)) - TrackerClient.EVENT_CONTROLLER.stream.where(lambda request: request[0] == info_hash).add_done_callback( - lambda request: on_data(request.result()[1])) + TrackerClient.EVENT_CONTROLLER.add(('search', info_hash, on_data)) diff --git a/tests/unit/torrent/test_tracker.py b/tests/unit/torrent/test_tracker.py index a6afed06a..e4b3c2f1c 100644 --- a/tests/unit/torrent/test_tracker.py +++ b/tests/unit/torrent/test_tracker.py @@ -16,6 +16,10 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su def connection_made(self, transport: asyncio.DatagramTransport) -> None: self.transport = transport + def add_peer(self, info_hash, ip_address: str, port: int): + self.peers.setdefault(info_hash, []) + self.peers[info_hash].append(encode_peer(ip_address, port)) + def datagram_received(self, data: bytes, address: (str, int)) -> None: if len(data) < 16: return @@ -30,18 +34,21 @@ class UDPTrackerServerProtocol(asyncio.DatagramProtocol): # for testing. Not su if req.connection_id not in self.known_conns: resp = encode(ErrorResponse(3, req.transaction_id, b'Connection ID missmatch.\x00')) else: - self.peers.setdefault(req.info_hash, []) - compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), address[0].split('.'), bytearray()) - compact_address = compact_ip + req.port.to_bytes(2, "big", signed=False) + compact_address = encode_peer(address[0], req.port) if req.event != 3: - self.peers[req.info_hash].append(compact_address) - elif compact_address in self.peers[req.info_hash]: + self.add_peer(req.info_hash, address[0], req.port) + elif compact_address in self.peers.get(req.info_hash, []): self.peers[req.info_hash].remove(compact_address) peers = [decode(CompactIPv4Peer, peer) for peer in self.peers[req.info_hash]] resp = encode(AnnounceResponse(1, req.transaction_id, 1700, 0, len(peers), peers)) return self.transport.sendto(resp, address) +def encode_peer(ip_address: str, port: int): + compact_ip = reduce(lambda buff, x: buff + bytearray([int(x)]), ip_address.split('.'), bytearray()) + return compact_ip + port.to_bytes(2, "big", signed=False) + + class UDPTrackerClientTestCase(AsyncioTestCase): async def asyncSetUp(self): self.servers = {} @@ -70,7 +77,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase): async def test_announce_using_helper_function(self): info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) queue = asyncio.Queue() - subscribe_hash(info_hash, queue.put_nowait) + subscribe_hash(info_hash, queue.put) announcement = await queue.get() peers = announcement.peers self.assertEqual(peers, [CompactIPv4Peer(int.from_bytes(bytes([127, 0, 0, 1]), "big", signed=False), 4444)]) @@ -89,8 +96,7 @@ class UDPTrackerClientTestCase(AsyncioTestCase): info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) await self.client.get_peer_list(info_hash) for server in self.servers.values(): - self.assertEqual(1, len(server.peers)) - self.assertEqual(1, len(server.peers[info_hash])) + self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]}) async def test_multiple_with_bad_one(self): await asyncio.gather(*[self.add_server() for _ in range(10)]) @@ -98,5 +104,20 @@ class UDPTrackerClientTestCase(AsyncioTestCase): info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) await self.client.get_peer_list(info_hash) for server in self.servers.values(): - self.assertEqual(1, len(server.peers)) - self.assertEqual(1, len(server.peers[info_hash])) + self.assertEqual(server.peers, {info_hash: [encode_peer("127.0.0.1", self.client.announce_port)]}) + + async def test_multiple_with_different_peers_across_helper_function(self): + # this is how the downloader uses it + await asyncio.gather(*[self.add_server() for _ in range(10)]) + info_hash = random.getrandbits(160).to_bytes(20, "big", signed=False) + fake_peers = [] + for server in self.servers.values(): + for _ in range(10): + peer = (f"127.0.0.{random.randint(1, 255)}", random.randint(2000, 65500)) + fake_peers.append(peer) + server.add_peer(info_hash, *peer) + response = [] + subscribe_hash(info_hash, response.append) + await asyncio.sleep(0) + await asyncio.gather(*self.client.tasks.values()) + self.assertEqual(11, len(response))