create parent class for sql databases

This commit is contained in:
ThomasV 2019-03-06 09:56:22 +01:00
parent b861e2e955
commit d8e9a9a49e
3 changed files with 71 additions and 81 deletions

View file

@ -35,13 +35,11 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
import binascii import binascii
import base64 import base64
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_ from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session from .sql_db import SqlDB, sql
from . import constants from . import constants
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
@ -212,50 +210,25 @@ class Address(Base):
last_connected_date = Column(DateTime(), nullable=False) last_connected_date = Column(DateTime(), nullable=False)
class ChannelDB(PrintError):
class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20 NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
self.network = network path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base)
print(Base)
self.num_nodes = 0 self.num_nodes = 0
self.num_channels = 0 self.num_channels = 0
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self) self.ca_verifier = LNChannelVerifier(network, self)
self.db_requests = queue.Queue() self.update_counts()
threading.Thread(target=self.sql_thread).start()
def sql_thread(self): @sql
self.sql_thread = threading.currentThread() def update_counts(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
self._update_counts() self._update_counts()
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func):
def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
def _update_counts(self): def _update_counts(self):
self.num_channels = self.DBSession.query(ChannelInfo).count() self.num_channels = self.DBSession.query(ChannelInfo).count()

View file

@ -11,9 +11,14 @@ from collections import defaultdict
import asyncio import asyncio
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import NamedTuple, Dict from typing import NamedTuple, Dict
import jsonrpclib import jsonrpclib
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from .sql_db import SqlDB, sql
from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions
from . import wallet from . import wallet
from .storage import WalletStorage from .storage import WalletStorage
@ -37,14 +42,6 @@ class TxMinedDepth(IntEnum):
FREE = auto() FREE = auto()
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import not_, or_
from sqlalchemy.orm import scoped_session
Base = declarative_base() Base = declarative_base()
class SweepTx(Base): class SweepTx(Base):
@ -60,42 +57,11 @@ class ChannelInfo(Base):
outpoint = Column(String(34)) outpoint = Column(String(34))
class SweepStore(PrintError):
class SweepStore(SqlDB):
def __init__(self, path, network): def __init__(self, path, network):
PrintError.__init__(self) super().__init__(network, path, Base)
self.path = path
self.network = network
self.db_requests = queue.Queue()
threading.Thread(target=self.sql_thread).start()
def sql_thread(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
Base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")
def sql(func):
def wrapper(self, *args, **kwargs):
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
@sql @sql
def get_sweep_tx(self, funding_outpoint, prev_txid): def get_sweep_tx(self, funding_outpoint, prev_txid):

51
electrum/sql_db.py Normal file
View file

@ -0,0 +1,51 @@
import os
import concurrent
import queue
import threading
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from sqlalchemy.orm import sessionmaker
from .util import PrintError
def sql(func):
"""wrapper for sql methods"""
def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future()
self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10)
return wrapper
class SqlDB(PrintError):
def __init__(self, network, path, base):
self.base = base
self.network = network
self.path = path
self.db_requests = queue.Queue()
self.sql_thread = threading.Thread(target=self.run_sql)
self.sql_thread.start()
def run_sql(self):
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path):
self.base.metadata.create_all(engine)
while self.network.asyncio_loop.is_running():
try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
except queue.Empty:
continue
try:
result = func(self, *args, **kwargs)
except BaseException as e:
future.set_exception(e)
continue
future.set_result(result)
# write
self.DBSession.commit()
self.print_error("SQL thread terminated")