mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-09-03 20:35:13 +00:00
create parent class for sql databases
This commit is contained in:
parent
b861e2e955
commit
d8e9a9a49e
3 changed files with 71 additions and 81 deletions
|
@ -35,13 +35,11 @@ from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECK
|
||||||
import binascii
|
import binascii
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
|
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlalchemy.orm.query import Query
|
from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.sql import not_, or_
|
from sqlalchemy.sql import not_, or_
|
||||||
from sqlalchemy.orm import scoped_session
|
from .sql_db import SqlDB, sql
|
||||||
|
|
||||||
from . import constants
|
from . import constants
|
||||||
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
|
from .util import PrintError, bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits
|
||||||
|
@ -212,50 +210,25 @@ class Address(Base):
|
||||||
last_connected_date = Column(DateTime(), nullable=False)
|
last_connected_date = Column(DateTime(), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class ChannelDB(PrintError):
|
|
||||||
|
|
||||||
|
class ChannelDB(SqlDB):
|
||||||
|
|
||||||
NUM_MAX_RECENT_PEERS = 20
|
NUM_MAX_RECENT_PEERS = 20
|
||||||
|
|
||||||
def __init__(self, network: 'Network'):
|
def __init__(self, network: 'Network'):
|
||||||
self.network = network
|
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
||||||
|
super().__init__(network, path, Base)
|
||||||
|
print(Base)
|
||||||
self.num_nodes = 0
|
self.num_nodes = 0
|
||||||
self.num_channels = 0
|
self.num_channels = 0
|
||||||
self.path = os.path.join(get_headers_dir(network.config), 'channel_db.sqlite3')
|
|
||||||
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
|
||||||
self.ca_verifier = LNChannelVerifier(network, self)
|
self.ca_verifier = LNChannelVerifier(network, self)
|
||||||
self.db_requests = queue.Queue()
|
self.update_counts()
|
||||||
threading.Thread(target=self.sql_thread).start()
|
|
||||||
|
|
||||||
def sql_thread(self):
|
@sql
|
||||||
self.sql_thread = threading.currentThread()
|
def update_counts(self):
|
||||||
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
|
|
||||||
DBSession = sessionmaker(bind=engine, autoflush=False)
|
|
||||||
self.DBSession = DBSession()
|
|
||||||
if not os.path.exists(self.path):
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
self._update_counts()
|
self._update_counts()
|
||||||
while self.network.asyncio_loop.is_running():
|
|
||||||
try:
|
|
||||||
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
result = func(self, *args, **kwargs)
|
|
||||||
except BaseException as e:
|
|
||||||
future.set_exception(e)
|
|
||||||
continue
|
|
||||||
future.set_result(result)
|
|
||||||
# write
|
|
||||||
self.DBSession.commit()
|
|
||||||
self.print_error("SQL thread terminated")
|
|
||||||
|
|
||||||
def sql(func):
|
|
||||||
def wrapper(self, *args, **kwargs):
|
|
||||||
assert threading.currentThread() != self.sql_thread
|
|
||||||
f = concurrent.futures.Future()
|
|
||||||
self.db_requests.put((f, func, args, kwargs))
|
|
||||||
return f.result(timeout=10)
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
def _update_counts(self):
|
def _update_counts(self):
|
||||||
self.num_channels = self.DBSession.query(ChannelInfo).count()
|
self.num_channels = self.DBSession.query(ChannelInfo).count()
|
||||||
|
|
|
@ -11,9 +11,14 @@ from collections import defaultdict
|
||||||
import asyncio
|
import asyncio
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import NamedTuple, Dict
|
from typing import NamedTuple, Dict
|
||||||
|
|
||||||
import jsonrpclib
|
import jsonrpclib
|
||||||
|
|
||||||
|
from sqlalchemy import Column, ForeignKey, Integer, String, DateTime, Boolean
|
||||||
|
from sqlalchemy.orm.query import Query
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.sql import not_, or_
|
||||||
|
from .sql_db import SqlDB, sql
|
||||||
|
|
||||||
from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions
|
from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions
|
||||||
from . import wallet
|
from . import wallet
|
||||||
from .storage import WalletStorage
|
from .storage import WalletStorage
|
||||||
|
@ -37,14 +42,6 @@ class TxMinedDepth(IntEnum):
|
||||||
FREE = auto()
|
FREE = auto()
|
||||||
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from sqlalchemy.orm.query import Query
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.sql import not_, or_
|
|
||||||
from sqlalchemy.orm import scoped_session
|
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
class SweepTx(Base):
|
class SweepTx(Base):
|
||||||
|
@ -60,42 +57,11 @@ class ChannelInfo(Base):
|
||||||
outpoint = Column(String(34))
|
outpoint = Column(String(34))
|
||||||
|
|
||||||
|
|
||||||
class SweepStore(PrintError):
|
|
||||||
|
class SweepStore(SqlDB):
|
||||||
|
|
||||||
def __init__(self, path, network):
|
def __init__(self, path, network):
|
||||||
PrintError.__init__(self)
|
super().__init__(network, path, Base)
|
||||||
self.path = path
|
|
||||||
self.network = network
|
|
||||||
self.db_requests = queue.Queue()
|
|
||||||
threading.Thread(target=self.sql_thread).start()
|
|
||||||
|
|
||||||
def sql_thread(self):
|
|
||||||
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)
|
|
||||||
DBSession = sessionmaker(bind=engine, autoflush=False)
|
|
||||||
self.DBSession = DBSession()
|
|
||||||
if not os.path.exists(self.path):
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
while self.network.asyncio_loop.is_running():
|
|
||||||
try:
|
|
||||||
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
result = func(self, *args, **kwargs)
|
|
||||||
except BaseException as e:
|
|
||||||
future.set_exception(e)
|
|
||||||
continue
|
|
||||||
future.set_result(result)
|
|
||||||
# write
|
|
||||||
self.DBSession.commit()
|
|
||||||
self.print_error("SQL thread terminated")
|
|
||||||
|
|
||||||
def sql(func):
|
|
||||||
def wrapper(self, *args, **kwargs):
|
|
||||||
f = concurrent.futures.Future()
|
|
||||||
self.db_requests.put((f, func, args, kwargs))
|
|
||||||
return f.result(timeout=10)
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
@sql
|
@sql
|
||||||
def get_sweep_tx(self, funding_outpoint, prev_txid):
|
def get_sweep_tx(self, funding_outpoint, prev_txid):
|
||||||
|
|
51
electrum/sql_db.py
Normal file
51
electrum/sql_db.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
import os
|
||||||
|
import concurrent
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from .util import PrintError
|
||||||
|
|
||||||
|
|
||||||
|
def sql(func):
|
||||||
|
"""wrapper for sql methods"""
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
assert threading.currentThread() != self.sql_thread
|
||||||
|
f = concurrent.futures.Future()
|
||||||
|
self.db_requests.put((f, func, args, kwargs))
|
||||||
|
return f.result(timeout=10)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
class SqlDB(PrintError):
|
||||||
|
|
||||||
|
def __init__(self, network, path, base):
|
||||||
|
self.base = base
|
||||||
|
self.network = network
|
||||||
|
self.path = path
|
||||||
|
self.db_requests = queue.Queue()
|
||||||
|
self.sql_thread = threading.Thread(target=self.run_sql)
|
||||||
|
self.sql_thread.start()
|
||||||
|
|
||||||
|
def run_sql(self):
|
||||||
|
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
|
||||||
|
DBSession = sessionmaker(bind=engine, autoflush=False)
|
||||||
|
self.DBSession = DBSession()
|
||||||
|
if not os.path.exists(self.path):
|
||||||
|
self.base.metadata.create_all(engine)
|
||||||
|
while self.network.asyncio_loop.is_running():
|
||||||
|
try:
|
||||||
|
future, func, args, kwargs = self.db_requests.get(timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
result = func(self, *args, **kwargs)
|
||||||
|
except BaseException as e:
|
||||||
|
future.set_exception(e)
|
||||||
|
continue
|
||||||
|
future.set_result(result)
|
||||||
|
# write
|
||||||
|
self.DBSession.commit()
|
||||||
|
self.print_error("SQL thread terminated")
|
Loading…
Add table
Reference in a new issue