write lock

This commit is contained in:
Jack Robison 2020-02-25 14:15:27 -05:00
parent a26cfc639c
commit 61603ccfce
No known key found for this signature in database
GPG key ID: DF25C68FE0239BB2

View file

@ -60,6 +60,9 @@ class AIOSQLite:
self.writer_connection: Optional[sqlite3.Connection] = None self.writer_connection: Optional[sqlite3.Connection] = None
self._closing = False self._closing = False
self.query_count = 0 self.query_count = 0
self.write_lock = asyncio.Lock()
self.writers = 0
self.read_ready = asyncio.Event()
@classmethod @classmethod
async def connect(cls, path: Union[bytes, str], *args, **kwargs): async def connect(cls, path: Union[bytes, str], *args, **kwargs):
@ -74,6 +77,7 @@ class AIOSQLite:
max_workers=readers, initializer=initializer, initargs=(path, ) max_workers=readers, initializer=initializer, initargs=(path, )
) )
await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer) await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer)
db.read_ready.set()
return db return db
async def close(self): async def close(self):
@ -83,6 +87,7 @@ class AIOSQLite:
await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close) await asyncio.get_event_loop().run_in_executor(self.writer_executor, self.writer_connection.close)
self.writer_executor.shutdown(wait=True) self.writer_executor.shutdown(wait=True)
self.reader_executor.shutdown(wait=True) self.reader_executor.shutdown(wait=True)
self.read_ready.clear()
self.writer_connection = None self.writer_connection = None
def executemany(self, sql: str, params: Iterable): def executemany(self, sql: str, params: Iterable):
@ -93,32 +98,44 @@ class AIOSQLite:
def executescript(self, script: str) -> Awaitable: def executescript(self, script: str) -> Awaitable:
return self.run(lambda conn: conn.executescript(script)) return self.run(lambda conn: conn.executescript(script))
async def _execute_fetch(self, sql: str, parameters: Iterable = None,
read_only: bool = False, fetch_all: bool = False) -> Iterable[sqlite3.Row]:
read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone
parameters = parameters if parameters is not None else []
if read_only:
while self.writers:
await self.read_ready.wait()
return await asyncio.get_event_loop().run_in_executor(
self.reader_executor, read_only_fn, sql, parameters
)
if fetch_all:
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
return await self.run(lambda conn: conn.execute(sql, parameters).fetchone())
async def execute_fetchall(self, sql: str, parameters: Iterable = None, async def execute_fetchall(self, sql: str, parameters: Iterable = None,
read_only: bool = False) -> Iterable[sqlite3.Row]: read_only: bool = False) -> Iterable[sqlite3.Row]:
parameters = parameters if parameters is not None else [] return await self._execute_fetch(sql, parameters, read_only, fetch_all=True)
if read_only:
return await asyncio.get_event_loop().run_in_executor(
self.reader_executor, run_read_only_fetchall, sql, parameters
)
return await self.run(lambda conn: conn.execute(sql, parameters).fetchall())
def execute_fetchone(self, sql: str, parameters: Iterable = None, async def execute_fetchone(self, sql: str, parameters: Iterable = None,
read_only: bool = False) -> Awaitable[Iterable[sqlite3.Row]]: read_only: bool = False) -> Iterable[sqlite3.Row]:
parameters = parameters if parameters is not None else [] return await self._execute_fetch(sql, parameters, read_only, fetch_all=False)
if read_only:
return asyncio.get_event_loop().run_in_executor(
self.reader_executor, run_read_only_fetchone, sql, parameters
)
return self.run(lambda conn: conn.execute(sql, parameters).fetchone())
def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]: def execute(self, sql: str, parameters: Iterable = None) -> Awaitable[sqlite3.Cursor]:
parameters = parameters if parameters is not None else [] parameters = parameters if parameters is not None else []
return self.run(lambda conn: conn.execute(sql, parameters)) return self.run(lambda conn: conn.execute(sql, parameters))
def run(self, fun, *args, **kwargs) -> Awaitable: async def run(self, fun, *args, **kwargs):
return asyncio.get_event_loop().run_in_executor( self.writers += 1
self.read_ready.clear()
async with self.write_lock:
try:
return await asyncio.get_event_loop().run_in_executor(
self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs) self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs)
) )
finally:
self.writers -= 1
if not self.writers:
self.read_ready.set()
def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs):
self.writer_connection.execute('begin') self.writer_connection.execute('begin')