diff --git a/lbry/lbry/extras/daemon/storage.py b/lbry/lbry/extras/daemon/storage.py index 820663ed9..2b0bf96d7 100644 --- a/lbry/lbry/extras/daemon/storage.py +++ b/lbry/lbry/extras/daemon/storage.py @@ -101,6 +101,9 @@ def _batched_select(transaction, query, parameters, batch_size=900): def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Dict]: files = [] signed_claims = {} + stream_hashes = tuple( + stream_hash for (stream_hash,) in transaction.execute("select stream_hash from file").fetchall() + ) for (rowid, stream_hash, file_name, download_dir, data_rate, status, saved_file, raw_content_fee, _, sd_hash, stream_key, stream_name, suggested_file_name, *claim_args) in _batched_select( transaction, "select file.rowid, file.*, stream.*, c.* " @@ -108,9 +111,7 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di "inner join content_claim cc on file.stream_hash=cc.stream_hash " "inner join claim c on cc.claim_outpoint=c.claim_outpoint " "where file.stream_hash in {} " - "order by c.rowid desc", [ - stream_hash for (stream_hash,) in transaction.execute("select stream_hash from file")]): - + "order by c.rowid desc", stream_hashes): claim = StoredStreamClaim(stream_hash, *claim_args) if claim.channel_claim_id: if claim.channel_claim_id not in signed_claims: @@ -137,7 +138,7 @@ def get_all_lbry_files(transaction: sqlite3.Connection) -> typing.List[typing.Di ) for claim_name, claim_id in _batched_select( transaction, "select c.claim_name, c.claim_id from claim c where c.claim_id in {}", - list(signed_claims.keys())): + tuple(signed_claims.keys())): for claim in signed_claims[claim_id]: claim.channel_name = claim_name return files @@ -147,35 +148,35 @@ def store_stream(transaction: sqlite3.Connection, sd_blob: 'BlobFile', descripto # add all blobs, except the last one, which is empty transaction.executemany( "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", - [(blob.blob_hash, blob.length, 0, 0, "pending", 0, 0) - for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob]] - ) + ((blob.blob_hash, blob.length, 0, 0, "pending", 0, 0) + for blob in (descriptor.blobs[:-1] if len(descriptor.blobs) > 1 else descriptor.blobs) + [sd_blob]) + ).fetchall() # associate the blobs to the stream transaction.execute("insert or ignore into stream values (?, ?, ?, ?, ?)", (descriptor.stream_hash, sd_blob.blob_hash, descriptor.key, binascii.hexlify(descriptor.stream_name.encode()).decode(), - binascii.hexlify(descriptor.suggested_file_name.encode()).decode())) + binascii.hexlify(descriptor.suggested_file_name.encode()).decode())).fetchall() # add the stream transaction.executemany( "insert or ignore into stream_blob values (?, ?, ?, ?)", - [(descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv) - for blob in descriptor.blobs] - ) + ((descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv) + for blob in descriptor.blobs) + ).fetchall() # ensure should_announce is set regardless if insert was ignored transaction.execute( "update blob set should_announce=1 where blob_hash in (?, ?)", (sd_blob.blob_hash, descriptor.blobs[0].blob_hash,) - ) + ).fetchall() def delete_stream(transaction: sqlite3.Connection, descriptor: 'StreamDescriptor'): blob_hashes = [(blob.blob_hash, ) for blob in descriptor.blobs[:-1]] blob_hashes.append((descriptor.sd_hash, )) - transaction.execute("delete from content_claim where stream_hash=? ", (descriptor.stream_hash,)) - transaction.execute("delete from file where stream_hash=? ", (descriptor.stream_hash,)) - transaction.execute("delete from stream_blob where stream_hash=?", (descriptor.stream_hash,)) - transaction.execute("delete from stream where stream_hash=? ", (descriptor.stream_hash,)) - transaction.executemany("delete from blob where blob_hash=?", blob_hashes) + transaction.execute("delete from content_claim where stream_hash=? ", (descriptor.stream_hash,)).fetchall() + transaction.execute("delete from file where stream_hash=? ", (descriptor.stream_hash,)).fetchall() + transaction.execute("delete from stream_blob where stream_hash=?", (descriptor.stream_hash,)).fetchall() + transaction.execute("delete from stream where stream_hash=? ", (descriptor.stream_hash,)).fetchall() + transaction.executemany("delete from blob where blob_hash=?", blob_hashes).fetchall() def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typing.Optional[str], @@ -191,7 +192,7 @@ def store_file(transaction: sqlite3.Connection, stream_hash: str, file_name: typ (stream_hash, encoded_file_name, encoded_download_dir, data_payment_rate, status, 1 if (file_name and download_directory and os.path.isfile(os.path.join(download_directory, file_name))) else 0, None if not content_fee else binascii.hexlify(content_fee.raw).decode()) - ) + ).fetchall() return transaction.execute("select rowid from file where stream_hash=?", (stream_hash, )).fetchone()[0] @@ -293,17 +294,17 @@ class SQLiteStorage(SQLiteMixin): def _add_blobs(transaction: sqlite3.Connection): transaction.executemany( "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", - [ + ( (blob_hash, length, 0, 0, "pending" if not finished else "finished", 0, 0) for blob_hash, length in blob_hashes_and_lengths - ] - ) + ) + ).fetchall() if finished: transaction.executemany( - "update blob set status='finished' where blob.blob_hash=?", [ + "update blob set status='finished' where blob.blob_hash=?", ( (blob_hash, ) for blob_hash, _ in blob_hashes_and_lengths - ] - ) + ) + ).fetchall() return await self.db.run(_add_blobs) def get_blob_status(self, blob_hash: str): @@ -317,9 +318,9 @@ class SQLiteStorage(SQLiteMixin): return transaction.executemany( "update blob set next_announce_time=?, last_announced_time=?, single_announce=0 " "where blob_hash=?", - [(int(last_announced + (data_expiration / 2)), int(last_announced), blob_hash) - for blob_hash in blob_hashes] - ) + ((int(last_announced + (data_expiration / 2)), int(last_announced), blob_hash) + for blob_hash in blob_hashes) + ).fetchall() return self.db.run(_update_last_announced_blobs) def should_single_announce_blobs(self, blob_hashes, immediate=False): @@ -330,11 +331,11 @@ class SQLiteStorage(SQLiteMixin): transaction.execute( "update blob set single_announce=1, next_announce_time=? " "where blob_hash=? and status='finished'", (int(now), blob_hash) - ) + ).fetchall() else: transaction.execute( "update blob set single_announce=1 where blob_hash=? and status='finished'", (blob_hash,) - ) + ).fetchall() return self.db.run(set_single_announce) def get_blobs_to_announce(self): @@ -347,22 +348,22 @@ class SQLiteStorage(SQLiteMixin): "(should_announce=1 or single_announce=1) and next_announce_time typing.Awaitable[typing.Set[str]]: def _sync_blobs(transaction: sqlite3.Connection) -> typing.Set[str]: - to_update = [ - (blob_hash, ) - for (blob_hash, ) in transaction.execute("select blob_hash from blob where status='finished'") - if blob_hash not in blob_files - ] + finished_blob_hashes = tuple( + blob_hash for (blob_hash, ) in transaction.execute( + "select blob_hash from blob where status='finished'" + ).fetchall() + ) + finished_blobs_set = set(finished_blob_hashes) + to_update_set = finished_blobs_set.difference(blob_files) transaction.executemany( "update blob set status='pending' where blob_hash=?", - to_update - ) - return { - blob_hash - for blob_hash, in _batched_select( - transaction, "select blob_hash from blob where status='finished' and blob_hash in {}", - list(blob_files) - ) - } + ((blob_hash, ) for blob_hash in to_update_set) + ).fetchall() + return blob_files.intersection(finished_blobs_set) return self.db.run(_sync_blobs) # # # # # # # # # stream functions # # # # # # # # # @@ -484,7 +481,7 @@ class SQLiteStorage(SQLiteMixin): transaction.executemany( "update file set file_name=null, download_directory=null, saved_file=0 where stream_hash=?", removed - ) + ).fetchall() return await self.db.run(update_manually_removed_files) def get_all_lbry_files(self) -> typing.Awaitable[typing.List[typing.Dict]]: @@ -492,7 +489,7 @@ class SQLiteStorage(SQLiteMixin): def change_file_status(self, stream_hash: str, new_status: str): log.debug("update file status %s -> %s", stream_hash, new_status) - return self.db.execute("update file set status=? where stream_hash=?", (new_status, stream_hash)) + return self.db.execute_fetchall("update file set status=? where stream_hash=?", (new_status, stream_hash)) async def change_file_download_dir_and_file_name(self, stream_hash: str, download_dir: typing.Optional[str], file_name: typing.Optional[str]): @@ -501,22 +498,22 @@ class SQLiteStorage(SQLiteMixin): else: encoded_file_name = binascii.hexlify(file_name.encode()).decode() encoded_download_dir = binascii.hexlify(download_dir.encode()).decode() - return await self.db.execute("update file set download_directory=?, file_name=? where stream_hash=?", ( + return await self.db.execute_fetchall("update file set download_directory=?, file_name=? where stream_hash=?", ( encoded_download_dir, encoded_file_name, stream_hash, )) async def save_content_fee(self, stream_hash: str, content_fee: Transaction): - return await self.db.execute("update file set content_fee=? where stream_hash=?", ( + return await self.db.execute_fetchall("update file set content_fee=? where stream_hash=?", ( binascii.hexlify(content_fee.raw), stream_hash, )) async def set_saved_file(self, stream_hash: str): - return await self.db.execute("update file set saved_file=1 where stream_hash=?", ( + return await self.db.execute_fetchall("update file set saved_file=1 where stream_hash=?", ( stream_hash, )) async def clear_saved_file(self, stream_hash: str): - return await self.db.execute("update file set saved_file=0 where stream_hash=?", ( + return await self.db.execute_fetchall("update file set saved_file=0 where stream_hash=?", ( stream_hash, )) @@ -537,13 +534,13 @@ class SQLiteStorage(SQLiteMixin): transaction.execute("insert or ignore into content_claim values (?, ?)", content_claim) transaction.executemany( "update file set status='stopped' where stream_hash=?", - [(stream_hash, ) for stream_hash in stream_hashes] - ) + ((stream_hash, ) for stream_hash in stream_hashes) + ).fetchall() download_dir = binascii.hexlify(self.conf.download_dir.encode()).decode() transaction.executemany( f"update file set download_directory=? where stream_hash=?", - [(download_dir, stream_hash) for stream_hash in stream_hashes] - ) + ((download_dir, stream_hash) for stream_hash in stream_hashes) + ).fetchall() await self.db.run_with_foreign_keys_disabled(_recover) def get_all_stream_hashes(self): @@ -555,14 +552,16 @@ class SQLiteStorage(SQLiteMixin): # TODO: add 'address' to support items returned for a claim from lbrycrdd and lbryum-server def _save_support(transaction): bind = "({})".format(','.join(['?'] * len(claim_id_to_supports))) - transaction.execute(f"delete from support where claim_id in {bind}", list(claim_id_to_supports.keys())) + transaction.execute( + f"delete from support where claim_id in {bind}", tuple(claim_id_to_supports.keys()) + ).fetchall() for claim_id, supports in claim_id_to_supports.items(): for support in supports: transaction.execute( "insert into support values (?, ?, ?, ?)", ("%s:%i" % (support['txid'], support['nout']), claim_id, lbc_to_dewies(support['amount']), support.get('address', "")) - ) + ).fetchall() return self.db.run(_save_support) def get_supports(self, *claim_ids): @@ -581,7 +580,7 @@ class SQLiteStorage(SQLiteMixin): for support_info in _batched_select( transaction, "select * from support where claim_id in {}", - tuple(claim_ids) + claim_ids ) ] @@ -612,7 +611,7 @@ class SQLiteStorage(SQLiteMixin): transaction.execute( "insert or replace into claim values (?, ?, ?, ?, ?, ?, ?, ?, ?)", (outpoint, claim_id, name, amount, height, serialized, certificate_id, address, sequence) - ) + ).fetchall() # if this response doesn't have support info don't overwrite the existing # support info if 'supports' in claim_info: @@ -699,7 +698,9 @@ class SQLiteStorage(SQLiteMixin): ) # update the claim associated to the file - transaction.execute("insert or replace into content_claim values (?, ?)", (stream_hash, claim_outpoint)) + transaction.execute( + "insert or replace into content_claim values (?, ?)", (stream_hash, claim_outpoint) + ).fetchall() async def save_content_claim(self, stream_hash, claim_outpoint): await self.db.run(self._save_content_claim, claim_outpoint, stream_hash) @@ -722,11 +723,11 @@ class SQLiteStorage(SQLiteMixin): def update_reflected_stream(self, sd_hash, reflector_address, success=True): if success: - return self.db.execute( + return self.db.execute_fetchall( "insert or replace into reflected_stream values (?, ?, ?)", (sd_hash, reflector_address, self.time_getter()) ) - return self.db.execute( + return self.db.execute_fetchall( "delete from reflected_stream where sd_hash=? and reflector_address=?", (sd_hash, reflector_address) ) diff --git a/lbry/lbry/wallet/database.py b/lbry/lbry/wallet/database.py index 8828ee973..4c7d70008 100644 --- a/lbry/lbry/wallet/database.py +++ b/lbry/lbry/wallet/database.py @@ -134,11 +134,11 @@ class WalletDatabase(BaseDatabase): return self.get_utxo_count(**constraints) async def release_all_outputs(self, account): - await self.db.execute( + await self.db.execute_fetchall( "UPDATE txo SET is_reserved = 0 WHERE" " is_reserved = 1 AND txo.address IN (" " SELECT address from pubkey_address WHERE account = ?" - " )", [account.public_key.address] + " )", (account.public_key.address, ) ) def get_supports_summary(self, account_id): diff --git a/lbry/tests/integration/test_account_commands.py b/lbry/tests/integration/test_account_commands.py index d0be31342..fa1270842 100644 --- a/lbry/tests/integration/test_account_commands.py +++ b/lbry/tests/integration/test_account_commands.py @@ -11,14 +11,6 @@ def extract(d, keys): class AccountManagement(CommandTestCase): - async def test_sqlite_binding_error(self): - tasks = [ - self.loop.create_task(self.daemon.jsonrpc_account_create('second account' + str(x))) for x in range(100) - ] - await asyncio.wait(tasks) - for result in tasks: - self.assertFalse(isinstance(result.result(), Exception)) - async def test_account_list_set_create_remove_add(self): # check initial account response = await self.daemon.jsonrpc_account_list() diff --git a/torba/tests/client_tests/unit/test_database.py b/torba/tests/client_tests/unit/test_database.py index b2f88231b..c862f573d 100644 --- a/torba/tests/client_tests/unit/test_database.py +++ b/torba/tests/client_tests/unit/test_database.py @@ -1,7 +1,10 @@ +import sys +import os import unittest import sqlite3 import tempfile -import os +import asyncio +from concurrent.futures.thread import ThreadPoolExecutor from torba.client.wallet import Wallet from torba.client.constants import COIN @@ -431,3 +434,98 @@ class TestUpgrade(AsyncioTestCase): self.assertEqual(self.get_tables(), ['foo', 'pubkey_address', 'tx', 'txi', 'txo', 'version']) self.assertEqual(self.get_addresses(), []) # all tables got reset await self.ledger.db.close() + + +class TestSQLiteRace(AsyncioTestCase): + max_misuse_attempts = 40000 + + def setup_db(self): + self.db = sqlite3.connect(":memory:", isolation_level=None) + self.db.executescript( + "create table test1 (id text primary key not null, val text);\n" + + "create table test2 (id text primary key not null, val text);\n" + + "\n".join(f"insert into test1 values ({v}, NULL);" for v in range(1000)) + ) + + async def asyncSetUp(self): + self.executor = ThreadPoolExecutor(1) + await self.loop.run_in_executor(self.executor, self.setup_db) + + async def asyncTearDown(self): + await self.loop.run_in_executor(self.executor, self.db.close) + self.executor.shutdown() + + async def test_binding_param_0_error(self): + # test real param 0 binding errors + + for supported_type in [str, int, bytes]: + await self.loop.run_in_executor( + self.executor, self.db.executemany, "insert into test2 values (?, NULL)", + [(supported_type(1), ), (supported_type(2), )] + ) + await self.loop.run_in_executor( + self.executor, self.db.execute, "delete from test2 where id in (1, 2)" + ) + for unsupported_type in [lambda x: (x, ), lambda x: [x], lambda x: {x}]: + try: + await self.loop.run_in_executor( + self.executor, self.db.executemany, "insert into test2 (id, val) values (?, NULL)", + [(unsupported_type(1), ), (unsupported_type(2), )] + ) + self.assertTrue(False) + except sqlite3.InterfaceError as err: + self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") + + async def test_unhandled_sqlite_misuse(self): + # test SQLITE_MISUSE being incorrectly raised as a param 0 binding error + attempts = 0 + python_version = sys.version.split('\n')[0].rstrip(' ') + + try: + while attempts < self.max_misuse_attempts: + f1 = asyncio.wrap_future( + self.loop.run_in_executor( + self.executor, self.db.executemany, "update test1 set val='derp' where id=?", + ((str(i),) for i in range(2)) + ) + ) + f2 = asyncio.wrap_future( + self.loop.run_in_executor( + self.executor, self.db.executemany, "update test2 set val='derp' where id=?", + ((str(i),) for i in range(2)) + ) + ) + attempts += 1 + await asyncio.gather(f1, f2) + print(f"\nsqlite3 {sqlite3.version}/python {python_version} " + f"did not raise SQLITE_MISUSE within {attempts} attempts of the race condition") + self.assertTrue(False, 'this test failing means either the sqlite race conditions ' + 'have been fixed in cpython or the test max_attempts needs to be increased') + except sqlite3.InterfaceError as err: + self.assertEqual(str(err), "Error binding parameter 0 - probably unsupported type.") + print(f"\nsqlite3 {sqlite3.version}/python {python_version} raised SQLITE_MISUSE " + f"after {attempts} attempts of the race condition") + + @unittest.SkipTest + async def test_fetchall_prevents_sqlite_misuse(self): + # test that calling fetchall sufficiently avoids the race + attempts = 0 + + def executemany_fetchall(query, params): + self.db.executemany(query, params).fetchall() + + while attempts < self.max_misuse_attempts: + f1 = asyncio.wrap_future( + self.loop.run_in_executor( + self.executor, executemany_fetchall, "update test1 set val='derp' where id=?", + ((str(i),) for i in range(2)) + ) + ) + f2 = asyncio.wrap_future( + self.loop.run_in_executor( + self.executor, executemany_fetchall, "update test2 set val='derp' where id=?", + ((str(i),) for i in range(2)) + ) + ) + attempts += 1 + await asyncio.gather(f1, f2) \ No newline at end of file diff --git a/torba/torba/client/basedatabase.py b/torba/torba/client/basedatabase.py index 5fd9c9b47..b92965fa1 100644 --- a/torba/torba/client/basedatabase.py +++ b/torba/torba/client/basedatabase.py @@ -3,7 +3,7 @@ import asyncio from binascii import hexlify from concurrent.futures.thread import ThreadPoolExecutor -from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Optional +from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional import sqlite3 @@ -19,6 +19,7 @@ class AIOSQLite: # has to be single threaded as there is no mapping of thread:connection self.executor = ThreadPoolExecutor(max_workers=1) self.connection: sqlite3.Connection = None + self._closing = False @classmethod async def connect(cls, path: Union[bytes, str], *args, **kwargs): @@ -29,14 +30,12 @@ class AIOSQLite: return db async def close(self): - def __close(conn): - self.executor.submit(conn.close) - self.executor.shutdown(wait=True) - conn = self.connection - if not conn: + if self._closing: return + self._closing = True + await asyncio.get_event_loop().run_in_executor(self.executor, self.connection.close) + self.executor.shutdown(wait=True) self.connection = None - return asyncio.get_event_loop_policy().get_event_loop().call_later(0.01, __close, conn) def executemany(self, sql: str, params: Iterable): params = params if params is not None else [] @@ -87,10 +86,10 @@ class AIOSQLite: if not foreign_keys_enabled: raise sqlite3.IntegrityError("foreign keys are disabled, use `AIOSQLite.run` instead") try: - self.connection.execute('pragma foreign_keys=off') + self.connection.execute('pragma foreign_keys=off').fetchone() return self.__run_transaction(fun, *args, **kwargs) finally: - self.connection.execute('pragma foreign_keys=on') + self.connection.execute('pragma foreign_keys=on').fetchone() def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): @@ -160,7 +159,7 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): return joiner.join(sql) if sql else '', values -def query(select, **constraints): +def query(select, **constraints) -> Tuple[str, Dict[str, Any]]: sql = [select] limit = constraints.pop('limit', None) offset = constraints.pop('offset', None) @@ -377,10 +376,10 @@ class BaseDatabase(SQLiteMixin): } async def insert_transaction(self, tx): - await self.db.execute(*self._insert_sql('tx', self.tx_to_row(tx))) + await self.db.execute_fetchall(*self._insert_sql('tx', self.tx_to_row(tx))) async def update_transaction(self, tx): - await self.db.execute(*self._update_sql("tx", { + await self.db.execute_fetchall(*self._update_sql("tx", { 'height': tx.height, 'position': tx.position, 'is_verified': tx.is_verified }, 'txid = ?', (tx.id,))) @@ -391,7 +390,7 @@ class BaseDatabase(SQLiteMixin): if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == txhash: conn.execute(*self._insert_sql( "txo", self.txo_to_row(tx, address, txo), ignore_duplicate=True - )) + )).fetchall() elif txo.script.is_pay_script_hash: # TODO: implement script hash payments log.warning('Database.save_transaction_io: pay script hash is not implemented!') @@ -404,7 +403,7 @@ class BaseDatabase(SQLiteMixin): 'txid': tx.id, 'txoid': txo.id, 'address': address, - }, ignore_duplicate=True)) + }, ignore_duplicate=True)).fetchall() conn.execute( "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", @@ -619,7 +618,7 @@ class BaseDatabase(SQLiteMixin): ) async def _set_address_history(self, address, history): - await self.db.execute( + await self.db.execute_fetchall( "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", (history, history.count(':')//2, address) ) diff --git a/torba/tox.ini b/torba/tox.ini index 3e2c2e424..8ec8dfcdf 100644 --- a/torba/tox.ini +++ b/torba/tox.ini @@ -13,6 +13,6 @@ changedir = {toxinidir}/tests setenv = integration: TORBA_LEDGER={envname} commands = - unit: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -t . client_tests.unit + unit: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -vv -t . client_tests.unit integration: orchstr8 download - integration: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -t . client_tests.integration + integration: coverage run -p --source={envsitepackagesdir}/torba -m unittest discover -vv -t . client_tests.integration