diff --git a/lbry/wallet/header.py b/lbry/wallet/header.py index fd5a85c61..c67b0dd1d 100644 --- a/lbry/wallet/header.py +++ b/lbry/wallet/header.py @@ -5,7 +5,6 @@ import asyncio import logging import zlib from datetime import date -from concurrent.futures.thread import ThreadPoolExecutor from io import BytesIO from typing import Optional, Iterator, Tuple, Callable @@ -42,23 +41,22 @@ class Headers: validate_difficulty: bool = True def __init__(self, path) -> None: - if path == ':memory:': - self.io = BytesIO() + self.io = None self.path = path self._size: Optional[int] = None self.chunk_getter: Optional[Callable] = None - self.executor = ThreadPoolExecutor(1) self.known_missing_checkpointed_chunks = set() self.check_chunk_lock = asyncio.Lock() async def open(self): - if not self.executor: - self.executor = ThreadPoolExecutor(1) + self.io = BytesIO() if self.path != ':memory:': - if not os.path.exists(self.path): - self.io = open(self.path, 'w+b') - else: - self.io = open(self.path, 'r+b') + def _readit(): + if os.path.exists(self.path): + with open(self.path, 'r+b') as header_file: + self.io.seek(0) + self.io.write(header_file.read()) + await asyncio.get_event_loop().run_in_executor(None, _readit) bytes_size = self.io.seek(0, os.SEEK_END) self._size = bytes_size // self.header_size max_checkpointed_height = max(self.checkpoints.keys() or [-1]) + 1000 @@ -72,10 +70,14 @@ class Headers: await self.get_all_missing_headers() async def close(self): - if self.executor: - self.executor.shutdown() - self.executor = None - self.io.close() + if self.io is not None: + def _close(): + flags = 'r+b' if os.path.exists(self.path) else 'w+b' + with open(self.path, flags) as header_file: + header_file.write(self.io.getbuffer()) + await asyncio.get_event_loop().run_in_executor(None, _close) + self.io.close() + self.io = None @staticmethod def serialize(header): @@ -135,28 +137,30 @@ class Headers: except struct.error: raise IndexError(f"failed to get {height}, at {len(self)}") - def estimated_timestamp(self, height): + def estimated_timestamp(self, height, try_real_headers=True): if height <= 0: return + if try_real_headers and self.has_header(height): + offset = height * self.header_size + return struct.unpack(' bytes: if self.chunk_getter: await self.ensure_chunk_at(height) if not 0 <= height <= self.height: raise IndexError(f"{height} is out of bounds, current height: {self.height}") - return await asyncio.get_running_loop().run_in_executor(self.executor, self._read, height) + return self._read(height) def _read(self, height, count=1): - self.io.seek(height * self.header_size, os.SEEK_SET) - return self.io.read(self.header_size * count) + offset = height * self.header_size + return bytes(self.io.getbuffer()[offset: offset + self.header_size * count]) def chunk_hash(self, start, count): - self.io.seek(start * self.header_size, os.SEEK_SET) - return self.hash_header(self.io.read(count * self.header_size)).decode() + return self.hash_header(self._read(start, count)).decode() async def ensure_checkpointed_size(self): max_checkpointed_height = max(self.checkpoints.keys() or [-1]) @@ -165,7 +169,7 @@ class Headers: async def ensure_chunk_at(self, height): async with self.check_chunk_lock: - if await self.has_header(height): + if self.has_header(height): log.debug("has header %s", height) return return await self.fetch_chunk(height) @@ -179,7 +183,7 @@ class Headers: ) chunk_hash = self.hash_header(chunk).decode() if self.checkpoints.get(start) == chunk_hash: - await asyncio.get_running_loop().run_in_executor(self.executor, self._write, start, chunk) + self._write(start, chunk) if start in self.known_missing_checkpointed_chunks: self.known_missing_checkpointed_chunks.remove(start) return @@ -189,27 +193,23 @@ class Headers: f"Checkpoint mismatch at height {start}. Expected {self.checkpoints[start]}, but got {chunk_hash} instead." ) - async def has_header(self, height): + def has_header(self, height): normalized_height = (height // 1000) * 1000 if normalized_height in self.checkpoints: return normalized_height not in self.known_missing_checkpointed_chunks - def _has_header(height): - empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d' - all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b' - return self.chunk_hash(height, 1) not in (empty, all_zeroes) - return await asyncio.get_running_loop().run_in_executor(self.executor, _has_header, height) + empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d' + all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b' + return self.chunk_hash(height, 1) not in (empty, all_zeroes) async def get_all_missing_headers(self): # Heavy operation done in one optimized shot - def _io_checkall(): - for chunk_height, expected_hash in reversed(list(self.checkpoints.items())): - if chunk_height in self.known_missing_checkpointed_chunks: - continue - if self.chunk_hash(chunk_height, 1000) != expected_hash: - self.known_missing_checkpointed_chunks.add(chunk_height) - return self.known_missing_checkpointed_chunks - return await asyncio.get_running_loop().run_in_executor(self.executor, _io_checkall) + for chunk_height, expected_hash in reversed(list(self.checkpoints.items())): + if chunk_height in self.known_missing_checkpointed_chunks: + continue + if self.chunk_hash(chunk_height, 1000) != expected_hash: + self.known_missing_checkpointed_chunks.add(chunk_height) + return self.known_missing_checkpointed_chunks @property def height(self) -> int: @@ -241,7 +241,7 @@ class Headers: bail = True chunk = chunk[:(height-e.height)*self.header_size] if chunk: - added += await asyncio.get_running_loop().run_in_executor(self.executor, self._write, height, chunk) + added += self._write(height, chunk) if bail: break return added @@ -306,9 +306,7 @@ class Headers: previous_header_hash = fail = None batch_size = 36 for height in range(start_height, self.height, batch_size): - headers = await asyncio.get_running_loop().run_in_executor( - self.executor, self._read, height, batch_size - ) + headers = self._read(height, batch_size) if len(headers) % self.header_size != 0: headers = headers[:(len(headers) // self.header_size) * self.header_size] for header_hash, header in self._iterate_headers(height, headers): @@ -324,12 +322,11 @@ class Headers: assert start_height > 0 and height == start_height if fail: log.warning("Header file corrupted at height %s, truncating it.", height - 1) - def __truncate(at_height): - self.io.seek(max(0, (at_height - 1)) * self.header_size, os.SEEK_SET) - self.io.truncate() - self.io.flush() - self._size = self.io.seek(0, os.SEEK_END) // self.header_size - return await asyncio.get_running_loop().run_in_executor(self.executor, __truncate, height) + self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET) + self.io.truncate() + self.io.flush() + self._size = self.io.seek(0, os.SEEK_END) // self.header_size + return previous_header_hash = header_hash @classmethod diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py index 7ceed23d7..b796a9c1a 100644 --- a/tests/unit/wallet/test_database.py +++ b/tests/unit/wallet/test_database.py @@ -211,6 +211,7 @@ class TestQueries(AsyncioTestCase): 'db': Database(':memory:'), 'headers': Headers(':memory:') }) + await self.ledger.headers.open() self.wallet = Wallet() await self.ledger.db.open() diff --git a/tests/unit/wallet/test_headers.py b/tests/unit/wallet/test_headers.py index e014f6e46..da433724d 100644 --- a/tests/unit/wallet/test_headers.py +++ b/tests/unit/wallet/test_headers.py @@ -21,8 +21,8 @@ class TestHeaders(AsyncioTestCase): async def test_deserialize(self): self.maxDiff = None h = Headers(':memory:') - h.io.write(HEADERS) await h.open() + await h.connect(0, HEADERS) self.assertEqual(await h.get(0), { 'bits': 520159231, 'block_height': 0, @@ -52,8 +52,11 @@ class TestHeaders(AsyncioTestCase): self.assertEqual(headers.height, 19) async def test_connect_from_middle(self): - h = Headers(':memory:') - h.io.write(HEADERS[:block_bytes(10)]) + headers_temporary_file = tempfile.mktemp() + self.addCleanup(os.remove, headers_temporary_file) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.write(HEADERS[:block_bytes(10)]) + h = Headers(headers_temporary_file) await h.open() self.assertEqual(h.height, 9) await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)]) @@ -115,6 +118,7 @@ class TestHeaders(AsyncioTestCase): async def test_bounds(self): headers = Headers(':memory:') + await headers.open() await headers.connect(0, HEADERS) self.assertEqual(19, headers.height) with self.assertRaises(IndexError): @@ -126,6 +130,7 @@ class TestHeaders(AsyncioTestCase): async def test_repair(self): headers = Headers(':memory:') + await headers.open() await headers.connect(0, HEADERS[:block_bytes(11)]) self.assertEqual(10, headers.height) await headers.repair() @@ -147,24 +152,39 @@ class TestHeaders(AsyncioTestCase): await headers.repair(start_height=10) self.assertEqual(19, headers.height) - def test_do_not_estimate_unconfirmed(self): + async def test_do_not_estimate_unconfirmed(self): headers = Headers(':memory:') + await headers.open() self.assertIsNone(headers.estimated_timestamp(-1)) self.assertIsNone(headers.estimated_timestamp(0)) self.assertIsNotNone(headers.estimated_timestamp(1)) - async def test_misalignment_triggers_repair_on_open(self): + async def test_dont_estimate_whats_there(self): headers = Headers(':memory:') - headers.io.seek(0) - headers.io.write(HEADERS) + await headers.open() + estimated = headers.estimated_timestamp(10) + await headers.connect(0, HEADERS) + real_time = (await headers.get(10))['timestamp'] + after_downloading_header_estimated = headers.estimated_timestamp(10) + self.assertNotEqual(estimated, after_downloading_header_estimated) + self.assertEqual(after_downloading_header_estimated, real_time) + + async def test_misalignment_triggers_repair_on_open(self): + headers_temporary_file = tempfile.mktemp() + self.addCleanup(os.remove, headers_temporary_file) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.write(HEADERS) + headers = Headers(headers_temporary_file) with self.assertLogs(level='WARN') as cm: await headers.open() + await headers.close() self.assertEqual(cm.output, []) - headers.io.seek(0) - headers.io.truncate() - headers.io.write(HEADERS[:block_bytes(10)]) - headers.io.write(b'ops') - headers.io.write(HEADERS[block_bytes(10):]) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.seek(0) + headers_file.truncate() + headers_file.write(HEADERS[:block_bytes(10)]) + headers_file.write(b'ops') + headers_file.write(HEADERS[block_bytes(10):]) await headers.open() self.assertEqual( cm.output, [ @@ -192,6 +212,7 @@ class TestHeaders(AsyncioTestCase): reader_task = asyncio.create_task(reader()) await writer() await reader_task + await headers.close() HEADERS = unhexlify( diff --git a/tests/unit/wallet/test_ledger.py b/tests/unit/wallet/test_ledger.py index 0244de987..bfe5cc71b 100644 --- a/tests/unit/wallet/test_ledger.py +++ b/tests/unit/wallet/test_ledger.py @@ -48,6 +48,8 @@ class LedgerTestCase(AsyncioTestCase): 'db': Database(':memory:'), 'headers': Headers(':memory:') }) + self.ledger.headers.checkpoints = {} + await self.ledger.headers.open() self.account = Account.generate(self.ledger, Wallet(), "lbryum") await self.ledger.db.open()