diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index bed88e3a6..571a27823 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -60,6 +60,9 @@ class AIOSQLite: self.writer_connection: Optional[sqlite3.Connection] = None self._closing = False self.query_count = 0 + self.write_lock = asyncio.Lock() + self.writers = 0 + self.read_ready = asyncio.Event() @classmethod async def connect(cls, path: Union[bytes, str], *args, **kwargs): @@ -74,6 +77,7 @@ class AIOSQLite: max_workers=readers, initializer=initializer, initargs=(path, ) ) await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer) + db.read_ready.set() return db async def close(self): @@ -83,6 +87,7 @@ class AIOSQLite: await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close) self.writer_executor.shutdown(wait=True) self.reader_executor.shutdown(wait=True) + self.read_ready.clear() self.writer_connection = None def executemany(self, sql: str, params: Iterable): @@ -93,32 +98,44 @@ class AIOSQLite: def executescript(self, script: str) -> Awaitable: return self.run(lambda conn: conn.executescript(script)) + async def _execute_fetch(self, sql: str, parameters: Iterable = None, + read_only: bool = False, fetch_all: bool = False) -> Iterable[sqlite3.Row]: + read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone + parameters = parameters if parameters is not None else [] + if read_only: + while self.writers: + await self.read_ready.wait() + return await asyncio.get_event_loop().run_in_executor( + self.reader_executor, read_only_fn, sql, parameters + ) + if fetch_all: + return await self.run(lambda conn: conn.execute(sql, parameters).fetchall()) + return await self.run(lambda conn: conn.execute(sql, parameters).fetchone()) + async def execute_fetchall(self, sql: str, parameters: Iterable = None, read_only: bool = False) -> Iterable[sqlite3.Row]: - parameters = parameters if parameters is not None else [] - if read_only: - return await asyncio.get_event_loop().run_in_executor( - self.reader_executor, run_read_only_fetchall, sql, parameters - ) - return await self.run(lambda conn: conn.execute(sql, parameters).fetchall()) + return await self._execute_fetch(sql, parameters, read_only, fetch_all=True) - def execute_fetchone(self, sql: str, parameters: Iterable = None, - read_only: bool = False) -> Awaitable[Iterable[sqlite3.Row]]: - parameters = parameters if parameters is not None else [] - if read_only: - return asyncio.get_event_loop().run_in_executor( - self.reader_executor, run_read_only_fetchone, sql, parameters - ) - return self.run(lambda conn: conn.execute(sql, parameters).fetchone()) + async def execute_fetchone(self, sql: str, parameters: Iterable = None, + read_only: bool = False) -> Iterable[sqlite3.Row]: + return await self._execute_fetch(sql, parameters, read_only, fetch_all=False) def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: parameters = parameters if parameters is not None else [] return self.run(lambda conn: conn.execute(sql, parameters)) - def run(self, fun, *args, **kwargs) -> Awaitable: - return asyncio.get_event_loop().run_in_executor( - self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs) - ) + async def run(self, fun, *args, **kwargs): + self.writers += 1 + self.read_ready.clear() + async with self.write_lock: + try: + return await asyncio.get_event_loop().run_in_executor( + self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs) + ) + finally: + self.writers -= 1 + if not self.writers: + self.read_ready.set() def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): self.writer_connection.execute('begin')