diff --git a/lbry/wallet/header.py b/lbry/wallet/header.py index b6de019f1..c5af445aa 100644 --- a/lbry/wallet/header.py +++ b/lbry/wallet/header.py @@ -5,6 +5,7 @@ import asyncio import hashlib import logging import zlib +from concurrent.futures.thread import ThreadPoolExecutor from io import BytesIO from contextlib import asynccontextmanager @@ -47,6 +48,7 @@ class Headers: self.path = path self._size: Optional[int] = None self.chunk_getter: Optional[Callable] = None + self.executor = ThreadPoolExecutor(1) async def open(self): if self.path != ':memory:': @@ -54,8 +56,10 @@ class Headers: self.io = open(self.path, 'w+b') else: self.io = open(self.path, 'r+b') + self._size = self.io.seek(0, os.SEEK_END) // self.header_size async def close(self): + self.executor.shutdown() self.io.close() @staticmethod @@ -103,27 +107,34 @@ class Headers: return new_target def __len__(self) -> int: - if self._size is None: - self._size = self.io.seek(0, os.SEEK_END) // self.header_size return self._size def __bool__(self): return True async def get(self, height) -> dict: + if height < 0: + raise IndexError(f"Height cannot be negative!!") if isinstance(height, slice): raise NotImplementedError("Slicing of header chain has not been implemented yet.") - return self.deserialize(height, await self.get_raw_header(height)) + try: + return self.deserialize(height, await self.get_raw_header(height)) + except struct.error: + raise IndexError(f"failed to get {height}, at {len(self)}") def estimated_timestamp(self, height): return self.first_block_timestamp + (height * self.timestamp_average_offset) async def get_raw_header(self, height) -> bytes: - await self.ensure_chunk_at(height) - self.io.seek(height * self.header_size, os.SEEK_SET) - return self.io.read(self.header_size) + if self.chunk_getter: + await self.ensure_chunk_at(height) + return await asyncio.get_running_loop().run_in_executor(self.executor, self._read, height) - async def chunk_hash(self, start, count): + def _read(self, height, count=1): + self.io.seek(height * self.header_size, os.SEEK_SET) + return self.io.read(self.header_size * count) + + def chunk_hash(self, start, count): self.io.seek(start * self.header_size, os.SEEK_SET) return self.hash_header(self.io.read(count * self.header_size)).decode() @@ -141,16 +152,20 @@ class Headers: zlib.decompress(base64.b64decode(headers['base64']), wbits=-15, bufsize=600_000) ) chunk_hash = self.hash_header(chunk).decode() - if HASHES[start] == chunk_hash: - return self._write(start, chunk) + if HASHES.get(start) == chunk_hash: + return await asyncio.get_running_loop().run_in_executor(self.executor, self._write, start, chunk) + elif start not in HASHES: + return # todo: fixme raise Exception( f"Checkpoint mismatch at height {start}. Expected {HASHES[start]}, but got {chunk_hash} instead." ) async def has_header(self, height): - empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d' - all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b' - return await self.chunk_hash(height, 1) not in (empty, all_zeroes) + def _has_header(height): + empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d' + all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b' + return self.chunk_hash(height, 1) not in (empty, all_zeroes) + return await asyncio.get_running_loop().run_in_executor(self.executor, _has_header, height) @property def height(self) -> int: @@ -216,11 +231,11 @@ class Headers: def _write(self, height, verified_chunk): self.io.seek(height * self.header_size, os.SEEK_SET) written = self.io.write(verified_chunk) // self.header_size - self.io.truncate() + # self.io.truncate() # .seek()/.write()/.truncate() might also .flush() when needed # the goal here is mainly to ensure we're definitely flush()'ing self.io.flush() - self._size = self.io.tell() // self.header_size + self._size = self.io.seek(0, os.SEEK_END) // self.header_size return written async def validate_chunk(self, height, chunk): @@ -272,8 +287,9 @@ class Headers: previous_header_hash = fail = None batch_size = 36 for start_height in range(0, self.height, batch_size): - self.io.seek(self.header_size * start_height) - headers = self.io.read(self.header_size*batch_size) + headers = await asyncio.get_running_loop().run_in_executor( + self.executor, self._read, start_height, batch_size + ) if len(headers) % self.header_size != 0: headers = headers[:(len(headers) // self.header_size) * self.header_size] for header_hash, header in self._iterate_headers(start_height, headers): @@ -286,11 +302,12 @@ class Headers: fail = True if fail: log.warning("Header file corrupted at height %s, truncating it.", height - 1) - self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET) - self.io.truncate() - self.io.flush() - self._size = None - return + def __truncate(at_height): + self.io.seek(max(0, (at_height - 1)) * self.header_size, os.SEEK_SET) + self.io.truncate() + self.io.flush() + self._size = self.io.seek(0, os.SEEK_END) // self.header_size + return await asyncio.get_running_loop().run_in_executor(self.executor, __truncate, height) previous_header_hash = header_hash @classmethod diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index 6577b4e54..0d1284dc2 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -316,9 +316,6 @@ class Ledger(metaclass=LedgerRegistry): first_connection = self.network.on_connected.first asyncio.ensure_future(self.network.start()) await first_connection - async with self._header_processing_lock: - await self._update_tasks.add(self.initial_headers_sync()) - await self._on_ready_controller.stream.first await asyncio.gather(*(a.maybe_migrate_certificates() for a in self.accounts)) await asyncio.gather(*(a.save_max_gap() for a in self.accounts)) if len(self.accounts) > 10: @@ -329,8 +326,10 @@ class Ledger(metaclass=LedgerRegistry): async def join_network(self, *_): log.info("Subscribing and updating accounts.") - #async with self._header_processing_lock: - # await self.update_headers() + self._update_tasks.add(self.initial_headers_sync()) + async with self._header_processing_lock: + await self.headers.ensure_tip() + await self.update_headers() await self.subscribe_accounts() await self._update_tasks.done.wait() self._on_ready_controller.add(True) @@ -348,16 +347,12 @@ class Ledger(metaclass=LedgerRegistry): async def initial_headers_sync(self): target = self.network.remote_height + 1 - current = len(self.headers) get_chunk = partial(self.network.retriable_call, self.network.get_headers, count=1000, b64=True) self.headers.chunk_getter = get_chunk - await self.headers.ensure_tip() async def doit(): - for height in range(current, target, 1000): + for height in reversed(range(0, target, 1000)): await self.headers.ensure_chunk_at(height) - self._download_height = height - log.info("Headers sync: %s / %s", self._download_height, target) asyncio.ensure_future(doit()) return @@ -598,7 +593,7 @@ class Ledger(metaclass=LedgerRegistry): async def maybe_verify_transaction(self, tx, remote_height): tx.height = remote_height - if 0 < remote_height < self.network.remote_height: + if 0 < remote_height < len(self.headers): merkle = await self.network.retriable_call(self.network.get_merkle, tx.id, remote_height) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) header = await self.headers.get(remote_height) diff --git a/tests/unit/wallet/test_headers.py b/tests/unit/wallet/test_headers.py index 8499a2b99..968766725 100644 --- a/tests/unit/wallet/test_headers.py +++ b/tests/unit/wallet/test_headers.py @@ -42,6 +42,7 @@ class TestHeaders(AsyncioTestCase): async def test_connect_from_genesis(self): headers = Headers(':memory:') + await headers.open() self.assertEqual(headers.height, -1) await headers.connect(0, HEADERS) self.assertEqual(headers.height, 19) @@ -49,6 +50,7 @@ class TestHeaders(AsyncioTestCase): async def test_connect_from_middle(self): h = Headers(':memory:') h.io.write(HEADERS[:block_bytes(10)]) + await h.open() self.assertEqual(h.height, 9) await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)]) self.assertEqual(h.height, 19) @@ -140,6 +142,7 @@ class TestHeaders(AsyncioTestCase): async def test_checkpointed_writer(self): headers = Headers(':memory:') + await headers.open() getblocks = lambda start, end: HEADERS[block_bytes(start):block_bytes(end)] headers.checkpoint = 10, hexlify(sha256(getblocks(10, 11))) async with headers.checkpointed_connector() as buff: @@ -149,6 +152,7 @@ class TestHeaders(AsyncioTestCase): buff.write(getblocks(10, 19)) self.assertEqual(len(headers), 19) headers = Headers(':memory:') + await headers.open() async with headers.checkpointed_connector() as buff: buff.write(getblocks(0, 19)) self.assertEqual(len(headers), 19) diff --git a/tests/unit/wallet/test_ledger.py b/tests/unit/wallet/test_ledger.py index 13957d459..dc51ca240 100644 --- a/tests/unit/wallet/test_ledger.py +++ b/tests/unit/wallet/test_ledger.py @@ -67,7 +67,7 @@ class LedgerTestCase(AsyncioTestCase): serialized = self.make_header(**kwargs) self.ledger.headers.io.seek(0, os.SEEK_END) self.ledger.headers.io.write(serialized) - self.ledger.headers._size = None + self.ledger.headers._size = self.ledger.headers.io.seek(0, os.SEEK_END) // self.ledger.headers.header_size class TestSynchronization(LedgerTestCase):