diff --git a/lbrynet/extras/daemon/storage.py b/lbrynet/extras/daemon/storage.py index b6f98f944..d4ae93c08 100644 --- a/lbrynet/extras/daemon/storage.py +++ b/lbrynet/extras/daemon/storage.py @@ -1,18 +1,22 @@ -import asyncio import logging -import os -import traceback +import sqlite3 import typing -from binascii import hexlify, unhexlify -from lbrynet.extras.wallet.dewies import dewies_to_lbc, lbc_to_dewies +import asyncio +import binascii +from torba.client.basedatabase import SQLiteMixin from lbrynet.conf import Config +from lbrynet.extras.wallet.dewies import dewies_to_lbc, lbc_to_dewies from lbrynet.schema.claim import ClaimDict from lbrynet.schema.decode import smart_decode -from lbrynet.blob.CryptBlob import CryptBlobInfo -from lbrynet.dht.constants import dataExpireTimeout -from torba.client.basedatabase import SQLiteMixin +from lbrynet.dht.constants import data_expiration + +if typing.TYPE_CHECKING: + from lbrynet.blob.blob_file import BlobFile + from lbrynet.stream.descriptor import StreamDescriptor log = logging.getLogger(__name__) +opt_str = typing.Optional[str] +opt_int = typing.Optional[int] def calculate_effective_amount(amount: str, supports: typing.Optional[typing.List[typing.Dict]] = None) -> str: @@ -21,50 +25,97 @@ def calculate_effective_amount(amount: str, supports: typing.Optional[typing.Lis ) -def _get_next_available_file_name(download_directory, file_name): - base_name, ext = os.path.splitext(file_name) - i = 0 - while os.path.isfile(os.path.join(download_directory, file_name)): - i += 1 - file_name = "%s_%i%s" % (base_name, i, ext) - return os.path.join(download_directory, file_name) +class StoredStreamClaim: + def __init__(self, stream_hash: str, outpoint: opt_str = None, claim_id: opt_str = None, name: opt_str = None, + amount: opt_int = None, height: opt_int = None, serialized: opt_str = None, + channel_claim_id: opt_str = None, address: opt_str = None, claim_sequence: opt_int = None, + channel_name: opt_str = None): + self.stream_hash = stream_hash + self.claim_id = claim_id + self.outpoint = outpoint + self.claim_name = name + self.amount = amount + self.height = height + self.claim: typing.Optional[ClaimDict] = None if not serialized else smart_decode(serialized) + self.claim_address = address + self.claim_sequence = claim_sequence + self.channel_claim_id = channel_claim_id + self.channel_name = channel_name + + @property + def txid(self) -> typing.Optional[str]: + return None if not self.outpoint else self.outpoint.split(":")[0] + + @property + def nout(self) -> typing.Optional[int]: + return None if not self.outpoint else int(self.outpoint.split(":")[1]) + + @property + def metadata(self) -> typing.Optional[typing.Dict]: + return None if not self.claim else self.claim.claim_dict['stream']['metadata'] + + def as_dict(self) -> typing.Dict: + return { + "name": self.claim_name, + "claim_id": self.claim_id, + "address": self.claim_address, + "claim_sequence": self.claim_sequence, + "value": self.claim, + "height": self.height, + "amount": dewies_to_lbc(self.amount), + "nout": self.nout, + "txid": self.txid, + "channel_claim_id": self.channel_claim_id, + "channel_name": self.channel_name + } -def _open_file_for_writing(download_directory, suggested_file_name): - file_path = _get_next_available_file_name(download_directory, suggested_file_name) - try: - file_handle = open(file_path, 'wb') - file_handle.close() - except IOError: - log.error(traceback.format_exc()) - raise ValueError( - "Failed to open %s. Make sure you have permission to save files to that location." % file_path - ) - return os.path.basename(file_path) - - -async def open_file_for_writing(download_directory: str, suggested_file_name: str) -> str: - """ Used to touch the path of a file to be downloaded. """ - return await asyncio.get_event_loop().run_in_executor( - None, _open_file_for_writing, download_directory, suggested_file_name +def get_claims_from_stream_hashes(transaction: sqlite3.Connection, + stream_hashes: typing.List[str]) -> typing.Dict[str, StoredStreamClaim]: + query = ( + "select content_claim.stream_hash, c.*, case when c.channel_claim_id is not null then " + " (select claim_name from claim where claim_id==c.channel_claim_id) " + " else null end as channel_name " + " from content_claim " + " inner join claim c on c.claim_outpoint=content_claim.claim_outpoint and content_claim.stream_hash in {}" + " order by c.rowid desc" ) + return { + claim_info.stream_hash: claim_info + for claim_info in [ + None if not claim_info else StoredStreamClaim(*claim_info) + for claim_info in _batched_select(transaction, query, stream_hashes) + ] + } -async def looping_call(interval, fun): - while True: - try: - await fun() - except Exception as e: - log.exception('Looping call experienced exception:', exc_info=e) - await asyncio.sleep(interval) +def get_content_claim_from_outpoint(transaction: sqlite3.Connection, + outpoint: str) -> typing.Optional[StoredStreamClaim]: + query = ( + "select content_claim.stream_hash, c.*, case when c.channel_claim_id is not null then " + " (select claim_name from claim where claim_id==c.channel_claim_id) " + " else null end as channel_name " + " from content_claim " + " inner join claim c on c.claim_outpoint=content_claim.claim_outpoint and content_claim.claim_outpoint=?" + ) + claim_fields = transaction.execute(query, (outpoint, )).fetchone() + if claim_fields: + return StoredStreamClaim(*claim_fields) + + +def _batched_select(transaction, query, parameters): + for start_index in range(0, len(parameters), 900): + current_batch = parameters[start_index:start_index+900] + bind = "({})".format(','.join(['?'] * len(current_batch))) + for result in transaction.execute(query.format(bind), current_batch): + yield result class SQLiteStorage(SQLiteMixin): - CREATE_TABLES_QUERY = """ pragma foreign_keys=on; pragma journal_mode=WAL; - + create table if not exists blob ( blob_hash char(96) primary key not null, blob_length integer not null, @@ -74,7 +125,7 @@ class SQLiteStorage(SQLiteMixin): last_announced_time integer, single_announce integer ); - + create table if not exists stream ( stream_hash char(96) not null primary key, sd_hash char(96) not null references blob, @@ -82,7 +133,7 @@ class SQLiteStorage(SQLiteMixin): stream_name text not null, suggested_filename text not null ); - + create table if not exists stream_blob ( stream_hash char(96) not null references stream, blob_hash char(96) references blob, @@ -90,7 +141,7 @@ class SQLiteStorage(SQLiteMixin): iv char(32) not null, primary key (stream_hash, blob_hash) ); - + create table if not exists claim ( claim_outpoint text not null primary key, claim_id char(40) not null, @@ -110,20 +161,20 @@ class SQLiteStorage(SQLiteMixin): blob_data_rate real not null, status text not null ); - + create table if not exists content_claim ( stream_hash text unique not null references file, claim_outpoint text not null references claim, primary key (stream_hash, claim_outpoint) ); - + create table if not exists support ( support_outpoint text not null primary key, claim_id text not null, amount integer not null, address text not null ); - + create table if not exists reflected_stream ( sd_hash text not null, reflector_address text not null, @@ -136,21 +187,8 @@ class SQLiteStorage(SQLiteMixin): super().__init__(path) self.conf = conf self.content_claim_callbacks = {} - self.check_should_announce_lc = None self.loop = loop or asyncio.get_event_loop() - async def open(self): - await super().open() - if 'reflector' not in self.conf.components_to_skip: - self.check_should_announce_lc = looping_call( - 600, self.verify_will_announce_all_head_and_sd_blobs - ) - - async def close(self): - if self.check_should_announce_lc is not None: - self.check_should_announce_lc.close() - await super().close() - async def run_and_return_one_or_none(self, query, *args): for row in await self.db.execute_fetchall(query, args): if len(row) == 1: @@ -161,33 +199,29 @@ class SQLiteStorage(SQLiteMixin): rows = list(await self.db.execute_fetchall(query, args)) return [col[0] for col in rows] if rows else [] - async def run_and_return_id(self, query, *args): - return (await self.db.execute(query, args)).lastrowid - # # # # # # # # # blob functions # # # # # # # # # - def add_completed_blob(self, blob_hash, length, next_announce_time, should_announce, status="finished"): - log.debug("Adding a completed blob. blob_hash=%s, length=%i", blob_hash, length) - values = (blob_hash, length, next_announce_time or 0, int(bool(should_announce)), status, 0, 0) - return self.db.execute("insert or replace into blob values (?, ?, ?, ?, ?, ?, ?)", values) + def add_completed_blob(self, blob_hash: str): + log.debug("Adding a completed blob. blob_hash=%s", blob_hash) + return self.db.execute("update blob set status='finished' where blob.blob_hash=?", (blob_hash, )) - def set_should_announce(self, blob_hash, next_announce_time, should_announce): + def set_should_announce(self, blob_hash: str, next_announce_time: int, should_announce: int): return self.db.execute( "update blob set next_announce_time=?, should_announce=? where blob_hash=?", (next_announce_time or 0, int(bool(should_announce)), blob_hash) ) - def get_blob_status(self, blob_hash): + def get_blob_status(self, blob_hash: str): return self.run_and_return_one_or_none( "select status from blob where blob_hash=?", blob_hash ) - def add_known_blob(self, blob_hash, length): + def add_known_blob(self, blob_hash: str, length: int): return self.db.execute( "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", (blob_hash, length, 0, 0, "pending", 0, 0) ) - def should_announce(self, blob_hash): + def should_announce(self, blob_hash: str): return self.run_and_return_one_or_none( "select should_announce from blob where blob_hash=?", blob_hash ) @@ -202,22 +236,25 @@ class SQLiteStorage(SQLiteMixin): "select blob_hash from blob where should_announce=1 and status='finished'" ) - async def get_all_finished_blobs(self): - blob_hashes = await self.run_and_return_list( + def get_all_finished_blobs(self): + return self.run_and_return_list( "select blob_hash from blob where status='finished'" ) - return [unhexlify(blob_hash) for blob_hash in blob_hashes] def count_finished_blobs(self): return self.run_and_return_one_or_none( "select count(*) from blob where status='finished'" ) - def update_last_announced_blob(self, blob_hash, last_announced): - return self.db.execute( - "update blob set next_announce_time=?, last_announced_time=?, single_announce=0 where blob_hash=?", - (int(last_announced + (dataExpireTimeout / 2)), int(last_announced), blob_hash) - ) + def update_last_announced_blobs(self, blob_hashes: typing.List[str], last_announced: float): + def _update_last_announced_blobs(transaction: sqlite3.Connection): + return transaction.executemany( + "update blob set next_announce_time=?, last_announced_time=?, single_announce=0 " + "where blob_hash=?", + [(int(last_announced + (data_expiration / 2)), int(last_announced), blob_hash) + for blob_hash in blob_hashes] + ) + return self.db.run(_update_last_announced_blobs) def should_single_announce_blobs(self, blob_hashes, immediate=False): def set_single_announce(transaction): @@ -230,7 +267,7 @@ class SQLiteStorage(SQLiteMixin): ) else: transaction.execute( - "update blob set single_announce=1 where blob_hash=? and status='finished'", (blob_hash, ) + "update blob set single_announce=1 where blob_hash=? and status='finished'", (blob_hash,) ) return self.db.run(set_single_announce) @@ -255,203 +292,87 @@ class SQLiteStorage(SQLiteMixin): def delete_blobs_from_db(self, blob_hashes): def delete_blobs(transaction): - for blob_hash in blob_hashes: - transaction.execute("delete from blob where blob_hash=?;", (blob_hash,)) + transaction.executemany( + "delete from blob where blob_hash=?;", [(blob_hash,) for blob_hash in blob_hashes] + ) return self.db.run(delete_blobs) def get_all_blob_hashes(self): return self.run_and_return_list("select blob_hash from blob") - # # # # # # # # # stream blob functions # # # # # # # # # - - def add_blobs_to_stream(self, stream_hash, blob_infos): - def _add_stream_blobs(transaction): - for blob_info in blob_infos: - transaction.execute("insert into stream_blob values (?, ?, ?, ?)", - (stream_hash, blob_info.get('blob_hash', None), - blob_info['blob_num'], blob_info['iv'])) - return self.db.run(_add_stream_blobs) - - async def add_known_blobs(self, blob_infos): - for blob_info in blob_infos: - if blob_info.get('blob_hash') and blob_info['length']: - await self.add_known_blob(blob_info['blob_hash'], blob_info['length']) - - def verify_will_announce_head_and_sd_blobs(self, stream_hash): - # fix should_announce for imported head and sd blobs - return self.db.execute( - "update blob set should_announce=1 " - "where should_announce=0 and " - "blob.blob_hash in " - " (select b.blob_hash from blob b inner join stream s on b.blob_hash=s.sd_hash and s.stream_hash=?) " - "or blob.blob_hash in " - " (select b.blob_hash from blob b " - " inner join stream_blob s2 on b.blob_hash=s2.blob_hash and s2.position=0 and s2.stream_hash=?)", - (stream_hash, stream_hash) - ) - - def verify_will_announce_all_head_and_sd_blobs(self): - return self.db.execute( - "update blob set should_announce=1 " - "where should_announce=0 and " - "blob.blob_hash in " - " (select b.blob_hash from blob b inner join stream s on b.blob_hash=s.sd_hash) " - "or blob.blob_hash in " - " (select b.blob_hash from blob b " - " inner join stream_blob s2 on b.blob_hash=s2.blob_hash and s2.position=0)" - ) - # # # # # # # # # stream functions # # # # # # # # # - def store_stream(self, stream_hash, sd_hash, stream_name, stream_key, suggested_file_name, - stream_blob_infos): - """ - Add a stream to the stream table + async def stream_exists(self, sd_hash: str) -> bool: + streams = await self.run_and_return_one_or_none("select stream_hash from stream where sd_hash=?", sd_hash) + return streams is not None - :param stream_hash: hash of the assembled stream - :param sd_hash: hash of the sd blob - :param stream_key: blob decryption key - :param stream_name: the name of the file the stream was generated from - :param suggested_file_name: (str) suggested file name for stream - :param stream_blob_infos: (list) of blob info dictionaries - :return: (defer.Deferred) - """ - def _store_stream(transaction): + async def file_exists(self, sd_hash: str) -> bool: + streams = await self.run_and_return_one_or_none("select f.stream_hash from file f " + "inner join stream s on " + "s.stream_hash=f.stream_hash and s.sd_hash=?", sd_hash) + return streams is not None + + def store_stream(self, sd_blob: 'BlobFile', descriptor: 'StreamDescriptor'): + def _store_stream(transaction: sqlite3.Connection): + # add the head blob and set it to be announced transaction.execute( - "insert into stream values (?, ?, ?, ?, ?);", ( - stream_hash, sd_hash, stream_key, stream_name, suggested_file_name + "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?), (?, ?, ?, ?, ?, ?, ?)", + ( + sd_blob.blob_hash, sd_blob.length, 0, 1, "pending", 0, 0, + descriptor.blobs[0].blob_hash, descriptor.blobs[0].length, 0, 1, "pending", 0, 0 ) ) - for blob_info in stream_blob_infos: - transaction.execute( - "insert into stream_blob values (?, ?, ?, ?)", ( - stream_hash, blob_info.get('blob_hash', None), - blob_info['blob_num'], blob_info['iv'] - ) + # add the rest of the blobs with announcement off + if len(descriptor.blobs) > 2: + transaction.executemany( + "insert or ignore into blob values (?, ?, ?, ?, ?, ?, ?)", + [(blob.blob_hash, blob.length, 0, 0, "pending", 0, 0) + for blob in descriptor.blobs[1:-1]] ) + # associate the blobs to the stream + transaction.execute("insert or ignore into stream values (?, ?, ?, ?, ?)", + (descriptor.stream_hash, sd_blob.blob_hash, descriptor.key, + binascii.hexlify(descriptor.stream_name.encode()).decode(), + binascii.hexlify(descriptor.suggested_file_name.encode()).decode())) + # add the stream + transaction.executemany( + "insert or ignore into stream_blob values (?, ?, ?, ?)", + [(descriptor.stream_hash, blob.blob_hash, blob.blob_num, blob.iv) + for blob in descriptor.blobs] + ) + return self.db.run(_store_stream) - async def delete_stream(self, stream_hash): - sd_hash = await self.get_sd_blob_hash_for_stream(stream_hash) - stream_blobs = await self.get_blobs_for_stream(stream_hash) - blob_hashes = [b.blob_hash for b in stream_blobs if b.blob_hash is not None] - - def _delete_stream(transaction): - transaction.execute("delete from content_claim where stream_hash=? ", (stream_hash,)) - transaction.execute("delete from file where stream_hash=? ", (stream_hash, )) - transaction.execute("delete from stream_blob where stream_hash=?", (stream_hash, )) - transaction.execute("delete from stream where stream_hash=? ", (stream_hash, )) - transaction.execute("delete from blob where blob_hash=?", (sd_hash, )) - for blob_hash in blob_hashes: - transaction.execute("delete from blob where blob_hash=?;", (blob_hash, )) - - await self.db.run(_delete_stream) - - def get_all_streams(self): - return self.run_and_return_list("select stream_hash from stream") - - def get_stream_info(self, stream_hash): - return self.run_and_return_one_or_none( - "select stream_name, stream_key, suggested_filename, sd_hash from stream " - "where stream_hash=?", stream_hash - ) - - async def check_if_stream_exists(self, stream_hash): - row = await self.run_and_return_one_or_none( - "select stream_hash from stream where stream_hash=?", stream_hash - ) - if row is not None: - return bool(len(row)) - return False - - def get_blob_num_by_hash(self, stream_hash, blob_hash): - return self.run_and_return_one_or_none( - "select position from stream_blob where stream_hash=? and blob_hash=?", - stream_hash, blob_hash - ) - - def get_stream_blob_by_position(self, stream_hash, blob_num): - return self.run_and_return_one_or_none( - "select blob_hash from stream_blob where stream_hash=? and position=?", - stream_hash, blob_num - ) - - def get_blobs_for_stream(self, stream_hash, only_completed=False): - def _get_blobs_for_stream(transaction): - crypt_blob_infos = [] - stream_blobs = transaction.execute( - "select blob_hash, position, iv from stream_blob where stream_hash=?", (stream_hash, ) - ).fetchall() - if only_completed: - lengths = transaction.execute( - "select b.blob_hash, b.blob_length from blob b " - "inner join stream_blob s ON b.blob_hash=s.blob_hash and b.status='finished' and s.stream_hash=?", - (stream_hash, ) - ).fetchall() - else: - lengths = transaction.execute( - "select b.blob_hash, b.blob_length from blob b " - "inner join stream_blob s ON b.blob_hash=s.blob_hash and s.stream_hash=?", - (stream_hash, ) - ).fetchall() - - blob_length_dict = {} - for blob_hash, length in lengths: - blob_length_dict[blob_hash] = length - - for blob_hash, position, iv in stream_blobs: - blob_length = blob_length_dict.get(blob_hash, 0) - crypt_blob_infos.append(CryptBlobInfo(blob_hash, position, blob_length, iv)) - crypt_blob_infos = sorted(crypt_blob_infos, key=lambda info: info.blob_num) - return crypt_blob_infos - return self.db.run(_get_blobs_for_stream) - - def get_pending_blobs_for_stream(self, stream_hash): - return self.run_and_return_list( - "select s.blob_hash from stream_blob s " - "inner join blob b on b.blob_hash=s.blob_hash and b.status='pending' " - "where stream_hash=?", - stream_hash - ) - - def get_stream_of_blob(self, blob_hash): - return self.run_and_return_one_or_none( - "select stream_hash from stream_blob where blob_hash=?", blob_hash - ) - - def get_sd_blob_hash_for_stream(self, stream_hash): - return self.run_and_return_one_or_none( - "select sd_hash from stream where stream_hash=?", stream_hash - ) - - def get_stream_hash_for_sd_hash(self, sd_blob_hash): - return self.run_and_return_one_or_none( - "select stream_hash from stream where sd_hash = ?", sd_blob_hash - ) + def delete_stream(self, descriptor: 'StreamDescriptor'): + def _delete_stream(transaction: sqlite3.Connection): + transaction.execute("delete from content_claim where stream_hash=? ", (descriptor.stream_hash,)) + transaction.execute("delete from file where stream_hash=? ", (descriptor.stream_hash, )) + transaction.execute("delete from stream_blob where stream_hash=?", (descriptor.stream_hash, )) + transaction.execute("delete from stream where stream_hash=? ", (descriptor.stream_hash, )) + transaction.execute("delete from blob where blob_hash=?", (descriptor.sd_hash, )) + transaction.executemany("delete from blob where blob_hash=?", + [(blob.blob_hash, ) for blob in descriptor.blobs[:-1]]) + return self.db.run(_delete_stream) # # # # # # # # # file stuff # # # # # # # # # - async def save_downloaded_file(self, stream_hash, file_name, download_directory, data_payment_rate): - # touch the closest available file to the file name - file_name = await open_file_for_writing(unhexlify(download_directory).decode(), unhexlify(file_name).decode()) - return await self.save_published_file( - stream_hash, hexlify(file_name.encode()), download_directory, data_payment_rate + def save_downloaded_file(self, stream_hash, file_name, download_directory, data_payment_rate): + return self.save_published_file( + stream_hash, binascii.hexlify(file_name.encode()).decode(), + binascii.hexlify(download_directory.encode()).decode(), data_payment_rate, + status="running" ) - def save_published_file(self, stream_hash, file_name, download_directory, data_payment_rate, status="stopped"): - return self.run_and_return_id( + def save_published_file(self, stream_hash: str, file_name: str, download_directory: str, data_payment_rate: float, + status="finished"): + return self.db.execute( "insert into file values (?, ?, ?, ?, ?)", - stream_hash, file_name, download_directory, data_payment_rate, status + (stream_hash, file_name, download_directory, data_payment_rate, status) ) - def get_filename_for_rowid(self, rowid): - return self.run_and_return_one_or_none( - "select file_name from file where rowid=?", rowid - ) - - def get_all_lbry_files(self): + async def get_all_lbry_files(self) -> typing.List[typing.Dict]: def _lbry_file_dict(rowid, stream_hash, file_name, download_dir, data_rate, status, _, sd_hash, stream_key, - stream_name, suggested_file_name): + stream_name, suggested_file_name) -> typing.Dict: return { "row_id": rowid, "stream_hash": stream_hash, @@ -465,36 +386,35 @@ class SQLiteStorage(SQLiteMixin): "suggested_file_name": suggested_file_name } - def _get_all_files(transaction): - return [ - _lbry_file_dict(*file_info) for file_info in transaction.execute( - "select file.rowid, file.*, stream.* " - "from file inner join stream on file.stream_hash=stream.stream_hash" - ).fetchall() - ] + def _get_all_files(transaction: sqlite3.Connection) -> typing.List[typing.Dict]: + file_infos = list(map(lambda a: _lbry_file_dict(*a), transaction.execute( + "select file.rowid, file.*, stream.* " + "from file inner join stream on file.stream_hash=stream.stream_hash" + ).fetchall())) + stream_hashes = [file_info['stream_hash'] for file_info in file_infos] + claim_infos = get_claims_from_stream_hashes(transaction, stream_hashes) + for index in range(len(file_infos)): # pylint: disable=consider-using-enumerate + file_infos[index]['claim'] = claim_infos.get(file_infos[index]['stream_hash']) + return file_infos - return self.db.run(_get_all_files) + results = await self.db.run(_get_all_files) + if results: + return results + return [] - async def change_file_status(self, rowid, new_status): - await self.db.execute("update file set status=? where rowid=?", (new_status, rowid)) - return new_status + def change_file_status(self, stream_hash: str, new_status: str): + log.info("update file status %s -> %s", stream_hash, new_status) + return self.db.execute("update file set status=? where stream_hash=?", (new_status, stream_hash)) - def get_lbry_file_status(self, rowid): - return self.run_and_return_one_or_none( - "select status from file where rowid = ?", rowid - ) - - def get_rowid_for_stream_hash(self, stream_hash): - return self.run_and_return_one_or_none( - "select rowid from file where stream_hash=?", stream_hash - ) + def get_all_stream_hashes(self): + return self.run_and_return_list("select stream_hash from stream") # # # # # # # # # support functions # # # # # # # # # def save_supports(self, claim_id, supports): # TODO: add 'address' to support items returned for a claim from lbrycrdd and lbryum-server def _save_support(transaction): - transaction.execute("delete from support where claim_id=?", (claim_id, )) + transaction.execute("delete from support where claim_id=?", (claim_id,)) for support in supports: transaction.execute( "insert into support values (?, ?, ?, ?)", @@ -552,30 +472,32 @@ class SQLiteStorage(SQLiteMixin): source_hash = None except AttributeError: source_hash = None - serialized = claim_info.get('hex') or hexlify(smart_decode(claim_info['value']).serialized) + serialized = claim_info.get('hex') or binascii.hexlify( + smart_decode(claim_info['value']).serialized).decode() transaction.execute( "insert or replace into claim values (?, ?, ?, ?, ?, ?, ?, ?, ?)", (outpoint, claim_id, name, amount, height, serialized, certificate_id, address, sequence) ) - if 'supports' in claim_info: # if this response doesn't have support info don't overwrite the existing - # support info + # if this response doesn't have support info don't overwrite the existing + # support info + if 'supports' in claim_info: support_callbacks.append((claim_id, claim_info['supports'])) if not source_hash: continue stream_hash = transaction.execute( "select file.stream_hash from stream " - "inner join file on file.stream_hash=stream.stream_hash where sd_hash=?", (source_hash, ) + "inner join file on file.stream_hash=stream.stream_hash where sd_hash=?", (source_hash,) ).fetchone() if not stream_hash: continue stream_hash = stream_hash[0] known_outpoint = transaction.execute( - "select claim_outpoint from content_claim where stream_hash=?", (stream_hash, ) + "select claim_outpoint from content_claim where stream_hash=?", (stream_hash,) ) known_claim_id = transaction.execute( "select claim_id from claim " "inner join content_claim c3 ON claim.claim_outpoint=c3.claim_outpoint " - "where c3.stream_hash=?", (stream_hash, ) + "where c3.stream_hash=?", (stream_hash,) ) if not known_claim_id: content_claims_to_update.append((stream_hash, outpoint)) @@ -607,14 +529,6 @@ class SQLiteStorage(SQLiteMixin): to_save.append(info['claim']) return self.save_claims(to_save) - def get_old_stream_hashes_for_claim_id(self, claim_id, new_stream_hash): - return self.run_and_return_list( - "select f.stream_hash from file f " - "inner join content_claim cc on f.stream_hash=cc.stream_hash " - "inner join claim c on c.claim_outpoint=cc.claim_outpoint and c.claim_id=? " - "where f.stream_hash!=?", claim_id, new_stream_hash - ) - @staticmethod def _save_content_claim(transaction, claim_outpoint, stream_hash): # get the claim id and serialized metadata @@ -623,7 +537,7 @@ class SQLiteStorage(SQLiteMixin): ).fetchone() if not claim_info: raise Exception("claim not found") - new_claim_id, claim = claim_info[0], ClaimDict.deserialize(unhexlify(claim_info[1])) + new_claim_id, claim = claim_info[0], ClaimDict.deserialize(binascii.unhexlify(claim_info[1])) # certificate claims should not be in the content_claim table if not claim.is_stream: @@ -661,90 +575,31 @@ class SQLiteStorage(SQLiteMixin): if stream_hash in self.content_claim_callbacks: await self.content_claim_callbacks[stream_hash]() - async def get_content_claim(self, stream_hash, include_supports=True): - def _get_claim_from_stream_hash(transaction): - claim_info = transaction.execute( - "select c.*, " - "case when c.channel_claim_id is not null then " - "(select claim_name from claim where claim_id==c.channel_claim_id) " - "else null end as channel_name from content_claim " - "inner join claim c on c.claim_outpoint=content_claim.claim_outpoint " - "and content_claim.stream_hash=? order by c.rowid desc", (stream_hash,) - ).fetchone() - if not claim_info: - return None - channel_name = claim_info[-1] - result = _format_claim_response(*claim_info[:-1]) - if channel_name: - result['channel_name'] = channel_name - return result - - result = await self.db.run(_get_claim_from_stream_hash) - if result and include_supports: - result['supports'] = await self.get_supports(result['claim_id']) - result['effective_amount'] = calculate_effective_amount(result['amount'], result['supports']) - return result - - async def get_claims_from_stream_hashes(self, stream_hashes, include_supports=True): - def _batch_get_claim(transaction): - results = {} - claim_infos = _batched_select( - transaction, - "select content_claim.stream_hash, c.* from content_claim " - "inner join claim c on c.claim_outpoint=content_claim.claim_outpoint " - "and content_claim.stream_hash in {} order by c.rowid desc", - stream_hashes) - channel_id_infos = {} - for claim_info in claim_infos: - if claim_info[7]: - streams = channel_id_infos.get(claim_info[7], []) - streams.append(claim_info[0]) - channel_id_infos[claim_info[7]] = streams - stream_hash = claim_info[0] - result = _format_claim_response(*claim_info[1:]) - results[stream_hash] = result - channel_names = _batched_select( - transaction, - "select claim_id, claim_name from claim where claim_id in {}", - tuple(channel_id_infos.keys()) - ) - for claim_id, channel_name in channel_names: - for stream_hash in channel_id_infos[claim_id]: - results[stream_hash]['channel_name'] = channel_name - return results - - claims = await self.db.run(_batch_get_claim) - if include_supports: - all_supports = {} - for support in await self.get_supports(*[claim['claim_id'] for claim in claims.values()]): - all_supports.setdefault(support['claim_id'], []).append(support) - for stream_hash in claims.keys(): - claim = claims[stream_hash] - supports = all_supports.get(claim['claim_id'], []) + async def get_content_claim(self, stream_hash: str, include_supports: typing.Optional[bool] = True) -> typing.Dict: + claims = await self.db.run(get_claims_from_stream_hashes, [stream_hash]) + claim = None + if claims: + claim = claims[stream_hash].as_dict() + if include_supports: + supports = await self.get_supports(claim['claim_id']) claim['supports'] = supports claim['effective_amount'] = calculate_effective_amount(claim['amount'], supports) - claims[stream_hash] = claim - return claims + return claim + + async def get_claims_from_stream_hashes(self, stream_hashes: typing.List[str], + include_supports: typing.Optional[bool] = True): + claims = await self.db.run(get_claims_from_stream_hashes, stream_hashes) + return {stream_hash: claim_info.as_dict() for stream_hash, claim_info in claims.items()} async def get_claim(self, claim_outpoint, include_supports=True): - def _get_claim(transaction): - claim_info = transaction.execute( - "select c.*, " - "case when c.channel_claim_id is not null then " - "(select claim_name from claim where claim_id==c.channel_claim_id) " - "else null end as channel_name from claim c where claim_outpoint = ?", - (claim_outpoint,) - ).fetchone() - channel_name = claim_info[-1] - result = _format_claim_response(*claim_info[:-1]) - if channel_name: - result['channel_name'] = channel_name - return result - - result = await self.db.run(_get_claim) + claim_info = await self.db.run(get_content_claim_from_outpoint, claim_outpoint) + if not claim_info: + return + result = claim_info.as_dict() if include_supports: - result['supports'] = await self.get_supports(result['claim_id']) - result['effective_amount'] = calculate_effective_amount(result['amount'], result['supports']) + supports = await self.get_supports(result['claim_id']) + result['supports'] = supports + result['effective_amount'] = calculate_effective_amount(result['amount'], supports) return result def get_unknown_certificate_ids(self): @@ -800,29 +655,3 @@ class SQLiteStorage(SQLiteMixin): "where r.timestamp is null or r.timestamp < ?", self.loop.time() - self.conf.auto_re_reflect_interval ) - - -# Helper functions -def _format_claim_response(outpoint, claim_id, name, amount, height, serialized, channel_id, address, claim_sequence): - r = { - "name": name, - "claim_id": claim_id, - "address": address, - "claim_sequence": claim_sequence, - "value": ClaimDict.deserialize(unhexlify(serialized)).claim_dict, - "height": height, - "amount": dewies_to_lbc(amount), - "nout": int(outpoint.split(":")[1]), - "txid": outpoint.split(":")[0], - "channel_claim_id": channel_id, - "channel_name": None - } - return r - - -def _batched_select(transaction, query, parameters): - for start_index in range(0, len(parameters), 900): - current_batch = parameters[start_index:start_index+900] - bind = "({})".format(','.join(['?'] * len(current_batch))) - for result in transaction.execute(query.format(bind), current_batch): - yield result