diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index ea24541b9..46aa3a845 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -41,7 +41,7 @@ test:datanetwork-integration: stage: test script: - pip install tox-travis - - tox -e datanetwork + - tox -e datanetwork --recreate test:blockchain-integration: stage: test @@ -94,6 +94,7 @@ build:linux: - apt-get update - apt-get install -y --no-install-recommends python3.7-dev - python3.7 <(curl -q https://bootstrap.pypa.io/get-pip.py) # make sure we get pip with python3.7 + - pip install lbry-libtorrent build:mac: extends: .build diff --git a/Makefile b/Makefile index d189bb6fa..a6221fa03 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ .PHONY: install tools lint test idea install: + pip install https://s3.amazonaws.com/files.lbry.io/python_libtorrent-1.2.4-py3-none-any.whl CFLAGS="-DSQLITE_MAX_VARIABLE_NUMBER=2500000" pip install -U https://github.com/rogerbinns/apsw/releases/download/3.30.1-r1/apsw-3.30.1-r1.zip \ --global-option=fetch \ --global-option=--version --global-option=3.30.1 --global-option=--all \ @@ -8,7 +9,7 @@ install: pip install -e . tools: - pip install mypy==0.701 + pip install mypy==0.701 pylint==2.4.4 pip install coverage astroid pylint lint: diff --git a/lbry/__init__.py b/lbry/__init__.py index 1591ec860..c7e039eed 100644 --- a/lbry/__init__.py +++ b/lbry/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.69.1" +__version__ = "0.74.0" version = tuple(map(int, __version__.split('.'))) # pylint: disable=invalid-name diff --git a/lbry/blob/blob_file.py b/lbry/blob/blob_file.py index 65e0d4a43..b8c7c461c 100644 --- a/lbry/blob/blob_file.py +++ b/lbry/blob/blob_file.py @@ -110,7 +110,7 @@ class AbstractBlob: if reader in self.readers: self.readers.remove(reader) - def _write_blob(self, blob_bytes: bytes): + def _write_blob(self, blob_bytes: bytes) -> asyncio.Task: raise NotImplementedError() def set_length(self, length) -> None: @@ -198,11 +198,17 @@ class AbstractBlob: def save_verified_blob(self, verified_bytes: bytes): if self.verified.is_set(): return - if self.is_writeable(): - self._write_blob(verified_bytes) + + def update_events(_): self.verified.set() + self.writing.clear() + + if self.is_writeable(): + self.writing.set() + task = self._write_blob(verified_bytes) + task.add_done_callback(update_events) if self.blob_completed_callback: - self.blob_completed_callback(self) + task.add_done_callback(lambda _: self.blob_completed_callback(self)) def get_blob_writer(self, peer_address: typing.Optional[str] = None, peer_port: typing.Optional[int] = None) -> HashBlobWriter: @@ -261,9 +267,11 @@ class BlobBuffer(AbstractBlob): self.verified.clear() def _write_blob(self, blob_bytes: bytes): - if self._verified_bytes: - raise OSError("already have bytes for blob") - self._verified_bytes = BytesIO(blob_bytes) + async def write(): + if self._verified_bytes: + raise OSError("already have bytes for blob") + self._verified_bytes = BytesIO(blob_bytes) + return self.loop.create_task(write()) def delete(self): if self._verified_bytes: @@ -319,8 +327,14 @@ class BlobFile(AbstractBlob): handle.close() def _write_blob(self, blob_bytes: bytes): - with open(self.file_path, 'wb') as f: - f.write(blob_bytes) + def _write_blob(): + with open(self.file_path, 'wb') as f: + f.write(blob_bytes) + + async def write_blob(): + await self.loop.run_in_executor(None, _write_blob) + + return self.loop.create_task(write_blob()) def delete(self): if os.path.isfile(self.file_path): diff --git a/lbry/blob_exchange/client.py b/lbry/blob_exchange/client.py index 408c0d323..61920c5b7 100644 --- a/lbry/blob_exchange/client.py +++ b/lbry/blob_exchange/client.py @@ -152,6 +152,8 @@ class BlobExchangeClientProtocol(asyncio.Protocol): log.debug(msg) msg = f"downloaded {self.blob.blob_hash[:8]} from {self.peer_address}:{self.peer_port}" await asyncio.wait_for(self.writer.finished, self.peer_timeout, loop=self.loop) + # wait for the io to finish + await self.blob.verified.wait() log.info("%s at %fMB/s", msg, round((float(self._blob_bytes_received) / float(time.perf_counter() - start_time)) / 1000000.0, 2)) diff --git a/lbry/conf.py b/lbry/conf.py index fa395d8a9..e9b293933 100644 --- a/lbry/conf.py +++ b/lbry/conf.py @@ -277,14 +277,23 @@ class Strings(ListSetting): class EnvironmentAccess: PREFIX = 'LBRY_' - def __init__(self, environ: dict): - self.environ = environ + def __init__(self, config: 'BaseConfig', environ: dict): + self.configuration = config + self.data = {} + if environ: + self.load(environ) + + def load(self, environ): + for setting in self.configuration.get_settings(): + value = environ.get(f'{self.PREFIX}{setting.name.upper()}', NOT_SET) + if value != NOT_SET and not (isinstance(setting, ListSetting) and value is None): + self.data[setting.name] = setting.deserialize(value) def __contains__(self, item: str): - return f'{self.PREFIX}{item.upper()}' in self.environ + return item in self.data def __getitem__(self, item: str): - return self.environ[f'{self.PREFIX}{item.upper()}'] + return self.data[item] class ArgumentAccess: @@ -443,7 +452,7 @@ class BaseConfig: self.arguments = ArgumentAccess(self, args) def set_environment(self, environ=None): - self.environment = EnvironmentAccess(environ or os.environ) + self.environment = EnvironmentAccess(self, environ or os.environ) def set_persisted(self, config_file_path=None): if config_file_path is None: @@ -469,12 +478,12 @@ class TranscodeConfig(BaseConfig): '', previous_names=['ffmpeg_folder']) video_encoder = String('FFmpeg codec and parameters for the video encoding. ' 'Example: libaom-av1 -crf 25 -b:v 0 -strict experimental', - 'libx264 -crf 21 -preset faster -pix_fmt yuv420p') + 'libx264 -crf 24 -preset faster -pix_fmt yuv420p') video_bitrate_maximum = Integer('Maximum bits per second allowed for video streams (0 to disable).', 8400000) video_scaler = String('FFmpeg scaling parameters for reducing bitrate. ' 'Example: -vf "scale=-2:720,fps=24" -maxrate 5M -bufsize 3M', r'-vf "scale=if(gte(iw\,ih)\,min(1920\,iw)\,-2):if(lt(iw\,ih)\,min(1920\,ih)\,-2)" ' - r'-maxrate 8400K -bufsize 5000K') + r'-maxrate 5500K -bufsize 5000K') audio_encoder = String('FFmpeg codec and parameters for the audio encoding. ' 'Example: libopus -b:a 128k', 'aac -b:a 160k') @@ -577,9 +586,14 @@ class Config(CLIConfig): ) # servers - reflector_servers = Servers("Reflector re-hosting servers", [ + reflector_servers = Servers("Reflector re-hosting servers for mirroring publishes", [ ('reflector.lbry.com', 5566) ]) + + fixed_peers = Servers("Fixed peers to fall back to if none are found on P2P for a blob", [ + ('cdn.reflector.lbry.com', 5567) + ]) + lbryum_servers = Servers("SPV wallet servers", [ ('spv11.lbry.com', 50001), ('spv12.lbry.com', 50001), @@ -622,6 +636,7 @@ class Config(CLIConfig): "Strategy to use when selecting UTXOs for a transaction", STRATEGIES, "standard") + transaction_cache_size = Integer("Transaction cache size", 100_000) save_resolved_claims = Toggle( "Save content claims to the database when they are resolved to keep file_list up to date, " "only disable this if file_x commands are not needed", True @@ -691,9 +706,10 @@ def get_darwin_directories() -> typing.Tuple[str, str, str]: def get_linux_directories() -> typing.Tuple[str, str, str]: try: with open(os.path.join(user_config_dir(), 'user-dirs.dirs'), 'r') as xdg: - down_dir = re.search(r'XDG_DOWNLOAD_DIR=(.+)', xdg.read()).group(1) - down_dir = re.sub(r'\$HOME', os.getenv('HOME') or os.path.expanduser("~/"), down_dir) - download_dir = re.sub('\"', '', down_dir) + down_dir = re.search(r'XDG_DOWNLOAD_DIR=(.+)', xdg.read()) + if down_dir: + down_dir = re.sub(r'\$HOME', os.getenv('HOME') or os.path.expanduser("~/"), down_dir.group(1)) + download_dir = re.sub('\"', '', down_dir) except OSError: download_dir = os.getenv('XDG_DOWNLOAD_DIR') if not download_dir: diff --git a/lbry/dht/protocol/protocol.py b/lbry/dht/protocol/protocol.py index 016ad50bf..7b90b5644 100644 --- a/lbry/dht/protocol/protocol.py +++ b/lbry/dht/protocol/protocol.py @@ -10,6 +10,7 @@ from asyncio.protocols import DatagramProtocol from asyncio.transports import DatagramTransport from lbry.dht import constants +from lbry.dht.serialization.bencoding import DecodeError from lbry.dht.serialization.datagram import decode_datagram, ErrorDatagram, ResponseDatagram, RequestDatagram from lbry.dht.serialization.datagram import RESPONSE_TYPE, ERROR_TYPE, PAGE_KEY from lbry.dht.error import RemoteException, TransportNotConnected @@ -554,7 +555,7 @@ class KademliaProtocol(DatagramProtocol): def datagram_received(self, datagram: bytes, address: typing.Tuple[str, int]) -> None: # pylint: disable=arguments-differ try: message = decode_datagram(datagram) - except (ValueError, TypeError): + except (ValueError, TypeError, DecodeError): self.peer_manager.report_failure(address[0], address[1]) log.warning("Couldn't decode dht datagram from %s: %s", address, binascii.hexlify(datagram).decode()) return diff --git a/lbry/extras/daemon/analytics.py b/lbry/extras/daemon/analytics.py index 828112396..f6983016c 100644 --- a/lbry/extras/daemon/analytics.py +++ b/lbry/extras/daemon/analytics.py @@ -66,7 +66,7 @@ def _download_properties(conf: Config, external_ip: str, resolve_duration: float "node_rpc_timeout": conf.node_rpc_timeout, "peer_connect_timeout": conf.peer_connect_timeout, "blob_download_timeout": conf.blob_download_timeout, - "use_fixed_peers": len(conf.reflector_servers) > 0, + "use_fixed_peers": len(conf.fixed_peers) > 0, "fixed_peer_delay": fixed_peer_delay, "added_fixed_peers": added_fixed_peers, "active_peer_count": active_peer_count, diff --git a/lbry/extras/daemon/components.py b/lbry/extras/daemon/components.py index 5271c1558..38c4d4650 100644 --- a/lbry/extras/daemon/components.py +++ b/lbry/extras/daemon/components.py @@ -17,11 +17,17 @@ from lbry.dht.blob_announcer import BlobAnnouncer from lbry.blob.blob_manager import BlobManager from lbry.blob_exchange.server import BlobServer from lbry.stream.stream_manager import StreamManager +from lbry.file.file_manager import FileManager from lbry.extras.daemon.component import Component from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.storage import SQLiteStorage +from lbry.torrent.torrent_manager import TorrentManager from lbry.wallet import WalletManager from lbry.wallet.usage_payment import WalletServerPayer +try: + from lbry.torrent.session import TorrentSession +except ImportError: + TorrentSession = None log = logging.getLogger(__name__) @@ -33,10 +39,11 @@ WALLET_COMPONENT = "wallet" WALLET_SERVER_PAYMENTS_COMPONENT = "wallet_server_payments" DHT_COMPONENT = "dht" HASH_ANNOUNCER_COMPONENT = "hash_announcer" -STREAM_MANAGER_COMPONENT = "stream_manager" +FILE_MANAGER_COMPONENT = "file_manager" PEER_PROTOCOL_SERVER_COMPONENT = "peer_protocol_server" UPNP_COMPONENT = "upnp" EXCHANGE_RATE_MANAGER_COMPONENT = "exchange_rate_manager" +LIBTORRENT_COMPONENT = "libtorrent_component" class DatabaseComponent(Component): @@ -319,23 +326,23 @@ class HashAnnouncerComponent(Component): } -class StreamManagerComponent(Component): - component_name = STREAM_MANAGER_COMPONENT - depends_on = [BLOB_COMPONENT, DATABASE_COMPONENT, WALLET_COMPONENT] +class FileManagerComponent(Component): + component_name = FILE_MANAGER_COMPONENT + depends_on = [BLOB_COMPONENT, DATABASE_COMPONENT, WALLET_COMPONENT, LIBTORRENT_COMPONENT] def __init__(self, component_manager): super().__init__(component_manager) - self.stream_manager: typing.Optional[StreamManager] = None + self.file_manager: typing.Optional[FileManager] = None @property - def component(self) -> typing.Optional[StreamManager]: - return self.stream_manager + def component(self) -> typing.Optional[FileManager]: + return self.file_manager async def get_status(self): - if not self.stream_manager: + if not self.file_manager: return return { - 'managed_files': len(self.stream_manager.streams), + 'managed_files': len(self.file_manager.get_filtered()), } async def start(self): @@ -344,16 +351,52 @@ class StreamManagerComponent(Component): wallet = self.component_manager.get_component(WALLET_COMPONENT) node = self.component_manager.get_component(DHT_COMPONENT) \ if self.component_manager.has_component(DHT_COMPONENT) else None + torrent = self.component_manager.get_component(LIBTORRENT_COMPONENT) if TorrentSession else None log.info('Starting the file manager') loop = asyncio.get_event_loop() - self.stream_manager = StreamManager( - loop, self.conf, blob_manager, wallet, storage, node, self.component_manager.analytics_manager + self.file_manager = FileManager( + loop, self.conf, wallet, storage, self.component_manager.analytics_manager ) - await self.stream_manager.start() + self.file_manager.source_managers['stream'] = StreamManager( + loop, self.conf, blob_manager, wallet, storage, node, + ) + if TorrentSession: + self.file_manager.source_managers['torrent'] = TorrentManager( + loop, self.conf, torrent, storage, self.component_manager.analytics_manager + ) + await self.file_manager.start() log.info('Done setting up file manager') async def stop(self): - self.stream_manager.stop() + self.file_manager.stop() + + +class TorrentComponent(Component): + component_name = LIBTORRENT_COMPONENT + + def __init__(self, component_manager): + super().__init__(component_manager) + self.torrent_session = None + + @property + def component(self) -> typing.Optional[TorrentSession]: + return self.torrent_session + + async def get_status(self): + if not self.torrent_session: + return + return { + 'running': True, # TODO: what to return here? + } + + async def start(self): + if TorrentSession: + self.torrent_session = TorrentSession(asyncio.get_event_loop(), None) + await self.torrent_session.bind() # TODO: specify host/port + + async def stop(self): + if self.torrent_session: + await self.torrent_session.pause() class PeerProtocolServerComponent(Component): diff --git a/lbry/extras/daemon/daemon.py b/lbry/extras/daemon/daemon.py index 9789bfa14..756ff6bc7 100644 --- a/lbry/extras/daemon/daemon.py +++ b/lbry/extras/daemon/daemon.py @@ -19,7 +19,7 @@ from functools import wraps, partial import ecdsa import base58 from aiohttp import web -from prometheus_client import generate_latest as prom_generate_latest +from prometheus_client import generate_latest as prom_generate_latest, Gauge, Histogram, Counter from google.protobuf.message import DecodeError from lbry.wallet import ( Wallet, ENCRYPT_ON_DISK, SingleKey, HierarchicalDeterministic, @@ -40,7 +40,7 @@ from lbry.error import ( from lbry.extras import system_info from lbry.extras.daemon import analytics from lbry.extras.daemon.components import WALLET_COMPONENT, DATABASE_COMPONENT, DHT_COMPONENT, BLOB_COMPONENT -from lbry.extras.daemon.components import STREAM_MANAGER_COMPONENT +from lbry.extras.daemon.components import FILE_MANAGER_COMPONENT from lbry.extras.daemon.components import EXCHANGE_RATE_MANAGER_COMPONENT, UPNP_COMPONENT from lbry.extras.daemon.componentmanager import RequiredCondition from lbry.extras.daemon.componentmanager import ComponentManager @@ -57,8 +57,8 @@ if typing.TYPE_CHECKING: from lbry.extras.daemon.components import UPnPComponent from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager from lbry.extras.daemon.storage import SQLiteStorage - from lbry.stream.stream_manager import StreamManager from lbry.wallet import WalletManager, Ledger + from lbry.file.file_manager import FileManager log = logging.getLogger(__name__) @@ -290,6 +290,11 @@ class JSONRPCServerType(type): return klass +HISTOGRAM_BUCKETS = ( + .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') +) + + class Daemon(metaclass=JSONRPCServerType): """ LBRYnet daemon, a jsonrpc interface to lbry functions @@ -297,6 +302,28 @@ class Daemon(metaclass=JSONRPCServerType): callable_methods: dict deprecated_methods: dict + pending_requests_metric = Gauge( + "pending_requests", "Number of running api requests", namespace="daemon_api", + labelnames=("method",) + ) + + requests_count_metric = Counter( + "requests_count", "Number of requests received", namespace="daemon_api", + labelnames=("method",) + ) + failed_request_metric = Counter( + "failed_request_count", "Number of failed requests", namespace="daemon_api", + labelnames=("method",) + ) + cancelled_request_metric = Counter( + "cancelled_request_count", "Number of cancelled requests", namespace="daemon_api", + labelnames=("method",) + ) + response_time_metric = Histogram( + "response_time", "Response times", namespace="daemon_api", buckets=HISTOGRAM_BUCKETS, + labelnames=("method",) + ) + def __init__(self, conf: Config, component_manager: typing.Optional[ComponentManager] = None): self.conf = conf self.platform_info = system_info.get_platform() @@ -345,8 +372,8 @@ class Daemon(metaclass=JSONRPCServerType): return self.component_manager.get_component(DATABASE_COMPONENT) @property - def stream_manager(self) -> typing.Optional['StreamManager']: - return self.component_manager.get_component(STREAM_MANAGER_COMPONENT) + def file_manager(self) -> typing.Optional['FileManager']: + return self.component_manager.get_component(FILE_MANAGER_COMPONENT) @property def exchange_rate_manager(self) -> typing.Optional['ExchangeRateManager']: @@ -457,7 +484,6 @@ class Daemon(metaclass=JSONRPCServerType): log.info("Starting LBRYNet Daemon") log.debug("Settings: %s", json.dumps(self.conf.settings_dict, indent=2)) log.info("Platform: %s", json.dumps(self.platform_info, indent=2)) - self.need_connection_status_refresh.set() self._connection_status_task = self.component_manager.loop.create_task( self.keep_connection_status_up_to_date() @@ -583,8 +609,8 @@ class Daemon(metaclass=JSONRPCServerType): else: name, claim_id = name_and_claim_id.split("/") uri = f"lbry://{name}#{claim_id}" - if not self.stream_manager.started.is_set(): - await self.stream_manager.started.wait() + if not self.file_manager.started.is_set(): + await self.file_manager.started.wait() stream = await self.jsonrpc_get(uri) if isinstance(stream, dict): raise web.HTTPServerError(text=stream['error']) @@ -608,11 +634,11 @@ class Daemon(metaclass=JSONRPCServerType): async def _handle_stream_range_request(self, request: web.Request): sd_hash = request.path.split("/stream/")[1] - if not self.stream_manager.started.is_set(): - await self.stream_manager.started.wait() - if sd_hash not in self.stream_manager.streams: + if not self.file_manager.started.is_set(): + await self.file_manager.started.wait() + if sd_hash not in self.file_manager.streams: return web.HTTPNotFound() - return await self.stream_manager.stream_partial_content(request, sd_hash) + return await self.file_manager.stream_partial_content(request, sd_hash) async def _process_rpc_call(self, data): args = data.get('params', {}) @@ -663,20 +689,27 @@ class Daemon(metaclass=JSONRPCServerType): JSONRPCError.CODE_INVALID_PARAMS, params_error_message, ) - + self.pending_requests_metric.labels(method=function_name).inc() + self.requests_count_metric.labels(method=function_name).inc() + start = time.perf_counter() try: result = method(self, *_args, **_kwargs) if asyncio.iscoroutine(result): result = await result return result except asyncio.CancelledError: + self.cancelled_request_metric.labels(method=function_name).inc() log.info("cancelled API call for: %s", function_name) raise except Exception as e: # pylint: disable=broad-except + self.failed_request_metric.labels(method=function_name).inc() log.exception("error handling api request") return JSONRPCError.create_command_exception( command=function_name, args=_args, kwargs=_kwargs, exception=e, traceback=format_exc() ) + finally: + self.pending_requests_metric.labels(method=function_name).dec() + self.response_time_metric.labels(method=function_name).observe(time.perf_counter() - start) def _verify_method_is_callable(self, function_path): if function_path not in self.callable_methods: @@ -825,7 +858,8 @@ class Daemon(metaclass=JSONRPCServerType): 'exchange_rate_manager': (bool), 'hash_announcer': (bool), 'peer_protocol_server': (bool), - 'stream_manager': (bool), + 'file_manager': (bool), + 'libtorrent_component': (bool), 'upnp': (bool), 'wallet': (bool), }, @@ -852,6 +886,9 @@ class Daemon(metaclass=JSONRPCServerType): } ], }, + 'libtorrent_component': { + 'running': (bool) libtorrent was detected and started successfully, + }, 'dht': { 'node_id': (str) lbry dht node id - hex encoded, 'peers_in_routing_table': (int) the number of peers in the routing table, @@ -873,7 +910,7 @@ class Daemon(metaclass=JSONRPCServerType): 'hash_announcer': { 'announce_queue_size': (int) number of blobs currently queued to be announced }, - 'stream_manager': { + 'file_manager': { 'managed_files': (int) count of files in the stream manager, }, 'upnp': { @@ -1044,7 +1081,7 @@ class Daemon(metaclass=JSONRPCServerType): return results @requires(WALLET_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT, - STREAM_MANAGER_COMPONENT) + FILE_MANAGER_COMPONENT) async def jsonrpc_get( self, uri, file_name=None, download_directory=None, timeout=None, save_file=None, wallet_id=None): """ @@ -1070,7 +1107,7 @@ class Daemon(metaclass=JSONRPCServerType): if download_directory and not os.path.isdir(download_directory): return {"error": f"specified download directory \"{download_directory}\" does not exist"} try: - stream = await self.stream_manager.download_stream_from_uri( + stream = await self.file_manager.download_from_uri( uri, self.exchange_rate_manager, timeout, file_name, download_directory, save_file=save_file, wallet=wallet ) @@ -1916,7 +1953,7 @@ class Daemon(metaclass=JSONRPCServerType): File management. """ - @requires(STREAM_MANAGER_COMPONENT) + @requires(FILE_MANAGER_COMPONENT) async def jsonrpc_file_list(self, sort=None, reverse=False, comparison=None, wallet_id=None, page=None, page_size=None, **kwargs): """ @@ -1928,9 +1965,11 @@ class Daemon(metaclass=JSONRPCServerType): [--outpoint=] [--txid=] [--nout=] [--channel_claim_id=] [--channel_name=] [--claim_name=] [--blobs_in_stream=] - [--blobs_remaining=] [--sort=] - [--comparison=] [--full_status=] [--reverse] - [--page=] [--page_size=] [--wallet_id=] + [--download_path=] [--blobs_remaining=] + [--uploading_to_reflector=] [--is_fully_reflected=] + [--status=] [--completed=] [--sort=] [--comparison=] + [--full_status=] [--reverse] [--page=] [--page_size=] + [--wallet_id=] Options: --sd_hash= : (str) get file with matching sd hash @@ -1947,6 +1986,11 @@ class Daemon(metaclass=JSONRPCServerType): --channel_name= : (str) get file with matching channel name --claim_name= : (str) get file with matching claim name --blobs_in_stream : (int) get file with matching blobs in stream + --download_path= : (str) get file with matching download path + --uploading_to_reflector= : (bool) get files currently uploading to reflector + --is_fully_reflected= : (bool) get files that have been uploaded to reflector + --status= : (str) match by status, ( running | finished | stopped ) + --completed= : (bool) match only completed --blobs_remaining= : (int) amount of remaining blobs to download --sort= : (str) field to sort by (one of the above filter fields) --comparison= : (str) logical comparison, (eq | ne | g | ge | l | le | in) @@ -1961,7 +2005,7 @@ class Daemon(metaclass=JSONRPCServerType): comparison = comparison or 'eq' paginated = paginate_list( - self.stream_manager.get_filtered_streams(sort, reverse, comparison, **kwargs), page, page_size + self.file_manager.get_filtered(sort, reverse, comparison, **kwargs), page, page_size ) if paginated['items']: receipts = { @@ -1975,7 +2019,7 @@ class Daemon(metaclass=JSONRPCServerType): stream.purchase_receipt = receipts.get(stream.claim_id) return paginated - @requires(STREAM_MANAGER_COMPONENT) + @requires(FILE_MANAGER_COMPONENT) async def jsonrpc_file_set_status(self, status, **kwargs): """ Start or stop downloading a file @@ -1999,12 +2043,14 @@ class Daemon(metaclass=JSONRPCServerType): if status not in ['start', 'stop']: raise Exception('Status must be "start" or "stop".') - streams = self.stream_manager.get_filtered_streams(**kwargs) + streams = self.file_manager.get_filtered(**kwargs) if not streams: raise Exception(f'Unable to find a file for {kwargs}') stream = streams[0] if status == 'start' and not stream.running: - await stream.save_file(node=self.stream_manager.node) + if not hasattr(stream, 'bt_infohash') and 'dht' not in self.conf.components_to_skip: + stream.downloader.node = self.dht_node + await stream.save_file() msg = "Resumed download" elif status == 'stop' and stream.running: await stream.stop() @@ -2016,7 +2062,7 @@ class Daemon(metaclass=JSONRPCServerType): ) return msg - @requires(STREAM_MANAGER_COMPONENT) + @requires(FILE_MANAGER_COMPONENT) async def jsonrpc_file_delete(self, delete_from_download_dir=False, delete_all=False, **kwargs): """ Delete a LBRY file @@ -2048,7 +2094,7 @@ class Daemon(metaclass=JSONRPCServerType): (bool) true if deletion was successful """ - streams = self.stream_manager.get_filtered_streams(**kwargs) + streams = self.file_manager.get_filtered(**kwargs) if len(streams) > 1: if not delete_all: @@ -2065,12 +2111,12 @@ class Daemon(metaclass=JSONRPCServerType): else: for stream in streams: message = f"Deleted file {stream.file_name}" - await self.stream_manager.delete_stream(stream, delete_file=delete_from_download_dir) + await self.file_manager.delete(stream, delete_file=delete_from_download_dir) log.info(message) result = True return result - @requires(STREAM_MANAGER_COMPONENT) + @requires(FILE_MANAGER_COMPONENT) async def jsonrpc_file_save(self, file_name=None, download_directory=None, **kwargs): """ Start saving a file to disk. @@ -2097,7 +2143,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: {File} """ - streams = self.stream_manager.get_filtered_streams(**kwargs) + streams = self.file_manager.get_filtered(**kwargs) if len(streams) > 1: log.warning("There are %i matching files, use narrower filters to select one", len(streams)) @@ -2106,6 +2152,8 @@ class Daemon(metaclass=JSONRPCServerType): log.warning("There is no file to save") return False stream = streams[0] + if not hasattr(stream, 'bt_infohash') and 'dht' not in self.conf.components_to_skip: + stream.downloader.node = self.dht_node await stream.save_file(file_name, download_directory) return stream @@ -2237,7 +2285,7 @@ class Daemon(metaclass=JSONRPCServerType): Returns: {Paginated[Output]} """ kwargs['type'] = claim_type or CLAIM_TYPE_NAMES - if 'is_spent' not in kwargs: + if not kwargs.get('is_spent', False): kwargs['is_not_spent'] = True return self.jsonrpc_txo_list(**kwargs) @@ -2513,7 +2561,7 @@ class Daemon(metaclass=JSONRPCServerType): name, claim, amount, claim_address, funding_accounts, funding_accounts[0] ) txo = tx.outputs[0] - txo.generate_channel_private_key() + await txo.generate_channel_private_key() await tx.sign(funding_accounts) @@ -2665,7 +2713,7 @@ class Daemon(metaclass=JSONRPCServerType): new_txo = tx.outputs[0] if new_signing_key: - new_txo.generate_channel_private_key() + await new_txo.generate_channel_private_key() else: new_txo.private_key = old_txo.private_key @@ -2872,7 +2920,7 @@ class Daemon(metaclass=JSONRPCServerType): Create, update, abandon, list and inspect your stream claims. """ - @requires(WALLET_COMPONENT, STREAM_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT) + @requires(WALLET_COMPONENT, FILE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT) async def jsonrpc_publish(self, name, **kwargs): """ Create or replace a stream claim at a given name (use 'stream create/update' for more control). @@ -2994,7 +3042,7 @@ class Daemon(metaclass=JSONRPCServerType): f"to update a specific stream claim." ) - @requires(WALLET_COMPONENT, STREAM_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT) + @requires(WALLET_COMPONENT, FILE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT) async def jsonrpc_stream_repost(self, name, bid, claim_id, allow_duplicate_name=False, channel_id=None, channel_name=None, channel_account_id=None, account_id=None, wallet_id=None, claim_address=None, funding_account_ids=None, preview=False, blocking=False): @@ -3066,7 +3114,7 @@ class Daemon(metaclass=JSONRPCServerType): return tx - @requires(WALLET_COMPONENT, STREAM_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT) + @requires(WALLET_COMPONENT, FILE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT) async def jsonrpc_stream_create( self, name, bid, file_path, allow_duplicate_name=False, channel_id=None, channel_name=None, channel_account_id=None, @@ -3204,7 +3252,7 @@ class Daemon(metaclass=JSONRPCServerType): file_stream = None if not preview: - file_stream = await self.stream_manager.create_stream(file_path) + file_stream = await self.file_manager.create_stream(file_path) claim.stream.source.sd_hash = file_stream.sd_hash new_txo.script.generate() @@ -3224,7 +3272,7 @@ class Daemon(metaclass=JSONRPCServerType): return tx - @requires(WALLET_COMPONENT, STREAM_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT) + @requires(WALLET_COMPONENT, FILE_MANAGER_COMPONENT, BLOB_COMPONENT, DATABASE_COMPONENT) async def jsonrpc_stream_update( self, claim_id, bid=None, file_path=None, channel_id=None, channel_name=None, channel_account_id=None, clear_channel=False, @@ -3414,11 +3462,12 @@ class Daemon(metaclass=JSONRPCServerType): stream_hash = None if not preview: - old_stream = self.stream_manager.streams.get(old_txo.claim.stream.source.sd_hash, None) + old_stream = self.file_manager.get_filtered(sd_hash=old_txo.claim.stream.source.sd_hash) + old_stream = old_stream[0] if old_stream else None if file_path is not None: if old_stream: - await self.stream_manager.delete_stream(old_stream, delete_file=False) - file_stream = await self.stream_manager.create_stream(file_path) + await self.file_manager.delete(old_stream, delete_file=False) + file_stream = await self.file_manager.create_stream(file_path) new_txo.claim.stream.source.sd_hash = file_stream.sd_hash new_txo.script.generate() stream_hash = file_stream.stream_hash @@ -3580,7 +3629,6 @@ class Daemon(metaclass=JSONRPCServerType): given name. default: false. --title= : (str) title of the collection --description=<description> : (str) description of the collection - --clear_languages : (bool) clear existing languages (prior to adding new ones) --tags=<tags> : (list) content tags --clear_languages : (bool) clear existing languages (prior to adding new ones) --languages=<languages> : (list) languages used by the collection, @@ -4550,9 +4598,9 @@ class Daemon(metaclass=JSONRPCServerType): """ if not blob_hash or not is_valid_blobhash(blob_hash): return f"Invalid blob hash to delete '{blob_hash}'" - streams = self.stream_manager.get_filtered_streams(sd_hash=blob_hash) + streams = self.file_manager.get_filtered(sd_hash=blob_hash) if streams: - await self.stream_manager.delete_stream(streams[0]) + await self.file_manager.delete(streams[0]) else: await self.blob_manager.delete_blobs([blob_hash]) return "Deleted %s" % blob_hash @@ -4725,7 +4773,7 @@ class Daemon(metaclass=JSONRPCServerType): raise NotImplementedError() - @requires(STREAM_MANAGER_COMPONENT) + @requires(FILE_MANAGER_COMPONENT) async def jsonrpc_file_reflect(self, **kwargs): """ Reflect all the blobs in a file matching the filter criteria @@ -4754,8 +4802,8 @@ class Daemon(metaclass=JSONRPCServerType): else: server, port = random.choice(self.conf.reflector_servers) reflected = await asyncio.gather(*[ - self.stream_manager.reflect_stream(stream, server, port) - for stream in self.stream_manager.get_filtered_streams(**kwargs) + self.file_manager['stream'].reflect_stream(stream, server, port) + for stream in self.file_manager.get_filtered_streams(**kwargs) ]) total = [] for reflected_for_stream in reflected: @@ -5301,10 +5349,10 @@ class Daemon(metaclass=JSONRPCServerType): results = await self.ledger.resolve(accounts, urls, **kwargs) if self.conf.save_resolved_claims and results: try: - claims = self.stream_manager._convert_to_old_resolve_output(self.wallet_manager, results) - await self.storage.save_claims_for_resolve([ - value for value in claims.values() if 'error' not in value - ]) + await self.storage.save_claim_from_output( + self.ledger, + *(result for result in results.values() if isinstance(result, Output)) + ) except DecodeError: pass return results diff --git a/lbry/extras/daemon/json_response_encoder.py b/lbry/extras/daemon/json_response_encoder.py index a4246bfab..99d487cd2 100644 --- a/lbry/extras/daemon/json_response_encoder.py +++ b/lbry/extras/daemon/json_response_encoder.py @@ -7,6 +7,7 @@ from json import JSONEncoder from google.protobuf.message import DecodeError from lbry.schema.claim import Claim +from lbry.torrent.torrent_manager import TorrentSource from lbry.wallet import Wallet, Ledger, Account, Transaction, Output from lbry.wallet.bip32 import PubKey from lbry.wallet.dewies import dewies_to_lbc @@ -109,7 +110,8 @@ def encode_file_doc(): 'channel_claim_id': '(str) None if claim is not found or not signed', 'channel_name': '(str) None if claim is not found or not signed', 'claim_name': '(str) None if claim is not found else the claim name', - 'reflector_progress': '(int) reflector upload progress, 0 to 100' + 'reflector_progress': '(int) reflector upload progress, 0 to 100', + 'uploading_to_reflector': '(bool) set to True when currently uploading to reflector' } @@ -125,7 +127,7 @@ class JSONResponseEncoder(JSONEncoder): return self.encode_account(obj) if isinstance(obj, Wallet): return self.encode_wallet(obj) - if isinstance(obj, ManagedStream): + if isinstance(obj, (ManagedStream, TorrentSource)): return self.encode_file(obj) if isinstance(obj, Transaction): return self.encode_transaction(obj) @@ -272,26 +274,32 @@ class JSONResponseEncoder(JSONEncoder): output_exists = managed_stream.output_file_exists tx_height = managed_stream.stream_claim_info.height best_height = self.ledger.headers.height - return { - 'streaming_url': managed_stream.stream_url, + is_stream = hasattr(managed_stream, 'stream_hash') + if is_stream: + total_bytes_lower_bound = managed_stream.descriptor.lower_bound_decrypted_length() + total_bytes = managed_stream.descriptor.upper_bound_decrypted_length() + else: + total_bytes_lower_bound = total_bytes = managed_stream.torrent_length + result = { + 'streaming_url': None, 'completed': managed_stream.completed, - 'file_name': managed_stream.file_name if output_exists else None, - 'download_directory': managed_stream.download_directory if output_exists else None, - 'download_path': managed_stream.full_path if output_exists else None, + 'file_name': None, + 'download_directory': None, + 'download_path': None, 'points_paid': 0.0, 'stopped': not managed_stream.running, - 'stream_hash': managed_stream.stream_hash, - 'stream_name': managed_stream.descriptor.stream_name, - 'suggested_file_name': managed_stream.descriptor.suggested_file_name, - 'sd_hash': managed_stream.descriptor.sd_hash, - 'mime_type': managed_stream.mime_type, - 'key': managed_stream.descriptor.key, - 'total_bytes_lower_bound': managed_stream.descriptor.lower_bound_decrypted_length(), - 'total_bytes': managed_stream.descriptor.upper_bound_decrypted_length(), + 'stream_hash': None, + 'stream_name': None, + 'suggested_file_name': None, + 'sd_hash': None, + 'mime_type': None, + 'key': None, + 'total_bytes_lower_bound': total_bytes_lower_bound, + 'total_bytes': total_bytes, 'written_bytes': managed_stream.written_bytes, - 'blobs_completed': managed_stream.blobs_completed, - 'blobs_in_stream': managed_stream.blobs_in_stream, - 'blobs_remaining': managed_stream.blobs_remaining, + 'blobs_completed': None, + 'blobs_in_stream': None, + 'blobs_remaining': None, 'status': managed_stream.status, 'claim_id': managed_stream.claim_id, 'txid': managed_stream.txid, @@ -308,9 +316,37 @@ class JSONResponseEncoder(JSONEncoder): 'height': tx_height, 'confirmations': (best_height + 1) - tx_height if tx_height > 0 else tx_height, 'timestamp': self.ledger.headers.estimated_timestamp(tx_height), - 'is_fully_reflected': managed_stream.is_fully_reflected, - 'reflector_progress': managed_stream.reflector_progress + 'is_fully_reflected': False, + 'reflector_progress': False, + 'uploading_to_reflector': False } + if is_stream: + result.update({ + 'streaming_url': managed_stream.stream_url, + 'stream_hash': managed_stream.stream_hash, + 'stream_name': managed_stream.descriptor.stream_name, + 'suggested_file_name': managed_stream.descriptor.suggested_file_name, + 'sd_hash': managed_stream.descriptor.sd_hash, + 'mime_type': managed_stream.mime_type, + 'key': managed_stream.descriptor.key, + 'blobs_completed': managed_stream.blobs_completed, + 'blobs_in_stream': managed_stream.blobs_in_stream, + 'blobs_remaining': managed_stream.blobs_remaining, + 'is_fully_reflected': managed_stream.is_fully_reflected, + 'reflector_progress': managed_stream.reflector_progress, + 'uploading_to_reflector': managed_stream.uploading_to_reflector + }) + else: + result.update({ + 'streaming_url': f'file://{managed_stream.full_path}', + }) + if output_exists: + result.update({ + 'file_name': managed_stream.file_name, + 'download_directory': managed_stream.download_directory, + 'download_path': managed_stream.full_path, + }) + return result def encode_claim(self, claim): encoded = getattr(claim, claim.claim_type).to_dict() diff --git a/lbry/extras/daemon/storage.py b/lbry/extras/daemon/storage.py index 11a61e45e..1387f94a7 100644 --- a/lbry/extras/daemon/storage.py +++ b/lbry/extras/daemon/storage.py @@ -9,7 +9,7 @@ from typing import Optional from lbry.wallet import SQLiteMixin from lbry.conf import Config from lbry.wallet.dewies import dewies_to_lbc, lbc_to_dewies -from lbry.wallet.transaction import Transaction +from lbry.wallet.transaction import Transaction, Output from lbry.schema.claim import Claim from lbry.dht.constants import DATA_EXPIRATION from lbry.blob.blob_info import BlobInfo @@ -727,6 +727,19 @@ class SQLiteStorage(SQLiteMixin): if claim_id_to_supports: await self.save_supports(claim_id_to_supports) + def save_claim_from_output(self, ledger, *outputs: Output): + return self.save_claims([{ + "claim_id": output.claim_id, + "name": output.claim_name, + "amount": dewies_to_lbc(output.amount), + "address": output.get_address(ledger), + "txid": output.tx_ref.id, + "nout": output.position, + "value": output.claim, + "height": output.tx_ref.height, + "claim_sequence": -1, + } for output in outputs]) + def save_claims_for_resolve(self, claim_infos): to_save = {} for info in claim_infos: @@ -740,7 +753,8 @@ class SQLiteStorage(SQLiteMixin): return self.save_claims(to_save.values()) @staticmethod - def _save_content_claim(transaction, claim_outpoint, stream_hash): + def _save_content_claim(transaction, claim_outpoint, stream_hash=None, bt_infohash=None): + assert stream_hash or bt_infohash # get the claim id and serialized metadata claim_info = transaction.execute( "select claim_id, serialized_metadata from claim where claim_outpoint=?", (claim_outpoint,) @@ -788,6 +802,19 @@ class SQLiteStorage(SQLiteMixin): if stream_hash in self.content_claim_callbacks: await self.content_claim_callbacks[stream_hash]() + async def save_torrent_content_claim(self, bt_infohash, claim_outpoint, length, name): + def _save_torrent(transaction): + transaction.execute( + "insert or replace into torrent values (?, NULL, ?, ?)", (bt_infohash, length, name) + ).fetchall() + transaction.execute( + "insert or replace into content_claim values (NULL, ?, ?)", (bt_infohash, claim_outpoint) + ).fetchall() + await self.db.run(_save_torrent) + # update corresponding ManagedEncryptedFileDownloader object + if bt_infohash in self.content_claim_callbacks: + await self.content_claim_callbacks[bt_infohash]() + async def get_content_claim(self, stream_hash: str, include_supports: typing.Optional[bool] = True) -> typing.Dict: claims = await self.db.run(get_claims_from_stream_hashes, [stream_hash]) claim = None @@ -799,6 +826,10 @@ class SQLiteStorage(SQLiteMixin): claim['effective_amount'] = calculate_effective_amount(claim['amount'], supports) return claim + async def get_content_claim_for_torrent(self, bt_infohash): + claims = await self.db.run(get_claims_from_torrent_info_hashes, [bt_infohash]) + return claims[bt_infohash].as_dict() if claims else None + # # # # # # # # # reflector functions # # # # # # # # # def update_reflected_stream(self, sd_hash, reflector_address, success=True): diff --git a/lbry/file/__init__.py b/lbry/file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lbry/file/file_manager.py b/lbry/file/file_manager.py new file mode 100644 index 000000000..906362858 --- /dev/null +++ b/lbry/file/file_manager.py @@ -0,0 +1,290 @@ +import asyncio +import logging +import typing +from typing import Optional +from aiohttp.web import Request +from lbry.error import ResolveError, DownloadSDTimeoutError, InsufficientFundsError +from lbry.error import ResolveTimeoutError, DownloadDataTimeoutError, KeyFeeAboveMaxAllowedError +from lbry.stream.managed_stream import ManagedStream +from lbry.torrent.torrent_manager import TorrentSource +from lbry.utils import cache_concurrent +from lbry.schema.url import URL +from lbry.wallet.dewies import dewies_to_lbc +from lbry.file.source_manager import SourceManager +from lbry.file.source import ManagedDownloadSource +if typing.TYPE_CHECKING: + from lbry.conf import Config + from lbry.extras.daemon.analytics import AnalyticsManager + from lbry.extras.daemon.storage import SQLiteStorage + from lbry.wallet import WalletManager, Output + from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager + +log = logging.getLogger(__name__) + + +class FileManager: + def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', wallet_manager: 'WalletManager', + storage: 'SQLiteStorage', analytics_manager: Optional['AnalyticsManager'] = None): + self.loop = loop + self.config = config + self.wallet_manager = wallet_manager + self.storage = storage + self.analytics_manager = analytics_manager + self.source_managers: typing.Dict[str, SourceManager] = {} + self.started = asyncio.Event() + + @property + def streams(self): + return self.source_managers['stream']._sources + + async def create_stream(self, file_path: str, key: Optional[bytes] = None, **kwargs) -> ManagedDownloadSource: + if 'stream' in self.source_managers: + return await self.source_managers['stream'].create(file_path, key, **kwargs) + raise NotImplementedError + + async def start(self): + await asyncio.gather(*(source_manager.start() for source_manager in self.source_managers.values())) + for manager in self.source_managers.values(): + await manager.started.wait() + self.started.set() + + def stop(self): + for manager in self.source_managers.values(): + # fixme: pop or not? + manager.stop() + self.started.clear() + + @cache_concurrent + async def download_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager', + timeout: Optional[float] = None, file_name: Optional[str] = None, + download_directory: Optional[str] = None, + save_file: Optional[bool] = None, resolve_timeout: float = 3.0, + wallet: Optional['Wallet'] = None) -> ManagedDownloadSource: + + wallet = wallet or self.wallet_manager.default_wallet + timeout = timeout or self.config.download_timeout + start_time = self.loop.time() + resolved_time = None + stream = None + claim = None + error = None + outpoint = None + if save_file is None: + save_file = self.config.save_files + if file_name and not save_file: + save_file = True + if save_file: + download_directory = download_directory or self.config.download_dir + else: + download_directory = None + + payment = None + try: + # resolve the claim + if not URL.parse(uri).has_stream: + raise ResolveError("cannot download a channel claim, specify a /path") + try: + resolved_result = await asyncio.wait_for( + self.wallet_manager.ledger.resolve(wallet.accounts, [uri], include_purchase_receipt=True), + resolve_timeout + ) + except asyncio.TimeoutError: + raise ResolveTimeoutError(uri) + except Exception as err: + if isinstance(err, asyncio.CancelledError): + raise + log.exception("Unexpected error resolving stream:") + raise ResolveError(f"Unexpected error resolving stream: {str(err)}") + if 'error' in resolved_result: + raise ResolveError(f"Unexpected error resolving uri for download: {resolved_result['error']}") + if not resolved_result or uri not in resolved_result: + raise ResolveError(f"Failed to resolve stream at '{uri}'") + txo = resolved_result[uri] + if isinstance(txo, dict): + raise ResolveError(f"Failed to resolve stream at '{uri}': {txo}") + claim = txo.claim + outpoint = f"{txo.tx_ref.id}:{txo.position}" + resolved_time = self.loop.time() - start_time + await self.storage.save_claim_from_output(self.wallet_manager.ledger, txo) + + #################### + # update or replace + #################### + + if claim.stream.source.bt_infohash: + source_manager = self.source_managers['torrent'] + existing = source_manager.get_filtered(bt_infohash=claim.stream.source.bt_infohash) + else: + source_manager = self.source_managers['stream'] + existing = source_manager.get_filtered(sd_hash=claim.stream.source.sd_hash) + + # resume or update an existing stream, if the stream changed: download it and delete the old one after + to_replace, updated_stream = None, None + if existing and existing[0].claim_id != txo.claim_id: + raise ResolveError(f"stream for {existing[0].claim_id} collides with existing download {txo.claim_id}") + if existing: + log.info("claim contains a metadata only update to a stream we have") + if claim.stream.source.bt_infohash: + await self.storage.save_torrent_content_claim( + existing[0].identifier, outpoint, existing[0].torrent_length, existing[0].torrent_name + ) + claim_info = await self.storage.get_content_claim_for_torrent(existing[0].identifier) + existing[0].set_claim(claim_info, claim) + else: + await self.storage.save_content_claim( + existing[0].stream_hash, outpoint + ) + await source_manager._update_content_claim(existing[0]) + updated_stream = existing[0] + else: + existing_for_claim_id = self.get_filtered(claim_id=txo.claim_id) + if existing_for_claim_id: + log.info("claim contains an update to a stream we have, downloading it") + if save_file and existing_for_claim_id[0].output_file_exists: + save_file = False + if not claim.stream.source.bt_infohash: + existing_for_claim_id[0].downloader.node = source_manager.node + await existing_for_claim_id[0].start(timeout=timeout, save_now=save_file) + if not existing_for_claim_id[0].output_file_exists and ( + save_file or file_name or download_directory): + await existing_for_claim_id[0].save_file( + file_name=file_name, download_directory=download_directory + ) + to_replace = existing_for_claim_id[0] + + # resume or update an existing stream, if the stream changed: download it and delete the old one after + if updated_stream: + log.info("already have stream for %s", uri) + if save_file and updated_stream.output_file_exists: + save_file = False + if not claim.stream.source.bt_infohash: + updated_stream.downloader.node = source_manager.node + await updated_stream.start(timeout=timeout, save_now=save_file) + if not updated_stream.output_file_exists and (save_file or file_name or download_directory): + await updated_stream.save_file( + file_name=file_name, download_directory=download_directory + ) + return updated_stream + + #################### + # pay fee + #################### + + if not to_replace and txo.has_price and not txo.purchase_receipt: + payment = await self.wallet_manager.create_purchase_transaction( + wallet.accounts, txo, exchange_rate_manager + ) + + #################### + # make downloader and wait for start + #################### + + if not claim.stream.source.bt_infohash: + # fixme: this shouldnt be here + stream = ManagedStream( + self.loop, self.config, source_manager.blob_manager, claim.stream.source.sd_hash, + download_directory, file_name, ManagedStream.STATUS_RUNNING, content_fee=payment, + analytics_manager=self.analytics_manager + ) + stream.downloader.node = source_manager.node + else: + stream = TorrentSource( + self.loop, self.config, self.storage, identifier=claim.stream.source.bt_infohash, + file_name=file_name, download_directory=download_directory or self.config.download_dir, + status=ManagedStream.STATUS_RUNNING, + analytics_manager=self.analytics_manager, + torrent_session=source_manager.torrent_session + ) + log.info("starting download for %s", uri) + + before_download = self.loop.time() + await stream.start(timeout, save_file) + + #################### + # success case: delete to_replace if applicable, broadcast fee payment + #################### + + if to_replace: # delete old stream now that the replacement has started downloading + await source_manager.delete(to_replace) + + if payment is not None: + await self.wallet_manager.broadcast_or_release(payment) + payment = None # to avoid releasing in `finally` later + log.info("paid fee of %s for %s", dewies_to_lbc(stream.content_fee.outputs[0].amount), uri) + await self.storage.save_content_fee(stream.stream_hash, stream.content_fee) + + source_manager.add(stream) + + if not claim.stream.source.bt_infohash: + await self.storage.save_content_claim(stream.stream_hash, outpoint) + else: + await self.storage.save_torrent_content_claim( + stream.identifier, outpoint, stream.torrent_length, stream.torrent_name + ) + claim_info = await self.storage.get_content_claim_for_torrent(stream.identifier) + stream.set_claim(claim_info, claim) + if save_file: + await asyncio.wait_for(stream.save_file(), timeout - (self.loop.time() - before_download), + loop=self.loop) + return stream + except asyncio.TimeoutError: + error = DownloadDataTimeoutError(stream.sd_hash) + raise error + except Exception as err: # forgive data timeout, don't delete stream + expected = (DownloadSDTimeoutError, DownloadDataTimeoutError, InsufficientFundsError, + KeyFeeAboveMaxAllowedError) + if isinstance(err, expected): + log.warning("Failed to download %s: %s", uri, str(err)) + elif isinstance(err, asyncio.CancelledError): + pass + else: + log.exception("Unexpected error downloading stream:") + error = err + raise + finally: + if payment is not None: + # payment is set to None after broadcasting, if we're here an exception probably happened + await self.wallet_manager.ledger.release_tx(payment) + if self.analytics_manager and claim and claim.stream.source.bt_infohash: + # TODO: analytics for torrents + pass + elif self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or + stream.downloader.time_to_first_bytes))): + server = self.wallet_manager.ledger.network.client.server + self.loop.create_task( + self.analytics_manager.send_time_to_first_bytes( + resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id, + uri, outpoint, + None if not stream else len(stream.downloader.blob_downloader.active_connections), + None if not stream else len(stream.downloader.blob_downloader.scores), + None if not stream else len(stream.downloader.blob_downloader.connection_failures), + False if not stream else stream.downloader.added_fixed_peers, + self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay, + None if not stream else stream.sd_hash, + None if not stream else stream.downloader.time_to_descriptor, + None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash, + None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length, + None if not stream else stream.downloader.time_to_first_bytes, + None if not error else error.__class__.__name__, + None if not error else str(error), + None if not server else f"{server[0]}:{server[1]}" + ) + ) + + async def stream_partial_content(self, request: Request, sd_hash: str): + return await self.source_managers['stream'].stream_partial_content(request, sd_hash) + + def get_filtered(self, *args, **kwargs) -> typing.List[ManagedDownloadSource]: + """ + Get a list of filtered and sorted ManagedStream objects + + :param sort_by: field to sort by + :param reverse: reverse sorting + :param comparison: comparison operator used for filtering + :param search_by: fields and values to filter by + """ + return sum((manager.get_filtered(*args, **kwargs) for manager in self.source_managers.values()), []) + + async def delete(self, source: ManagedDownloadSource, delete_file=False): + for manager in self.source_managers.values(): + await manager.delete(source, delete_file) diff --git a/lbry/file/source.py b/lbry/file/source.py new file mode 100644 index 000000000..b661eb594 --- /dev/null +++ b/lbry/file/source.py @@ -0,0 +1,161 @@ +import os +import asyncio +import typing +import logging +import binascii +from typing import Optional +from lbry.utils import generate_id +from lbry.extras.daemon.storage import StoredContentClaim + +if typing.TYPE_CHECKING: + from lbry.conf import Config + from lbry.extras.daemon.analytics import AnalyticsManager + from lbry.wallet.transaction import Transaction + from lbry.extras.daemon.storage import SQLiteStorage + +log = logging.getLogger(__name__) + + +class ManagedDownloadSource: + STATUS_RUNNING = "running" + STATUS_STOPPED = "stopped" + STATUS_FINISHED = "finished" + + SAVING_ID = 1 + STREAMING_ID = 2 + + def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', storage: 'SQLiteStorage', identifier: str, + file_name: Optional[str] = None, download_directory: Optional[str] = None, + status: Optional[str] = STATUS_STOPPED, claim: Optional[StoredContentClaim] = None, + download_id: Optional[str] = None, rowid: Optional[int] = None, + content_fee: Optional['Transaction'] = None, + analytics_manager: Optional['AnalyticsManager'] = None, + added_on: Optional[int] = None): + self.loop = loop + self.storage = storage + self.config = config + self.identifier = identifier + self.download_directory = download_directory + self._file_name = file_name + self._status = status + self.stream_claim_info = claim + self.download_id = download_id or binascii.hexlify(generate_id()).decode() + self.rowid = rowid + self.content_fee = content_fee + self.purchase_receipt = None + self._added_on = added_on + self.analytics_manager = analytics_manager + + self.saving = asyncio.Event(loop=self.loop) + self.finished_writing = asyncio.Event(loop=self.loop) + self.started_writing = asyncio.Event(loop=self.loop) + self.finished_write_attempt = asyncio.Event(loop=self.loop) + + # @classmethod + # async def create(cls, loop: asyncio.AbstractEventLoop, config: 'Config', file_path: str, + # key: Optional[bytes] = None, + # iv_generator: Optional[typing.Generator[bytes, None, None]] = None) -> 'ManagedDownloadSource': + # raise NotImplementedError() + + async def start(self, timeout: Optional[float] = None, save_now: Optional[bool] = False): + raise NotImplementedError() + + async def stop(self, finished: bool = False): + raise NotImplementedError() + + async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None): + raise NotImplementedError() + + def stop_tasks(self): + raise NotImplementedError() + + def set_claim(self, claim_info: typing.Dict, claim: 'Claim'): + self.stream_claim_info = StoredContentClaim( + f"{claim_info['txid']}:{claim_info['nout']}", claim_info['claim_id'], + claim_info['name'], claim_info['amount'], claim_info['height'], + binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'], + claim_info['claim_sequence'], claim_info.get('channel_name') + ) + + # async def update_content_claim(self, claim_info: Optional[typing.Dict] = None): + # if not claim_info: + # claim_info = await self.blob_manager.storage.get_content_claim(self.stream_hash) + # self.set_claim(claim_info, claim_info['value']) + + @property + def file_name(self) -> Optional[str]: + return self._file_name + + @property + def added_on(self) -> Optional[int]: + return self._added_on + + @property + def status(self) -> str: + return self._status + + @property + def completed(self): + raise NotImplementedError() + + # @property + # def stream_url(self): + # return f"http://{self.config.streaming_host}:{self.config.streaming_port}/stream/{self.sd_hash} + + @property + def finished(self) -> bool: + return self.status == self.STATUS_FINISHED + + @property + def running(self) -> bool: + return self.status == self.STATUS_RUNNING + + @property + def claim_id(self) -> Optional[str]: + return None if not self.stream_claim_info else self.stream_claim_info.claim_id + + @property + def txid(self) -> Optional[str]: + return None if not self.stream_claim_info else self.stream_claim_info.txid + + @property + def nout(self) -> Optional[int]: + return None if not self.stream_claim_info else self.stream_claim_info.nout + + @property + def outpoint(self) -> Optional[str]: + return None if not self.stream_claim_info else self.stream_claim_info.outpoint + + @property + def claim_height(self) -> Optional[int]: + return None if not self.stream_claim_info else self.stream_claim_info.height + + @property + def channel_claim_id(self) -> Optional[str]: + return None if not self.stream_claim_info else self.stream_claim_info.channel_claim_id + + @property + def channel_name(self) -> Optional[str]: + return None if not self.stream_claim_info else self.stream_claim_info.channel_name + + @property + def claim_name(self) -> Optional[str]: + return None if not self.stream_claim_info else self.stream_claim_info.claim_name + + @property + def metadata(self) -> Optional[typing.Dict]: + return None if not self.stream_claim_info else self.stream_claim_info.claim.stream.to_dict() + + @property + def metadata_protobuf(self) -> bytes: + if self.stream_claim_info: + return binascii.hexlify(self.stream_claim_info.claim.to_bytes()) + + @property + def full_path(self) -> Optional[str]: + return os.path.join(self.download_directory, os.path.basename(self.file_name)) \ + if self.file_name and self.download_directory else None + + @property + def output_file_exists(self): + return os.path.isfile(self.full_path) if self.full_path else False diff --git a/lbry/file/source_manager.py b/lbry/file/source_manager.py new file mode 100644 index 000000000..b4babc7a9 --- /dev/null +++ b/lbry/file/source_manager.py @@ -0,0 +1,138 @@ +import os +import asyncio +import logging +import typing +from typing import Optional +from lbry.file.source import ManagedDownloadSource +if typing.TYPE_CHECKING: + from lbry.conf import Config + from lbry.extras.daemon.analytics import AnalyticsManager + from lbry.extras.daemon.storage import SQLiteStorage + +log = logging.getLogger(__name__) + +COMPARISON_OPERATORS = { + 'eq': lambda a, b: a == b, + 'ne': lambda a, b: a != b, + 'g': lambda a, b: a > b, + 'l': lambda a, b: a < b, + 'ge': lambda a, b: a >= b, + 'le': lambda a, b: a <= b, +} + + +class SourceManager: + filter_fields = { + 'rowid', + 'status', + 'file_name', + 'added_on', + 'download_path', + 'claim_name', + 'claim_height', + 'claim_id', + 'outpoint', + 'txid', + 'nout', + 'channel_claim_id', + 'channel_name', + 'completed' + } + + set_filter_fields = { + "claim_ids": "claim_id", + "channel_claim_ids": "channel_claim_id", + "outpoints": "outpoint" + } + + source_class = ManagedDownloadSource + + def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', storage: 'SQLiteStorage', + analytics_manager: Optional['AnalyticsManager'] = None): + self.loop = loop + self.config = config + self.storage = storage + self.analytics_manager = analytics_manager + self._sources: typing.Dict[str, ManagedDownloadSource] = {} + self.started = asyncio.Event(loop=self.loop) + + def add(self, source: ManagedDownloadSource): + self._sources[source.identifier] = source + + def remove(self, source: ManagedDownloadSource): + if source.identifier not in self._sources: + return + self._sources.pop(source.identifier) + source.stop_tasks() + + async def initialize_from_database(self): + raise NotImplementedError() + + async def start(self): + await self.initialize_from_database() + self.started.set() + + def stop(self): + while self._sources: + _, source = self._sources.popitem() + source.stop_tasks() + self.started.clear() + + async def create(self, file_path: str, key: Optional[bytes] = None, + iv_generator: Optional[typing.Generator[bytes, None, None]] = None) -> ManagedDownloadSource: + raise NotImplementedError() + + async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False): + self.remove(source) + if delete_file and source.output_file_exists: + os.remove(source.full_path) + + def get_filtered(self, sort_by: Optional[str] = None, reverse: Optional[bool] = False, + comparison: Optional[str] = None, **search_by) -> typing.List[ManagedDownloadSource]: + """ + Get a list of filtered and sorted ManagedStream objects + + :param sort_by: field to sort by + :param reverse: reverse sorting + :param comparison: comparison operator used for filtering + :param search_by: fields and values to filter by + """ + if sort_by and sort_by not in self.filter_fields: + raise ValueError(f"'{sort_by}' is not a valid field to sort by") + if comparison and comparison not in COMPARISON_OPERATORS: + raise ValueError(f"'{comparison}' is not a valid comparison") + if 'full_status' in search_by: + del search_by['full_status'] + + for search in search_by: + if search not in self.filter_fields: + raise ValueError(f"'{search}' is not a valid search operation") + + compare_sets = {} + if isinstance(search_by.get('claim_id'), list): + compare_sets['claim_ids'] = search_by.pop('claim_id') + if isinstance(search_by.get('outpoint'), list): + compare_sets['outpoints'] = search_by.pop('outpoint') + if isinstance(search_by.get('channel_claim_id'), list): + compare_sets['channel_claim_ids'] = search_by.pop('channel_claim_id') + + if search_by or compare_sets: + comparison = comparison or 'eq' + streams = [] + for stream in self._sources.values(): + if compare_sets and not all( + getattr(stream, self.set_filter_fields[set_search]) in val + for set_search, val in compare_sets.items()): + continue + if search_by and not all( + COMPARISON_OPERATORS[comparison](getattr(stream, search), val) + for search, val in search_by.items()): + continue + streams.append(stream) + else: + streams = list(self._sources.values()) + if sort_by: + streams.sort(key=lambda s: getattr(s, sort_by)) + if reverse: + streams.reverse() + return streams diff --git a/lbry/prometheus.py b/lbry/prometheus.py new file mode 100644 index 000000000..220ee97bd --- /dev/null +++ b/lbry/prometheus.py @@ -0,0 +1,32 @@ +import logging +from aiohttp import web +from prometheus_client import generate_latest as prom_generate_latest + + +class PrometheusServer: + def __init__(self, logger=None): + self.runner = None + self.logger = logger or logging.getLogger(__name__) + + async def start(self, interface: str, port: int): + prom_app = web.Application() + prom_app.router.add_get('/metrics', self.handle_metrics_get_request) + self.runner = web.AppRunner(prom_app) + await self.runner.setup() + + metrics_site = web.TCPSite(self.runner, interface, port, shutdown_timeout=.5) + await metrics_site.start() + self.logger.info('metrics server listening on %s:%i', *metrics_site._server.sockets[0].getsockname()[:2]) + + async def handle_metrics_get_request(self, request: web.Request): + try: + return web.Response( + text=prom_generate_latest().decode(), + content_type='text/plain; version=0.0.4' + ) + except Exception: + self.logger.exception('could not generate prometheus data') + raise + + async def stop(self): + await self.runner.cleanup() diff --git a/lbry/stream/descriptor.py b/lbry/stream/descriptor.py index 1c0305ddc..2e2626b9b 100644 --- a/lbry/stream/descriptor.py +++ b/lbry/stream/descriptor.py @@ -44,18 +44,25 @@ def random_iv_generator() -> typing.Generator[bytes, None, None]: yield os.urandom(AES.block_size // 8) -def file_reader(file_path: str): +def read_bytes(file_path: str, offset: int, to_read: int): + with open(file_path, 'rb') as f: + f.seek(offset) + return f.read(to_read) + + +async def file_reader(file_path: str): length = int(os.stat(file_path).st_size) offset = 0 - with open(file_path, 'rb') as stream_file: - while offset < length: - bytes_to_read = min((length - offset), MAX_BLOB_SIZE - 1) - if not bytes_to_read: - break - blob_bytes = stream_file.read(bytes_to_read) - yield blob_bytes - offset += bytes_to_read + while offset < length: + bytes_to_read = min((length - offset), MAX_BLOB_SIZE - 1) + if not bytes_to_read: + break + blob_bytes = await asyncio.get_event_loop().run_in_executor( + None, read_bytes, file_path, offset, bytes_to_read + ) + yield blob_bytes + offset += bytes_to_read def sanitize_file_name(dirty_name: str, default_file_name: str = 'lbry_download'): @@ -245,7 +252,7 @@ class StreamDescriptor: iv_generator = iv_generator or random_iv_generator() key = key or os.urandom(AES.block_size // 8) blob_num = -1 - for blob_bytes in file_reader(file_path): + async for blob_bytes in file_reader(file_path): blob_num += 1 blob_info = await BlobFile.create_from_unencrypted( loop, blob_dir, key, next(iv_generator), blob_bytes, blob_num, blob_completed_callback diff --git a/lbry/stream/downloader.py b/lbry/stream/downloader.py index 9fe98ac54..94537e034 100644 --- a/lbry/stream/downloader.py +++ b/lbry/stream/downloader.py @@ -51,15 +51,15 @@ class StreamDownloader: def _delayed_add_fixed_peers(): self.added_fixed_peers = True self.peer_queue.put_nowait([ - make_kademlia_peer(None, address, None, tcp_port=port + 1, allow_localhost=True) + make_kademlia_peer(None, address, None, tcp_port=port, allow_localhost=True) for address, port in addresses ]) - if not self.config.reflector_servers: + if not self.config.fixed_peers: return addresses = [ - (await resolve_host(url, port + 1, proto='tcp'), port) - for url, port in self.config.reflector_servers + (await resolve_host(url, port, proto='tcp'), port) + for url, port in self.config.fixed_peers ] if 'dht' in self.config.components_to_skip or not self.node or not \ len(self.node.protocol.routing_table.get_peers()) > 0: @@ -92,8 +92,8 @@ class StreamDownloader: async def start(self, node: typing.Optional['Node'] = None, connection_id: int = 0): # set up peer accumulation - if node: - self.node = node + self.node = node or self.node # fixme: this shouldnt be set here! + if self.node: if self.accumulate_task and not self.accumulate_task.done(): self.accumulate_task.cancel() _, self.accumulate_task = self.node.accumulate_peers(self.search_queue, self.peer_queue) diff --git a/lbry/stream/managed_stream.py b/lbry/stream/managed_stream.py index 5071d879e..da625c381 100644 --- a/lbry/stream/managed_stream.py +++ b/lbry/stream/managed_stream.py @@ -3,9 +3,8 @@ import asyncio import time import typing import logging -import binascii +from typing import Optional from aiohttp.web import Request, StreamResponse, HTTPRequestRangeNotSatisfiable -from lbry.utils import generate_id from lbry.error import DownloadSDTimeoutError from lbry.schema.mime_types import guess_media_type from lbry.stream.downloader import StreamDownloader @@ -13,6 +12,7 @@ from lbry.stream.descriptor import StreamDescriptor, sanitize_file_name from lbry.stream.reflector.client import StreamReflectorClient from lbry.extras.daemon.storage import StoredContentClaim from lbry.blob import MAX_BLOB_SIZE +from lbry.file.source import ManagedDownloadSource if typing.TYPE_CHECKING: from lbry.conf import Config @@ -40,78 +40,35 @@ async def get_next_available_file_name(loop: asyncio.AbstractEventLoop, download return await loop.run_in_executor(None, _get_next_available_file_name, download_directory, file_name) -class ManagedStream: - STATUS_RUNNING = "running" - STATUS_STOPPED = "stopped" - STATUS_FINISHED = "finished" - - SAVING_ID = 1 - STREAMING_ID = 2 - - __slots__ = [ - 'loop', - 'config', - 'blob_manager', - 'sd_hash', - 'download_directory', - '_file_name', - '_added_on', - '_status', - 'stream_claim_info', - 'download_id', - 'rowid', - 'content_fee', - 'purchase_receipt', - 'downloader', - 'analytics_manager', - 'fully_reflected', - 'reflector_progress', - 'file_output_task', - 'delayed_stop_task', - 'streaming_responses', - 'streaming', - '_running', - 'saving', - 'finished_writing', - 'started_writing', - 'finished_write_attempt' - ] - +class ManagedStream(ManagedDownloadSource): def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager', - sd_hash: str, download_directory: typing.Optional[str] = None, file_name: typing.Optional[str] = None, - status: typing.Optional[str] = STATUS_STOPPED, claim: typing.Optional[StoredContentClaim] = None, - download_id: typing.Optional[str] = None, rowid: typing.Optional[int] = None, - descriptor: typing.Optional[StreamDescriptor] = None, - content_fee: typing.Optional['Transaction'] = None, - analytics_manager: typing.Optional['AnalyticsManager'] = None, - added_on: typing.Optional[int] = None): - self.loop = loop - self.config = config + sd_hash: str, download_directory: Optional[str] = None, file_name: Optional[str] = None, + status: Optional[str] = ManagedDownloadSource.STATUS_STOPPED, + claim: Optional[StoredContentClaim] = None, + download_id: Optional[str] = None, rowid: Optional[int] = None, + descriptor: Optional[StreamDescriptor] = None, + content_fee: Optional['Transaction'] = None, + analytics_manager: Optional['AnalyticsManager'] = None, + added_on: Optional[int] = None): + super().__init__(loop, config, blob_manager.storage, sd_hash, file_name, download_directory, status, claim, + download_id, rowid, content_fee, analytics_manager, added_on) self.blob_manager = blob_manager - self.sd_hash = sd_hash - self.download_directory = download_directory - self._file_name = file_name - self._status = status - self.stream_claim_info = claim - self.download_id = download_id or binascii.hexlify(generate_id()).decode() - self.rowid = rowid - self.content_fee = content_fee self.purchase_receipt = None - self._added_on = added_on self.downloader = StreamDownloader(self.loop, self.config, self.blob_manager, sd_hash, descriptor) self.analytics_manager = analytics_manager - self.fully_reflected = asyncio.Event(loop=self.loop) self.reflector_progress = 0 + self.uploading_to_reflector = False self.file_output_task: typing.Optional[asyncio.Task] = None self.delayed_stop_task: typing.Optional[asyncio.Task] = None self.streaming_responses: typing.List[typing.Tuple[Request, StreamResponse]] = [] + self.fully_reflected = asyncio.Event(loop=self.loop) self.streaming = asyncio.Event(loop=self.loop) self._running = asyncio.Event(loop=self.loop) - self.saving = asyncio.Event(loop=self.loop) - self.finished_writing = asyncio.Event(loop=self.loop) - self.started_writing = asyncio.Event(loop=self.loop) - self.finished_write_attempt = asyncio.Event(loop=self.loop) + + @property + def sd_hash(self) -> str: + return self.identifier @property def is_fully_reflected(self) -> bool: @@ -126,17 +83,9 @@ class ManagedStream: return self.descriptor.stream_hash @property - def file_name(self) -> typing.Optional[str]: + def file_name(self) -> Optional[str]: return self._file_name or (self.descriptor.suggested_file_name if self.descriptor else None) - @property - def added_on(self) -> typing.Optional[int]: - return self._added_on - - @property - def status(self) -> str: - return self._status - @property def written_bytes(self) -> int: return 0 if not self.output_file_exists else os.stat(self.full_path).st_size @@ -154,55 +103,6 @@ class ManagedStream: self._status = status await self.blob_manager.storage.change_file_status(self.stream_hash, status) - @property - def finished(self) -> bool: - return self.status == self.STATUS_FINISHED - - @property - def running(self) -> bool: - return self.status == self.STATUS_RUNNING - - @property - def claim_id(self) -> typing.Optional[str]: - return None if not self.stream_claim_info else self.stream_claim_info.claim_id - - @property - def txid(self) -> typing.Optional[str]: - return None if not self.stream_claim_info else self.stream_claim_info.txid - - @property - def nout(self) -> typing.Optional[int]: - return None if not self.stream_claim_info else self.stream_claim_info.nout - - @property - def outpoint(self) -> typing.Optional[str]: - return None if not self.stream_claim_info else self.stream_claim_info.outpoint - - @property - def claim_height(self) -> typing.Optional[int]: - return None if not self.stream_claim_info else self.stream_claim_info.height - - @property - def channel_claim_id(self) -> typing.Optional[str]: - return None if not self.stream_claim_info else self.stream_claim_info.channel_claim_id - - @property - def channel_name(self) -> typing.Optional[str]: - return None if not self.stream_claim_info else self.stream_claim_info.channel_name - - @property - def claim_name(self) -> typing.Optional[str]: - return None if not self.stream_claim_info else self.stream_claim_info.claim_name - - @property - def metadata(self) -> typing.Optional[typing.Dict]: - return None if not self.stream_claim_info else self.stream_claim_info.claim.stream.to_dict() - - @property - def metadata_protobuf(self) -> bytes: - if self.stream_claim_info: - return binascii.hexlify(self.stream_claim_info.claim.to_bytes()) - @property def blobs_completed(self) -> int: return sum([1 if b.blob_hash in self.blob_manager.completed_blob_hashes else 0 @@ -216,39 +116,34 @@ class ManagedStream: def blobs_remaining(self) -> int: return self.blobs_in_stream - self.blobs_completed - @property - def full_path(self) -> typing.Optional[str]: - return os.path.join(self.download_directory, os.path.basename(self.file_name)) \ - if self.file_name and self.download_directory else None - - @property - def output_file_exists(self): - return os.path.isfile(self.full_path) if self.full_path else False - @property def mime_type(self): return guess_media_type(os.path.basename(self.descriptor.suggested_file_name))[0] - @classmethod - async def create(cls, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager', - file_path: str, key: typing.Optional[bytes] = None, - iv_generator: typing.Optional[typing.Generator[bytes, None, None]] = None) -> 'ManagedStream': - """ - Generate a stream from a file and save it to the db - """ - descriptor = await StreamDescriptor.create_stream( - loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator, - blob_completed_callback=blob_manager.blob_completed - ) - await blob_manager.storage.store_stream( - blob_manager.get_blob(descriptor.sd_hash), descriptor - ) - row_id = await blob_manager.storage.save_published_file(descriptor.stream_hash, os.path.basename(file_path), - os.path.dirname(file_path), 0) - return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), - os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) + @property + def download_path(self): + return f"{self.download_directory}/{self._file_name}" if self.download_directory and self._file_name else None - async def start(self, node: typing.Optional['Node'] = None, timeout: typing.Optional[float] = None, + # @classmethod + # async def create(cls, loop: asyncio.AbstractEventLoop, config: 'Config', + # file_path: str, key: Optional[bytes] = None, + # iv_generator: Optional[typing.Generator[bytes, None, None]] = None) -> 'ManagedDownloadSource': + # """ + # Generate a stream from a file and save it to the db + # """ + # descriptor = await StreamDescriptor.create_stream( + # loop, blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator, + # blob_completed_callback=blob_manager.blob_completed + # ) + # await blob_manager.storage.store_stream( + # blob_manager.get_blob(descriptor.sd_hash), descriptor + # ) + # row_id = await blob_manager.storage.save_published_file(descriptor.stream_hash, os.path.basename(file_path), + # os.path.dirname(file_path), 0) + # return cls(loop, config, blob_manager, descriptor.sd_hash, os.path.dirname(file_path), + # os.path.basename(file_path), status=cls.STATUS_FINISHED, rowid=row_id, descriptor=descriptor) + + async def start(self, timeout: Optional[float] = None, save_now: bool = False): timeout = timeout or self.config.download_timeout if self._running.is_set(): @@ -256,7 +151,7 @@ class ManagedStream: log.info("start downloader for stream (sd hash: %s)", self.sd_hash) self._running.set() try: - await asyncio.wait_for(self.downloader.start(node), timeout, loop=self.loop) + await asyncio.wait_for(self.downloader.start(), timeout, loop=self.loop) except asyncio.TimeoutError: self._running.clear() raise DownloadSDTimeoutError(self.sd_hash) @@ -266,6 +161,11 @@ class ManagedStream: self.delayed_stop_task = self.loop.create_task(self._delayed_stop()) if not await self.blob_manager.storage.file_exists(self.sd_hash): if save_now: + if not self._file_name: + self._file_name = await get_next_available_file_name( + self.loop, self.download_directory, + self._file_name or sanitize_file_name(self.descriptor.suggested_file_name) + ) file_name, download_dir = self._file_name, self.download_directory else: file_name, download_dir = None, None @@ -285,7 +185,7 @@ class ManagedStream: if (finished and self.status != self.STATUS_FINISHED) or self.status == self.STATUS_RUNNING: await self.update_status(self.STATUS_FINISHED if finished else self.STATUS_STOPPED) - async def _aiter_read_stream(self, start_blob_num: typing.Optional[int] = 0, connection_id: int = 0)\ + async def _aiter_read_stream(self, start_blob_num: Optional[int] = 0, connection_id: int = 0)\ -> typing.AsyncIterator[typing.Tuple['BlobInfo', bytes]]: if start_blob_num >= len(self.descriptor.blobs[:-1]): raise IndexError(start_blob_num) @@ -297,13 +197,13 @@ class ManagedStream: decrypted = await self.downloader.read_blob(blob_info, connection_id) yield (blob_info, decrypted) - async def stream_file(self, request: Request, node: typing.Optional['Node'] = None) -> StreamResponse: + async def stream_file(self, request: Request) -> StreamResponse: log.info("stream file to browser for lbry://%s#%s (sd hash %s...)", self.claim_name, self.claim_id, self.sd_hash[:6]) headers, size, skip_blobs, first_blob_start_offset = self._prepare_range_response_headers( request.headers.get('range', 'bytes=0-') ) - await self.start(node) + await self.start() response = StreamResponse( status=206, headers=headers @@ -341,9 +241,10 @@ class ManagedStream: self.streaming.clear() @staticmethod - def _write_decrypted_blob(handle: typing.IO, data: bytes): - handle.write(data) - handle.flush() + def _write_decrypted_blob(output_path: str, data: bytes): + with open(output_path, 'ab') as handle: + handle.write(data) + handle.flush() async def _save_file(self, output_path: str): log.info("save file for lbry://%s#%s (sd hash %s...) -> %s", self.claim_name, self.claim_id, self.sd_hash[:6], @@ -353,12 +254,12 @@ class ManagedStream: self.finished_writing.clear() self.started_writing.clear() try: - with open(output_path, 'wb') as file_write_handle: - async for blob_info, decrypted in self._aiter_read_stream(connection_id=self.SAVING_ID): - log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) - await self.loop.run_in_executor(None, self._write_decrypted_blob, file_write_handle, decrypted) - if not self.started_writing.is_set(): - self.started_writing.set() + open(output_path, 'wb').close() + async for blob_info, decrypted in self._aiter_read_stream(connection_id=self.SAVING_ID): + log.info("write blob %i/%i", blob_info.blob_num + 1, len(self.descriptor.blobs) - 1) + await self.loop.run_in_executor(None, self._write_decrypted_blob, output_path, decrypted) + if not self.started_writing.is_set(): + self.started_writing.set() await self.update_status(ManagedStream.STATUS_FINISHED) if self.analytics_manager: self.loop.create_task(self.analytics_manager.send_download_finished( @@ -388,9 +289,8 @@ class ManagedStream: self.saving.clear() self.finished_write_attempt.set() - async def save_file(self, file_name: typing.Optional[str] = None, download_directory: typing.Optional[str] = None, - node: typing.Optional['Node'] = None): - await self.start(node) + async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None): + await self.start() if self.file_output_task and not self.file_output_task.done(): # cancel an already running save task self.file_output_task.cancel() self.download_directory = download_directory or self.download_directory or self.config.download_dir @@ -432,6 +332,7 @@ class ManagedStream: sent = [] protocol = StreamReflectorClient(self.blob_manager, self.descriptor) try: + self.uploading_to_reflector = True await self.loop.create_connection(lambda: protocol, host, port) await protocol.send_handshake() sent_sd, needed = await protocol.send_descriptor() @@ -458,20 +359,13 @@ class ManagedStream: finally: if protocol.transport: protocol.transport.close() + self.uploading_to_reflector = False if not self.fully_reflected.is_set(): self.fully_reflected.set() await self.blob_manager.storage.update_reflected_stream(self.sd_hash, f"{host}:{port}") return sent - def set_claim(self, claim_info: typing.Dict, claim: 'Claim'): - self.stream_claim_info = StoredContentClaim( - f"{claim_info['txid']}:{claim_info['nout']}", claim_info['claim_id'], - claim_info['name'], claim_info['amount'], claim_info['height'], - binascii.hexlify(claim.to_bytes()).decode(), claim.signing_channel_id, claim_info['address'], - claim_info['claim_sequence'], claim_info.get('channel_name') - ) - - async def update_content_claim(self, claim_info: typing.Optional[typing.Dict] = None): + async def update_content_claim(self, claim_info: Optional[typing.Dict] = None): if not claim_info: claim_info = await self.blob_manager.storage.get_content_claim(self.stream_hash) self.set_claim(claim_info, claim_info['value']) diff --git a/lbry/stream/reflector/client.py b/lbry/stream/reflector/client.py index 7a8032c99..fbc084b9b 100644 --- a/lbry/stream/reflector/client.py +++ b/lbry/stream/reflector/client.py @@ -35,6 +35,8 @@ class StreamReflectorClient(asyncio.Protocol): def connection_lost(self, exc: typing.Optional[Exception]): self.transport = None self.connected.clear() + if self.pending_request: + self.pending_request.cancel() if self.reflected_blobs: log.info("Finished sending reflector %i blobs", len(self.reflected_blobs)) @@ -56,11 +58,11 @@ class StreamReflectorClient(asyncio.Protocol): self.response_buff = b'' return - async def send_request(self, request_dict: typing.Dict): + async def send_request(self, request_dict: typing.Dict, timeout: int = 180): msg = json.dumps(request_dict) self.transport.write(msg.encode()) try: - self.pending_request = self.loop.create_task(self.response_queue.get()) + self.pending_request = self.loop.create_task(asyncio.wait_for(self.response_queue.get(), timeout)) return await self.pending_request finally: self.pending_request = None @@ -87,7 +89,7 @@ class StreamReflectorClient(asyncio.Protocol): sent_sd = False if response['send_sd_blob']: await sd_blob.sendfile(self) - received = await self.response_queue.get() + received = await asyncio.wait_for(self.response_queue.get(), 30) if received.get('received_sd_blob'): sent_sd = True if not needed: @@ -111,7 +113,7 @@ class StreamReflectorClient(asyncio.Protocol): raise ValueError("I don't know whether to send the blob or not!") if response['send_blob']: await blob.sendfile(self) - received = await self.response_queue.get() + received = await asyncio.wait_for(self.response_queue.get(), 30) if received.get('received_blob'): self.reflected_blobs.append(blob.blob_hash) log.info("Sent reflector blob %s", blob.blob_hash[:8]) diff --git a/lbry/stream/stream_manager.py b/lbry/stream/stream_manager.py index 4fb37e99a..8e66aff96 100644 --- a/lbry/stream/stream_manager.py +++ b/lbry/stream/stream_manager.py @@ -6,93 +6,67 @@ import random import typing from typing import Optional from aiohttp.web import Request -from lbry.error import ResolveError, InvalidStreamDescriptorError, DownloadSDTimeoutError, InsufficientFundsError -from lbry.error import ResolveTimeoutError, DownloadDataTimeoutError, KeyFeeAboveMaxAllowedError -from lbry.utils import cache_concurrent +from lbry.error import InvalidStreamDescriptorError +from lbry.file.source_manager import SourceManager from lbry.stream.descriptor import StreamDescriptor from lbry.stream.managed_stream import ManagedStream -from lbry.schema.claim import Claim -from lbry.schema.url import URL -from lbry.wallet.dewies import dewies_to_lbc -from lbry.wallet import Output - +from lbry.file.source import ManagedDownloadSource if typing.TYPE_CHECKING: from lbry.conf import Config from lbry.blob.blob_manager import BlobManager from lbry.dht.node import Node + from lbry.wallet.wallet import WalletManager + from lbry.wallet.transaction import Transaction from lbry.extras.daemon.analytics import AnalyticsManager from lbry.extras.daemon.storage import SQLiteStorage, StoredContentClaim - from lbry.extras.daemon.exchange_rate_manager import ExchangeRateManager - from lbry.wallet.transaction import Transaction - from lbry.wallet.manager import WalletManager - from lbry.wallet.wallet import Wallet log = logging.getLogger(__name__) -FILTER_FIELDS = [ - 'rowid', - 'status', - 'file_name', - 'added_on', - 'sd_hash', - 'stream_hash', - 'claim_name', - 'claim_height', - 'claim_id', - 'outpoint', - 'txid', - 'nout', - 'channel_claim_id', - 'channel_name', - 'full_status', # TODO: remove - 'blobs_remaining', - 'blobs_in_stream' -] -SET_FILTER_FIELDS = { - "claim_ids": "claim_id", - "channel_claim_ids": "channel_claim_id", - "outpoints": "outpoint" -} - -COMPARISON_OPERATORS = { - 'eq': lambda a, b: a == b, - 'ne': lambda a, b: a != b, - 'g': lambda a, b: a > b, - 'l': lambda a, b: a < b, - 'ge': lambda a, b: a >= b, - 'le': lambda a, b: a <= b, - 'in': lambda a, b: a in b -} - - -def path_or_none(path) -> Optional[str]: - if not path: +def path_or_none(encoded_path) -> Optional[str]: + if not encoded_path: return - return binascii.unhexlify(path).decode() + return binascii.unhexlify(encoded_path).decode() -class StreamManager: +class StreamManager(SourceManager): + _sources: typing.Dict[str, ManagedStream] + + filter_fields = SourceManager.filter_fields + filter_fields.update({ + 'sd_hash', + 'stream_hash', + 'full_status', # TODO: remove + 'blobs_remaining', + 'blobs_in_stream', + 'uploading_to_reflector', + 'is_fully_reflected' + }) + def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', blob_manager: 'BlobManager', wallet_manager: 'WalletManager', storage: 'SQLiteStorage', node: Optional['Node'], analytics_manager: Optional['AnalyticsManager'] = None): - self.loop = loop - self.config = config + super().__init__(loop, config, storage, analytics_manager) self.blob_manager = blob_manager self.wallet_manager = wallet_manager - self.storage = storage self.node = node - self.analytics_manager = analytics_manager - self.streams: typing.Dict[str, ManagedStream] = {} self.resume_saving_task: Optional[asyncio.Task] = None self.re_reflect_task: Optional[asyncio.Task] = None self.update_stream_finished_futs: typing.List[asyncio.Future] = [] self.running_reflector_uploads: typing.Dict[str, asyncio.Task] = {} self.started = asyncio.Event(loop=self.loop) + @property + def streams(self): + return self._sources + + def add(self, source: ManagedStream): + super().add(source) + self.storage.content_claim_callbacks[source.stream_hash] = lambda: self._update_content_claim(source) + async def _update_content_claim(self, stream: ManagedStream): claim_info = await self.storage.get_content_claim(stream.stream_hash) - self.streams.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value']) + self._sources.setdefault(stream.sd_hash, stream).set_claim(claim_info, claim_info['value']) async def recover_streams(self, file_infos: typing.List[typing.Dict]): to_restore = [] @@ -123,10 +97,10 @@ class StreamManager: # if self.blob_manager._save_blobs: # log.info("Recovered %i/%i attempted streams", len(to_restore), len(file_infos)) - async def add_stream(self, rowid: int, sd_hash: str, file_name: Optional[str], - download_directory: Optional[str], status: str, - claim: Optional['StoredContentClaim'], content_fee: Optional['Transaction'], - added_on: Optional[int], fully_reflected: bool): + async def _load_stream(self, rowid: int, sd_hash: str, file_name: Optional[str], + download_directory: Optional[str], status: str, + claim: Optional['StoredContentClaim'], content_fee: Optional['Transaction'], + added_on: Optional[int], fully_reflected: Optional[bool]): try: descriptor = await self.blob_manager.get_stream_descriptor(sd_hash) except InvalidStreamDescriptorError as err: @@ -139,10 +113,9 @@ class StreamManager: ) if fully_reflected: stream.fully_reflected.set() - self.streams[sd_hash] = stream - self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) + self.add(stream) - async def load_and_resume_streams_from_database(self): + async def initialize_from_database(self): to_recover = [] to_start = [] @@ -156,7 +129,6 @@ class StreamManager: to_recover.append(file_info) to_start.append(file_info) if to_recover: - log.info("Recover %i files", len(to_recover)) await self.recover_streams(to_recover) log.info("Initializing %i files", len(to_start)) @@ -167,7 +139,7 @@ class StreamManager: download_directory = path_or_none(file_info['download_directory']) if file_name and download_directory and not file_info['saved_file'] and file_info['status'] == 'running': to_resume_saving.append((file_name, download_directory, file_info['sd_hash'])) - add_stream_tasks.append(self.loop.create_task(self.add_stream( + add_stream_tasks.append(self.loop.create_task(self._load_stream( file_info['rowid'], file_info['sd_hash'], file_name, download_directory, file_info['status'], file_info['claim'], file_info['content_fee'], @@ -175,25 +147,22 @@ class StreamManager: ))) if add_stream_tasks: await asyncio.gather(*add_stream_tasks, loop=self.loop) - log.info("Started stream manager with %i files", len(self.streams)) + log.info("Started stream manager with %i files", len(self._sources)) if not self.node: log.info("no DHT node given, resuming downloads trusting that we can contact reflector") if to_resume_saving: - self.resume_saving_task = self.loop.create_task(self.resume(to_resume_saving)) - - async def resume(self, to_resume_saving): - log.info("Resuming saving %i files", len(to_resume_saving)) - await asyncio.gather( - *(self.streams[sd_hash].save_file(file_name, download_directory, node=self.node) - for (file_name, download_directory, sd_hash) in to_resume_saving), - loop=self.loop - ) + log.info("Resuming saving %i files", len(to_resume_saving)) + self.resume_saving_task = asyncio.ensure_future(asyncio.gather( + *(self._sources[sd_hash].save_file(file_name, download_directory) + for (file_name, download_directory, sd_hash) in to_resume_saving), + loop=self.loop + )) async def reflect_streams(self): while True: if self.config.reflect_streams and self.config.reflector_servers: sd_hashes = await self.storage.get_streams_to_re_reflect() - sd_hashes = [sd for sd in sd_hashes if sd in self.streams] + sd_hashes = [sd for sd in sd_hashes if sd in self._sources] batch = [] while sd_hashes: stream = self.streams[sd_hashes.pop()] @@ -209,18 +178,15 @@ class StreamManager: await asyncio.sleep(300, loop=self.loop) async def start(self): - await self.load_and_resume_streams_from_database() + await super().start() self.re_reflect_task = self.loop.create_task(self.reflect_streams()) - self.started.set() def stop(self): + super().stop() if self.resume_saving_task and not self.resume_saving_task.done(): self.resume_saving_task.cancel() if self.re_reflect_task and not self.re_reflect_task.done(): self.re_reflect_task.cancel() - while self.streams: - _, stream = self.streams.popitem() - stream.stop_tasks() while self.update_stream_finished_futs: self.update_stream_finished_futs.pop().cancel() while self.running_reflector_uploads: @@ -243,280 +209,45 @@ class StreamManager: ) return task - async def create_stream(self, file_path: str, key: Optional[bytes] = None, - iv_generator: Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream: - stream = await ManagedStream.create(self.loop, self.config, self.blob_manager, file_path, key, iv_generator) + async def create(self, file_path: str, key: Optional[bytes] = None, + iv_generator: Optional[typing.Generator[bytes, None, None]] = None) -> ManagedStream: + descriptor = await StreamDescriptor.create_stream( + self.loop, self.blob_manager.blob_dir, file_path, key=key, iv_generator=iv_generator, + blob_completed_callback=self.blob_manager.blob_completed + ) + await self.storage.store_stream( + self.blob_manager.get_blob(descriptor.sd_hash), descriptor + ) + row_id = await self.storage.save_published_file( + descriptor.stream_hash, os.path.basename(file_path), os.path.dirname(file_path), 0 + ) + stream = ManagedStream( + self.loop, self.config, self.blob_manager, descriptor.sd_hash, os.path.dirname(file_path), + os.path.basename(file_path), status=ManagedDownloadSource.STATUS_FINISHED, + rowid=row_id, descriptor=descriptor + ) self.streams[stream.sd_hash] = stream self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) if self.config.reflect_streams and self.config.reflector_servers: self.reflect_stream(stream) return stream - async def delete_stream(self, stream: ManagedStream, delete_file: Optional[bool] = False): - if stream.sd_hash in self.running_reflector_uploads: - self.running_reflector_uploads[stream.sd_hash].cancel() - stream.stop_tasks() - if stream.sd_hash in self.streams: - del self.streams[stream.sd_hash] - blob_hashes = [stream.sd_hash] + [b.blob_hash for b in stream.descriptor.blobs[:-1]] + async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False): + if not isinstance(source, ManagedStream): + return + if source.identifier in self.running_reflector_uploads: + self.running_reflector_uploads[source.identifier].cancel() + source.stop_tasks() + if source.identifier in self.streams: + del self.streams[source.identifier] + blob_hashes = [source.identifier] + [b.blob_hash for b in source.descriptor.blobs[:-1]] await self.blob_manager.delete_blobs(blob_hashes, delete_from_db=False) - await self.storage.delete_stream(stream.descriptor) - if delete_file and stream.output_file_exists: - os.remove(stream.full_path) - - def get_stream_by_stream_hash(self, stream_hash: str) -> Optional[ManagedStream]: - streams = tuple(filter(lambda stream: stream.stream_hash == stream_hash, self.streams.values())) - if streams: - return streams[0] - - def get_filtered_streams(self, sort_by: Optional[str] = None, reverse: Optional[bool] = False, - comparison: Optional[str] = None, - **search_by) -> typing.List[ManagedStream]: - """ - Get a list of filtered and sorted ManagedStream objects - - :param sort_by: field to sort by - :param reverse: reverse sorting - :param comparison: comparison operator used for filtering - :param search_by: fields and values to filter by - """ - if sort_by and sort_by not in FILTER_FIELDS: - raise ValueError(f"'{sort_by}' is not a valid field to sort by") - if comparison and comparison not in COMPARISON_OPERATORS: - raise ValueError(f"'{comparison}' is not a valid comparison") - if 'full_status' in search_by: - del search_by['full_status'] - - for search in search_by: - if search not in FILTER_FIELDS: - raise ValueError(f"'{search}' is not a valid search operation") - - compare_sets = {} - if isinstance(search_by.get('claim_id'), list): - compare_sets['claim_ids'] = search_by.pop('claim_id') - if isinstance(search_by.get('outpoint'), list): - compare_sets['outpoints'] = search_by.pop('outpoint') - if isinstance(search_by.get('channel_claim_id'), list): - compare_sets['channel_claim_ids'] = search_by.pop('channel_claim_id') - - if search_by: - comparison = comparison or 'eq' - streams = [] - for stream in self.streams.values(): - matched = False - for set_search, val in compare_sets.items(): - if COMPARISON_OPERATORS[comparison](getattr(stream, SET_FILTER_FIELDS[set_search]), val): - streams.append(stream) - matched = True - break - if matched: - continue - for search, val in search_by.items(): - this_stream = getattr(stream, search) - if COMPARISON_OPERATORS[comparison](this_stream, val): - streams.append(stream) - break - else: - streams = list(self.streams.values()) - if sort_by: - streams.sort(key=lambda s: getattr(s, sort_by)) - if reverse: - streams.reverse() - return streams - - async def _check_update_or_replace(self, outpoint: str, claim_id: str, claim: Claim - ) -> typing.Tuple[Optional[ManagedStream], Optional[ManagedStream]]: - existing = self.get_filtered_streams(outpoint=outpoint) - if existing: - return existing[0], None - existing = self.get_filtered_streams(sd_hash=claim.stream.source.sd_hash) - if existing and existing[0].claim_id != claim_id: - raise ResolveError(f"stream for {existing[0].claim_id} collides with existing download {claim_id}") - if existing: - log.info("claim contains a metadata only update to a stream we have") - await self.storage.save_content_claim( - existing[0].stream_hash, outpoint - ) - await self._update_content_claim(existing[0]) - return existing[0], None - else: - existing_for_claim_id = self.get_filtered_streams(claim_id=claim_id) - if existing_for_claim_id: - log.info("claim contains an update to a stream we have, downloading it") - return None, existing_for_claim_id[0] - return None, None - - @staticmethod - def _convert_to_old_resolve_output(wallet_manager, resolves): - result = {} - for url, txo in resolves.items(): - if isinstance(txo, Output): - tx_height = txo.tx_ref.height - best_height = wallet_manager.ledger.headers.height - result[url] = { - 'name': txo.claim_name, - 'value': txo.claim, - 'protobuf': binascii.hexlify(txo.claim.to_bytes()), - 'claim_id': txo.claim_id, - 'txid': txo.tx_ref.id, - 'nout': txo.position, - 'amount': dewies_to_lbc(txo.amount), - 'effective_amount': txo.meta.get('effective_amount', 0), - 'height': tx_height, - 'confirmations': (best_height+1) - tx_height if tx_height > 0 else tx_height, - 'claim_sequence': -1, - 'address': txo.get_address(wallet_manager.ledger), - 'valid_at_height': txo.meta.get('activation_height', None), - 'timestamp': wallet_manager.ledger.headers.estimated_timestamp(tx_height), - 'supports': [] - } - else: - result[url] = txo - return result - - @cache_concurrent - async def download_stream_from_uri(self, uri, exchange_rate_manager: 'ExchangeRateManager', - timeout: Optional[float] = None, - file_name: Optional[str] = None, - download_directory: Optional[str] = None, - save_file: Optional[bool] = None, - resolve_timeout: float = 3.0, - wallet: Optional['Wallet'] = None) -> ManagedStream: - manager = self.wallet_manager - wallet = wallet or manager.default_wallet - timeout = timeout or self.config.download_timeout - start_time = self.loop.time() - resolved_time = None - stream = None - txo: Optional[Output] = None - error = None - outpoint = None - if save_file is None: - save_file = self.config.save_files - if file_name and not save_file: - save_file = True - if save_file: - download_directory = download_directory or self.config.download_dir - else: - download_directory = None - - payment = None - try: - # resolve the claim - if not URL.parse(uri).has_stream: - raise ResolveError("cannot download a channel claim, specify a /path") - try: - response = await asyncio.wait_for( - manager.ledger.resolve(wallet.accounts, [uri], include_purchase_receipt=True), - resolve_timeout - ) - resolved_result = self._convert_to_old_resolve_output(manager, response) - except asyncio.TimeoutError: - raise ResolveTimeoutError(uri) - except Exception as err: - if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 - raise - log.exception("Unexpected error resolving stream:") - raise ResolveError(f"Unexpected error resolving stream: {str(err)}") - await self.storage.save_claims_for_resolve([ - value for value in resolved_result.values() if 'error' not in value - ]) - resolved = resolved_result.get(uri, {}) - resolved = resolved if 'value' in resolved else resolved.get('claim') - if not resolved: - raise ResolveError(f"Failed to resolve stream at '{uri}'") - if 'error' in resolved: - raise ResolveError(f"error resolving stream: {resolved['error']}") - txo = response[uri] - - claim = Claim.from_bytes(binascii.unhexlify(resolved['protobuf'])) - outpoint = f"{resolved['txid']}:{resolved['nout']}" - resolved_time = self.loop.time() - start_time - - # resume or update an existing stream, if the stream changed: download it and delete the old one after - updated_stream, to_replace = await self._check_update_or_replace(outpoint, resolved['claim_id'], claim) - if updated_stream: - log.info("already have stream for %s", uri) - if save_file and updated_stream.output_file_exists: - save_file = False - await updated_stream.start(node=self.node, timeout=timeout, save_now=save_file) - if not updated_stream.output_file_exists and (save_file or file_name or download_directory): - await updated_stream.save_file( - file_name=file_name, download_directory=download_directory, node=self.node - ) - return updated_stream - - if not to_replace and txo.has_price and not txo.purchase_receipt: - payment = await manager.create_purchase_transaction( - wallet.accounts, txo, exchange_rate_manager - ) - - stream = ManagedStream( - self.loop, self.config, self.blob_manager, claim.stream.source.sd_hash, download_directory, - file_name, ManagedStream.STATUS_RUNNING, content_fee=payment, - analytics_manager=self.analytics_manager - ) - log.info("starting download for %s", uri) - - before_download = self.loop.time() - await stream.start(self.node, timeout) - stream.set_claim(resolved, claim) - if to_replace: # delete old stream now that the replacement has started downloading - await self.delete_stream(to_replace) - - if payment is not None: - await manager.broadcast_or_release(payment) - payment = None # to avoid releasing in `finally` later - log.info("paid fee of %s for %s", dewies_to_lbc(stream.content_fee.outputs[0].amount), uri) - await self.storage.save_content_fee(stream.stream_hash, stream.content_fee) - - self.streams[stream.sd_hash] = stream - self.storage.content_claim_callbacks[stream.stream_hash] = lambda: self._update_content_claim(stream) - await self.storage.save_content_claim(stream.stream_hash, outpoint) - if save_file: - await asyncio.wait_for(stream.save_file(node=self.node), timeout - (self.loop.time() - before_download), - loop=self.loop) - return stream - except asyncio.TimeoutError: - error = DownloadDataTimeoutError(stream.sd_hash) - raise error - except Exception as err: # forgive data timeout, don't delete stream - expected = (DownloadSDTimeoutError, DownloadDataTimeoutError, InsufficientFundsError, - KeyFeeAboveMaxAllowedError) - if isinstance(err, expected): - log.warning("Failed to download %s: %s", uri, str(err)) - elif isinstance(err, asyncio.CancelledError): - pass - else: - log.exception("Unexpected error downloading stream:") - error = err - raise - finally: - if payment is not None: - # payment is set to None after broadcasting, if we're here an exception probably happened - await manager.ledger.release_tx(payment) - if self.analytics_manager and (error or (stream and (stream.downloader.time_to_descriptor or - stream.downloader.time_to_first_bytes))): - server = self.wallet_manager.ledger.network.client.server - self.loop.create_task( - self.analytics_manager.send_time_to_first_bytes( - resolved_time, self.loop.time() - start_time, None if not stream else stream.download_id, - uri, outpoint, - None if not stream else len(stream.downloader.blob_downloader.active_connections), - None if not stream else len(stream.downloader.blob_downloader.scores), - None if not stream else len(stream.downloader.blob_downloader.connection_failures), - False if not stream else stream.downloader.added_fixed_peers, - self.config.fixed_peer_delay if not stream else stream.downloader.fixed_peers_delay, - None if not stream else stream.sd_hash, - None if not stream else stream.downloader.time_to_descriptor, - None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].blob_hash, - None if not (stream and stream.descriptor) else stream.descriptor.blobs[0].length, - None if not stream else stream.downloader.time_to_first_bytes, - None if not error else error.__class__.__name__, - None if not error else str(error), - None if not server else f"{server[0]}:{server[1]}" - ) - ) + await self.storage.delete_stream(source.descriptor) + if delete_file and source.output_file_exists: + os.remove(source.full_path) async def stream_partial_content(self, request: Request, sd_hash: str): - return await self.streams[sd_hash].stream_file(request, self.node) + stream = self._sources[sd_hash] + if not stream.downloader.node: + stream.downloader.node = self.node + return await stream.stream_file(request) diff --git a/lbry/testcase.py b/lbry/testcase.py index dcdaa83e5..6dc1e2eb9 100644 --- a/lbry/testcase.py +++ b/lbry/testcase.py @@ -386,6 +386,7 @@ class CommandTestCase(IntegrationTestCase): conf.blockchain_name = 'lbrycrd_regtest' conf.lbryum_servers = [('127.0.0.1', 50001)] conf.reflector_servers = [('127.0.0.1', 5566)] + conf.fixed_peers = [('127.0.0.1', 5567)] conf.known_dht_nodes = [] conf.blob_lru_cache_size = self.blob_lru_cache_size conf.components_to_skip = [ diff --git a/lbry/torrent/__init__.py b/lbry/torrent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lbry/torrent/session.py b/lbry/torrent/session.py new file mode 100644 index 000000000..feff53f75 --- /dev/null +++ b/lbry/torrent/session.py @@ -0,0 +1,290 @@ +import asyncio +import binascii +import os +import logging +import random +from hashlib import sha1 +from tempfile import mkdtemp +from typing import Optional + +import libtorrent + + +NOTIFICATION_MASKS = [ + "error", + "peer", + "port_mapping", + "storage", + "tracker", + "debug", + "status", + "progress", + "ip_block", + "dht", + "stats", + "session_log", + "torrent_log", + "peer_log", + "incoming_request", + "dht_log", + "dht_operation", + "port_mapping_log", + "picker_log", + "file_progress", + "piece_progress", + "upload", + "block_progress" +] +log = logging.getLogger(__name__) + + +DEFAULT_FLAGS = ( # fixme: somehow the logic here is inverted? + libtorrent.add_torrent_params_flags_t.flag_auto_managed + | libtorrent.add_torrent_params_flags_t.flag_update_subscribe +) + + +def get_notification_type(notification) -> str: + for i, notification_type in enumerate(NOTIFICATION_MASKS): + if (1 << i) & notification: + return notification_type + raise ValueError("unrecognized notification type") + + +class TorrentHandle: + def __init__(self, loop, executor, handle): + self._loop = loop + self._executor = executor + self._handle: libtorrent.torrent_handle = handle + self.started = asyncio.Event(loop=loop) + self.finished = asyncio.Event(loop=loop) + self.metadata_completed = asyncio.Event(loop=loop) + self.size = 0 + self.total_wanted_done = 0 + self.name = '' + self.tasks = [] + self.torrent_file: Optional[libtorrent.file_storage] = None + self._base_path = None + self._handle.set_sequential_download(1) + + @property + def largest_file(self) -> Optional[str]: + if not self.torrent_file: + return None + index = self.largest_file_index + return os.path.join(self._base_path, self.torrent_file.at(index).path) + + @property + def largest_file_index(self): + largest_size, index = 0, 0 + for file_num in range(self.torrent_file.num_files()): + if self.torrent_file.file_size(file_num) > largest_size: + largest_size = self.torrent_file.file_size(file_num) + index = file_num + return index + + def stop_tasks(self): + while self.tasks: + self.tasks.pop().cancel() + + def _show_status(self): + # fixme: cleanup + if not self._handle.is_valid(): + return + status = self._handle.status() + if status.has_metadata: + self.size = status.total_wanted + self.total_wanted_done = status.total_wanted_done + self.name = status.name + if not self.metadata_completed.is_set(): + self.metadata_completed.set() + log.info("Metadata completed for btih:%s - %s", status.info_hash, self.name) + self.torrent_file = self._handle.get_torrent_info().files() + self._base_path = status.save_path + first_piece = self.torrent_file.at(self.largest_file_index).offset + if not self.started.is_set(): + if self._handle.have_piece(first_piece): + self.started.set() + else: + # prioritize it + self._handle.set_piece_deadline(first_piece, 100) + if not status.is_seeding: + log.debug('%.2f%% complete (down: %.1f kB/s up: %.1f kB/s peers: %d seeds: %d) %s - %s', + status.progress * 100, status.download_rate / 1000, status.upload_rate / 1000, + status.num_peers, status.num_seeds, status.state, status.save_path) + elif not self.finished.is_set(): + self.finished.set() + log.info("Torrent finished: %s", self.name) + + async def status_loop(self): + while True: + self._show_status() + if self.finished.is_set(): + break + await asyncio.sleep(0.1, loop=self._loop) + + async def pause(self): + await self._loop.run_in_executor( + self._executor, self._handle.pause + ) + + async def resume(self): + await self._loop.run_in_executor( + self._executor, lambda: self._handle.resume() # pylint: disable=unnecessary-lambda + ) + + +class TorrentSession: + def __init__(self, loop, executor): + self._loop = loop + self._executor = executor + self._session: Optional[libtorrent.session] = None + self._handles = {} + self.tasks = [] + self.wait_start = True + + async def add_fake_torrent(self): + tmpdir = mkdtemp() + info, btih = _create_fake_torrent(tmpdir) + flags = libtorrent.add_torrent_params_flags_t.flag_seed_mode + handle = self._session.add_torrent({ + 'ti': info, 'save_path': tmpdir, 'flags': flags + }) + self._handles[btih] = TorrentHandle(self._loop, self._executor, handle) + return btih + + async def bind(self, interface: str = '0.0.0.0', port: int = 10889): + settings = { + 'listen_interfaces': f"{interface}:{port}", + 'enable_outgoing_utp': True, + 'enable_incoming_utp': True, + 'enable_outgoing_tcp': False, + 'enable_incoming_tcp': False + } + self._session = await self._loop.run_in_executor( + self._executor, libtorrent.session, settings # pylint: disable=c-extension-no-member + ) + self.tasks.append(self._loop.create_task(self.process_alerts())) + + def stop(self): + while self.tasks: + self.tasks.pop().cancel() + self._session.save_state() + self._session.pause() + self._session.stop_dht() + self._session.stop_lsd() + self._session.stop_natpmp() + self._session.stop_upnp() + self._session = None + + def _pop_alerts(self): + for alert in self._session.pop_alerts(): + log.info("torrent alert: %s", alert) + + async def process_alerts(self): + while True: + await self._loop.run_in_executor( + self._executor, self._pop_alerts + ) + await asyncio.sleep(1, loop=self._loop) + + async def pause(self): + await self._loop.run_in_executor( + self._executor, lambda: self._session.save_state() # pylint: disable=unnecessary-lambda + ) + await self._loop.run_in_executor( + self._executor, lambda: self._session.pause() # pylint: disable=unnecessary-lambda + ) + + async def resume(self): + await self._loop.run_in_executor( + self._executor, self._session.resume + ) + + def _add_torrent(self, btih: str, download_directory: Optional[str]): + params = {'info_hash': binascii.unhexlify(btih.encode()), 'flags': DEFAULT_FLAGS} + if download_directory: + params['save_path'] = download_directory + handle = self._session.add_torrent(params) + handle.force_dht_announce() + self._handles[btih] = TorrentHandle(self._loop, self._executor, handle) + + def full_path(self, btih): + return self._handles[btih].largest_file + + async def add_torrent(self, btih, download_path): + await self._loop.run_in_executor( + self._executor, self._add_torrent, btih, download_path + ) + self._handles[btih].tasks.append(self._loop.create_task(self._handles[btih].status_loop())) + await self._handles[btih].metadata_completed.wait() + if self.wait_start: + # fixme: temporary until we add streaming support, otherwise playback fails! + await self._handles[btih].started.wait() + + def remove_torrent(self, btih, remove_files=False): + if btih in self._handles: + handle = self._handles[btih] + handle.stop_tasks() + self._session.remove_torrent(handle._handle, 1 if remove_files else 0) + self._handles.pop(btih) + + async def save_file(self, btih, download_directory): + handle = self._handles[btih] + await handle.resume() + + def get_size(self, btih): + return self._handles[btih].size + + def get_name(self, btih): + return self._handles[btih].name + + def get_downloaded(self, btih): + return self._handles[btih].total_wanted_done + + def is_completed(self, btih): + return self._handles[btih].finished.is_set() + + +def get_magnet_uri(btih): + return f"magnet:?xt=urn:btih:{btih}" + + +def _create_fake_torrent(tmpdir): + # beware, that's just for testing + path = os.path.join(tmpdir, 'tmp') + with open(path, 'wb') as myfile: + size = myfile.write(bytes([random.randint(0, 255) for _ in range(40)]) * 1024) + file_storage = libtorrent.file_storage() + file_storage.add_file('tmp', size) + t = libtorrent.create_torrent(file_storage, 0, 4 * 1024 * 1024) + libtorrent.set_piece_hashes(t, tmpdir) + info = libtorrent.torrent_info(t.generate()) + btih = sha1(info.metadata()).hexdigest() + return info, btih + + +async def main(): + if os.path.exists("~/Downloads/ubuntu-18.04.3-live-server-amd64.torrent"): + os.remove("~/Downloads/ubuntu-18.04.3-live-server-amd64.torrent") + if os.path.exists("~/Downloads/ubuntu-18.04.3-live-server-amd64.iso"): + os.remove("~/Downloads/ubuntu-18.04.3-live-server-amd64.iso") + + btih = "dd8255ecdc7ca55fb0bbf81323d87062db1f6d1c" + + executor = None + session = TorrentSession(asyncio.get_event_loop(), executor) + session2 = TorrentSession(asyncio.get_event_loop(), executor) + await session.bind('localhost', port=4040) + await session2.bind('localhost', port=4041) + btih = await session.add_fake_torrent() + session2._session.add_dht_node(('localhost', 4040)) + await session2.add_torrent(btih, "/tmp/down") + while True: + await asyncio.sleep(100) + await session.pause() + executor.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/lbry/torrent/torrent.py b/lbry/torrent/torrent.py new file mode 100644 index 000000000..04a8544c7 --- /dev/null +++ b/lbry/torrent/torrent.py @@ -0,0 +1,72 @@ +import asyncio +import logging +import typing + + +log = logging.getLogger(__name__) + + +class TorrentInfo: + __slots__ = ('dht_seeds', 'http_seeds', 'trackers', 'total_size') + + def __init__(self, dht_seeds: typing.Tuple[typing.Tuple[str, int]], + http_seeds: typing.Tuple[typing.Dict[str, typing.Any]], + trackers: typing.Tuple[typing.Tuple[str, int]], total_size: int): + self.dht_seeds = dht_seeds + self.http_seeds = http_seeds + self.trackers = trackers + self.total_size = total_size + + @classmethod + def from_libtorrent_info(cls, torrent_info): + return cls( + torrent_info.nodes(), tuple( + { + 'url': web_seed['url'], + 'type': web_seed['type'], + 'auth': web_seed['auth'] + } for web_seed in torrent_info.web_seeds() + ), tuple( + (tracker.url, tracker.tier) for tracker in torrent_info.trackers() + ), torrent_info.total_size() + ) + + +class Torrent: + def __init__(self, loop, handle): + self._loop = loop + self._handle = handle + self.finished = asyncio.Event(loop=loop) + + def _threaded_update_status(self): + status = self._handle.status() + if not status.is_seeding: + log.info( + '%.2f%% complete (down: %.1f kB/s up: %.1f kB/s peers: %d) %s', + status.progress * 100, status.download_rate / 1000, status.upload_rate / 1000, + status.num_peers, status.state + ) + elif not self.finished.is_set(): + self.finished.set() + + async def wait_for_finished(self): + while True: + await self._loop.run_in_executor( + None, self._threaded_update_status + ) + if self.finished.is_set(): + log.info("finished downloading torrent!") + await self.pause() + break + await asyncio.sleep(1, loop=self._loop) + + async def pause(self): + log.info("pause torrent") + await self._loop.run_in_executor( + None, self._handle.pause + ) + + async def resume(self): + await self._loop.run_in_executor( + None, self._handle.resume + ) diff --git a/lbry/torrent/torrent_manager.py b/lbry/torrent/torrent_manager.py new file mode 100644 index 000000000..cf9106731 --- /dev/null +++ b/lbry/torrent/torrent_manager.py @@ -0,0 +1,140 @@ +import asyncio +import binascii +import logging +import os +import typing +from typing import Optional +from aiohttp.web import Request +from lbry.file.source_manager import SourceManager +from lbry.file.source import ManagedDownloadSource + +if typing.TYPE_CHECKING: + from lbry.torrent.session import TorrentSession + from lbry.conf import Config + from lbry.wallet.transaction import Transaction + from lbry.extras.daemon.analytics import AnalyticsManager + from lbry.extras.daemon.storage import SQLiteStorage, StoredContentClaim + from lbry.extras.daemon.storage import StoredContentClaim + +log = logging.getLogger(__name__) + + +def path_or_none(encoded_path) -> Optional[str]: + if not encoded_path: + return + return binascii.unhexlify(encoded_path).decode() + + +class TorrentSource(ManagedDownloadSource): + STATUS_STOPPED = "stopped" + filter_fields = SourceManager.filter_fields + filter_fields.update({ + 'bt_infohash' + }) + + def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', storage: 'SQLiteStorage', identifier: str, + file_name: Optional[str] = None, download_directory: Optional[str] = None, + status: Optional[str] = STATUS_STOPPED, claim: Optional['StoredContentClaim'] = None, + download_id: Optional[str] = None, rowid: Optional[int] = None, + content_fee: Optional['Transaction'] = None, + analytics_manager: Optional['AnalyticsManager'] = None, + added_on: Optional[int] = None, torrent_session: Optional['TorrentSession'] = None): + super().__init__(loop, config, storage, identifier, file_name, download_directory, status, claim, download_id, + rowid, content_fee, analytics_manager, added_on) + self.torrent_session = torrent_session + + @property + def full_path(self) -> Optional[str]: + full_path = self.torrent_session.full_path(self.identifier) + self.download_directory = os.path.dirname(full_path) + return full_path + + async def start(self, timeout: Optional[float] = None, save_now: Optional[bool] = False): + await self.torrent_session.add_torrent(self.identifier, self.download_directory) + + async def stop(self, finished: bool = False): + await self.torrent_session.remove_torrent(self.identifier) + + async def save_file(self, file_name: Optional[str] = None, download_directory: Optional[str] = None): + await self.torrent_session.save_file(self.identifier, download_directory) + + @property + def torrent_length(self): + return self.torrent_session.get_size(self.identifier) + + @property + def written_bytes(self): + return self.torrent_session.get_downloaded(self.identifier) + + @property + def torrent_name(self): + return self.torrent_session.get_name(self.identifier) + + @property + def bt_infohash(self): + return self.identifier + + def stop_tasks(self): + pass + + @property + def completed(self): + return self.torrent_session.is_completed(self.identifier) + + +class TorrentManager(SourceManager): + _sources: typing.Dict[str, ManagedDownloadSource] + + filter_fields = set(SourceManager.filter_fields) + filter_fields.update({ + 'bt_infohash', + 'blobs_remaining', # TODO: here they call them "parts", but its pretty much the same concept + 'blobs_in_stream' + }) + + def __init__(self, loop: asyncio.AbstractEventLoop, config: 'Config', torrent_session: 'TorrentSession', + storage: 'SQLiteStorage', analytics_manager: Optional['AnalyticsManager'] = None): + super().__init__(loop, config, storage, analytics_manager) + self.torrent_session: 'TorrentSession' = torrent_session + + async def recover_streams(self, file_infos: typing.List[typing.Dict]): + raise NotImplementedError + + async def _load_stream(self, rowid: int, bt_infohash: str, file_name: Optional[str], + download_directory: Optional[str], status: str, + claim: Optional['StoredContentClaim'], content_fee: Optional['Transaction'], + added_on: Optional[int]): + stream = TorrentSource( + self.loop, self.config, self.storage, identifier=bt_infohash, file_name=file_name, + download_directory=download_directory, status=status, claim=claim, rowid=rowid, + content_fee=content_fee, analytics_manager=self.analytics_manager, added_on=added_on, + torrent_session=self.torrent_session + ) + self.add(stream) + + async def initialize_from_database(self): + pass + + async def start(self): + await super().start() + + def stop(self): + super().stop() + log.info("finished stopping the torrent manager") + + async def delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False): + await super().delete(source, delete_file) + self.torrent_session.remove_torrent(source.identifier, delete_file) + + async def create(self, file_path: str, key: Optional[bytes] = None, + iv_generator: Optional[typing.Generator[bytes, None, None]] = None): + raise NotImplementedError + + async def _delete(self, source: ManagedDownloadSource, delete_file: Optional[bool] = False): + raise NotImplementedError + # blob_hashes = [source.sd_hash] + [b.blob_hash for b in source.descriptor.blobs[:-1]] + # await self.blob_manager.delete_blobs(blob_hashes, delete_from_db=False) + # await self.storage.delete_stream(source.descriptor) + + async def stream_partial_content(self, request: Request, sd_hash: str): + raise NotImplementedError diff --git a/lbry/utils.py b/lbry/utils.py index c24d8a971..0db443cd9 100644 --- a/lbry/utils.py +++ b/lbry/utils.py @@ -3,6 +3,7 @@ import codecs import datetime import random import socket +import time import string import sys import json @@ -282,3 +283,25 @@ async def get_external_ip() -> typing.Optional[str]: # used if upnp is disabled def is_running_from_bundle(): # see https://pyinstaller.readthedocs.io/en/stable/runtime-information.html return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') + + +class LockWithMetrics(asyncio.Lock): + def __init__(self, acquire_metric, held_time_metric, loop=None): + super().__init__(loop=loop) + self._acquire_metric = acquire_metric + self._lock_held_time_metric = held_time_metric + self._lock_acquired_time = None + + async def acquire(self): + start = time.perf_counter() + try: + return await super().acquire() + finally: + self._lock_acquired_time = time.perf_counter() + self._acquire_metric.observe(self._lock_acquired_time - start) + + def release(self): + try: + return super().release() + finally: + self._lock_held_time_metric.observe(time.perf_counter() - self._lock_acquired_time) diff --git a/lbry/wallet/__init__.py b/lbry/wallet/__init__.py index 7ed88527b..98cfc8bb6 100644 --- a/lbry/wallet/__init__.py +++ b/lbry/wallet/__init__.py @@ -2,7 +2,7 @@ __node_daemon__ = 'lbrycrdd' __node_cli__ = 'lbrycrd-cli' __node_bin__ = '' __node_url__ = ( - 'https://github.com/lbryio/lbrycrd/releases/download/v0.17.4.4/lbrycrd-linux-1744.zip' + 'https://github.com/lbryio/lbrycrd/releases/download/v0.17.4.5/lbrycrd-linux-1745.zip' ) __spvserver__ = 'lbry.wallet.server.coin.LBCRegTest' diff --git a/lbry/wallet/account.py b/lbry/wallet/account.py index 3a3d4c3f3..fdefde985 100644 --- a/lbry/wallet/account.py +++ b/lbry/wallet/account.py @@ -525,11 +525,13 @@ class Account: channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes) self.channel_keys[channel_pubkey_hash] = private_key.to_pem().decode() - def get_channel_private_key(self, public_key_bytes): + async def get_channel_private_key(self, public_key_bytes): channel_pubkey_hash = self.ledger.public_key_to_address(public_key_bytes) private_key_pem = self.channel_keys.get(channel_pubkey_hash) if private_key_pem: - return ecdsa.SigningKey.from_pem(private_key_pem, hashfunc=sha256) + return await asyncio.get_event_loop().run_in_executor( + None, ecdsa.SigningKey.from_pem, private_key_pem, sha256 + ) async def maybe_migrate_certificates(self): def to_der(private_key_pem): diff --git a/lbry/wallet/database.py b/lbry/wallet/database.py index 5aec29649..15c866017 100644 --- a/lbry/wallet/database.py +++ b/lbry/wallet/database.py @@ -10,6 +10,8 @@ from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.process import ProcessPoolExecutor from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable, Dict, Optional from datetime import date +from prometheus_client import Gauge, Counter, Histogram +from lbry.utils import LockWithMetrics from .bip32 import PubKey from .transaction import Transaction, Output, OutputScript, TXRefImmutable @@ -20,6 +22,10 @@ from .util import date_to_julian_day log = logging.getLogger(__name__) sqlite3.enable_callback_tracebacks(True) +HISTOGRAM_BUCKETS = ( + .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') +) + @dataclass class ReaderProcessState: @@ -64,15 +70,36 @@ else: class AIOSQLite: reader_executor: ReaderExecutorClass + waiting_writes_metric = Gauge( + "waiting_writes_count", "Number of waiting db writes", namespace="daemon_database" + ) + waiting_reads_metric = Gauge( + "waiting_reads_count", "Number of waiting db writes", namespace="daemon_database" + ) + write_count_metric = Counter( + "write_count", "Number of database writes", namespace="daemon_database" + ) + read_count_metric = Counter( + "read_count", "Number of database reads", namespace="daemon_database" + ) + acquire_write_lock_metric = Histogram( + f'write_lock_acquired', 'Time to acquire the write lock', namespace="daemon_database", buckets=HISTOGRAM_BUCKETS + ) + held_write_lock_metric = Histogram( + f'write_lock_held', 'Length of time the write lock is held for', namespace="daemon_database", + buckets=HISTOGRAM_BUCKETS + ) + def __init__(self): # has to be single threaded as there is no mapping of thread:connection self.writer_executor = ThreadPoolExecutor(max_workers=1) self.writer_connection: Optional[sqlite3.Connection] = None self._closing = False self.query_count = 0 - self.write_lock = asyncio.Lock() + self.write_lock = LockWithMetrics(self.acquire_write_lock_metric, self.held_write_lock_metric) self.writers = 0 self.read_ready = asyncio.Event() + self.urgent_read_done = asyncio.Event() @classmethod async def connect(cls, path: Union[bytes, str], *args, **kwargs): @@ -88,6 +115,7 @@ class AIOSQLite: ) await asyncio.get_event_loop().run_in_executor(db.writer_executor, _connect_writer) db.read_ready.set() + db.urgent_read_done.set() return db async def close(self): @@ -112,12 +140,28 @@ class AIOSQLite: read_only=False, fetch_all: bool = False) -> List[dict]: read_only_fn = run_read_only_fetchall if fetch_all else run_read_only_fetchone parameters = parameters if parameters is not None else [] + still_waiting = False + urgent_read = False if read_only: - while self.writers: - await self.read_ready.wait() - return await asyncio.get_event_loop().run_in_executor( - self.reader_executor, read_only_fn, sql, parameters - ) + self.waiting_reads_metric.inc() + self.read_count_metric.inc() + try: + while self.writers: # more writes can come in while we are waiting for the first + if not urgent_read and still_waiting and self.urgent_read_done.is_set(): + # throttle the writes if they pile up + self.urgent_read_done.clear() + urgent_read = True + # wait until the running writes have finished + await self.read_ready.wait() + still_waiting = True + return await asyncio.get_event_loop().run_in_executor( + self.reader_executor, read_only_fn, sql, parameters + ) + finally: + if urgent_read: + # unthrottle the writers if they had to be throttled + self.urgent_read_done.set() + self.waiting_reads_metric.dec() if fetch_all: return await self.run(lambda conn: conn.execute(sql, parameters).fetchall()) return await self.run(lambda conn: conn.execute(sql, parameters).fetchone()) @@ -135,17 +179,32 @@ class AIOSQLite: return self.run(lambda conn: conn.execute(sql, parameters)) async def run(self, fun, *args, **kwargs): + self.write_count_metric.inc() + self.waiting_writes_metric.inc() + # it's possible many writes are coming in one after the other, these can + # block reader calls for a long time + # if the reader waits for the writers to finish and then has to wait for + # yet more, it will clear the urgent_read_done event to block more writers + # piling on + try: + await self.urgent_read_done.wait() + except Exception as e: + self.waiting_writes_metric.dec() + raise e self.writers += 1 + # block readers self.read_ready.clear() - async with self.write_lock: - try: + try: + async with self.write_lock: return await asyncio.get_event_loop().run_in_executor( self.writer_executor, lambda: self.__run_transaction(fun, *args, **kwargs) ) - finally: - self.writers -= 1 - if not self.writers: - self.read_ready.set() + finally: + self.writers -= 1 + self.waiting_writes_metric.dec() + if not self.writers: + # unblock the readers once the last enqueued writer finishes + self.read_ready.set() def __run_transaction(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], *args, **kwargs): self.writer_connection.execute('begin') @@ -160,10 +219,26 @@ class AIOSQLite: log.warning("rolled back") raise - def run_with_foreign_keys_disabled(self, fun, *args, **kwargs) -> Awaitable: - return asyncio.get_event_loop().run_in_executor( - self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs - ) + async def run_with_foreign_keys_disabled(self, fun, *args, **kwargs): + self.write_count_metric.inc() + self.waiting_writes_metric.inc() + try: + await self.urgent_read_done.wait() + except Exception as e: + self.waiting_writes_metric.dec() + raise e + self.writers += 1 + self.read_ready.clear() + try: + async with self.write_lock: + return await asyncio.get_event_loop().run_in_executor( + self.writer_executor, self.__run_transaction_with_foreign_keys_disabled, fun, args, kwargs + ) + finally: + self.writers -= 1 + self.waiting_writes_metric.dec() + if not self.writers: + self.read_ready.set() def __run_transaction_with_foreign_keys_disabled(self, fun: Callable[[sqlite3.Connection, Any, Any], Any], @@ -579,7 +654,7 @@ class Database(SQLiteMixin): return self.db.run(__many) async def reserve_outputs(self, txos, is_reserved=True): - txoids = ((is_reserved, txo.id) for txo in txos) + txoids = [(is_reserved, txo.id) for txo in txos] await self.db.executemany("UPDATE txo SET is_reserved = ? WHERE txoid = ?", txoids) async def release_outputs(self, txos): @@ -843,7 +918,7 @@ class Database(SQLiteMixin): channel_ids.add(txo.claim.signing_channel_id) if txo.claim.is_channel and wallet: for account in wallet.accounts: - private_key = account.get_channel_private_key( + private_key = await account.get_channel_private_key( txo.claim.channel.public_key_bytes ) if private_key: diff --git a/lbry/wallet/header.py b/lbry/wallet/header.py index fd5a85c61..c67b0dd1d 100644 --- a/lbry/wallet/header.py +++ b/lbry/wallet/header.py @@ -5,7 +5,6 @@ import asyncio import logging import zlib from datetime import date -from concurrent.futures.thread import ThreadPoolExecutor from io import BytesIO from typing import Optional, Iterator, Tuple, Callable @@ -42,23 +41,22 @@ class Headers: validate_difficulty: bool = True def __init__(self, path) -> None: - if path == ':memory:': - self.io = BytesIO() + self.io = None self.path = path self._size: Optional[int] = None self.chunk_getter: Optional[Callable] = None - self.executor = ThreadPoolExecutor(1) self.known_missing_checkpointed_chunks = set() self.check_chunk_lock = asyncio.Lock() async def open(self): - if not self.executor: - self.executor = ThreadPoolExecutor(1) + self.io = BytesIO() if self.path != ':memory:': - if not os.path.exists(self.path): - self.io = open(self.path, 'w+b') - else: - self.io = open(self.path, 'r+b') + def _readit(): + if os.path.exists(self.path): + with open(self.path, 'r+b') as header_file: + self.io.seek(0) + self.io.write(header_file.read()) + await asyncio.get_event_loop().run_in_executor(None, _readit) bytes_size = self.io.seek(0, os.SEEK_END) self._size = bytes_size // self.header_size max_checkpointed_height = max(self.checkpoints.keys() or [-1]) + 1000 @@ -72,10 +70,14 @@ class Headers: await self.get_all_missing_headers() async def close(self): - if self.executor: - self.executor.shutdown() - self.executor = None - self.io.close() + if self.io is not None: + def _close(): + flags = 'r+b' if os.path.exists(self.path) else 'w+b' + with open(self.path, flags) as header_file: + header_file.write(self.io.getbuffer()) + await asyncio.get_event_loop().run_in_executor(None, _close) + self.io.close() + self.io = None @staticmethod def serialize(header): @@ -135,28 +137,30 @@ class Headers: except struct.error: raise IndexError(f"failed to get {height}, at {len(self)}") - def estimated_timestamp(self, height): + def estimated_timestamp(self, height, try_real_headers=True): if height <= 0: return + if try_real_headers and self.has_header(height): + offset = height * self.header_size + return struct.unpack('<I', self.io.getbuffer()[offset + 100: offset + 104])[0] return int(self.first_block_timestamp + (height * self.timestamp_average_offset)) def estimated_julian_day(self, height): - return date_to_julian_day(date.fromtimestamp(self.estimated_timestamp(height))) + return date_to_julian_day(date.fromtimestamp(self.estimated_timestamp(height, False))) async def get_raw_header(self, height) -> bytes: if self.chunk_getter: await self.ensure_chunk_at(height) if not 0 <= height <= self.height: raise IndexError(f"{height} is out of bounds, current height: {self.height}") - return await asyncio.get_running_loop().run_in_executor(self.executor, self._read, height) + return self._read(height) def _read(self, height, count=1): - self.io.seek(height * self.header_size, os.SEEK_SET) - return self.io.read(self.header_size * count) + offset = height * self.header_size + return bytes(self.io.getbuffer()[offset: offset + self.header_size * count]) def chunk_hash(self, start, count): - self.io.seek(start * self.header_size, os.SEEK_SET) - return self.hash_header(self.io.read(count * self.header_size)).decode() + return self.hash_header(self._read(start, count)).decode() async def ensure_checkpointed_size(self): max_checkpointed_height = max(self.checkpoints.keys() or [-1]) @@ -165,7 +169,7 @@ class Headers: async def ensure_chunk_at(self, height): async with self.check_chunk_lock: - if await self.has_header(height): + if self.has_header(height): log.debug("has header %s", height) return return await self.fetch_chunk(height) @@ -179,7 +183,7 @@ class Headers: ) chunk_hash = self.hash_header(chunk).decode() if self.checkpoints.get(start) == chunk_hash: - await asyncio.get_running_loop().run_in_executor(self.executor, self._write, start, chunk) + self._write(start, chunk) if start in self.known_missing_checkpointed_chunks: self.known_missing_checkpointed_chunks.remove(start) return @@ -189,27 +193,23 @@ class Headers: f"Checkpoint mismatch at height {start}. Expected {self.checkpoints[start]}, but got {chunk_hash} instead." ) - async def has_header(self, height): + def has_header(self, height): normalized_height = (height // 1000) * 1000 if normalized_height in self.checkpoints: return normalized_height not in self.known_missing_checkpointed_chunks - def _has_header(height): - empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d' - all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b' - return self.chunk_hash(height, 1) not in (empty, all_zeroes) - return await asyncio.get_running_loop().run_in_executor(self.executor, _has_header, height) + empty = '56944c5d3f98413ef45cf54545538103cc9f298e0575820ad3591376e2e0f65d' + all_zeroes = '789d737d4f448e554b318c94063bbfa63e9ccda6e208f5648ca76ee68896557b' + return self.chunk_hash(height, 1) not in (empty, all_zeroes) async def get_all_missing_headers(self): # Heavy operation done in one optimized shot - def _io_checkall(): - for chunk_height, expected_hash in reversed(list(self.checkpoints.items())): - if chunk_height in self.known_missing_checkpointed_chunks: - continue - if self.chunk_hash(chunk_height, 1000) != expected_hash: - self.known_missing_checkpointed_chunks.add(chunk_height) - return self.known_missing_checkpointed_chunks - return await asyncio.get_running_loop().run_in_executor(self.executor, _io_checkall) + for chunk_height, expected_hash in reversed(list(self.checkpoints.items())): + if chunk_height in self.known_missing_checkpointed_chunks: + continue + if self.chunk_hash(chunk_height, 1000) != expected_hash: + self.known_missing_checkpointed_chunks.add(chunk_height) + return self.known_missing_checkpointed_chunks @property def height(self) -> int: @@ -241,7 +241,7 @@ class Headers: bail = True chunk = chunk[:(height-e.height)*self.header_size] if chunk: - added += await asyncio.get_running_loop().run_in_executor(self.executor, self._write, height, chunk) + added += self._write(height, chunk) if bail: break return added @@ -306,9 +306,7 @@ class Headers: previous_header_hash = fail = None batch_size = 36 for height in range(start_height, self.height, batch_size): - headers = await asyncio.get_running_loop().run_in_executor( - self.executor, self._read, height, batch_size - ) + headers = self._read(height, batch_size) if len(headers) % self.header_size != 0: headers = headers[:(len(headers) // self.header_size) * self.header_size] for header_hash, header in self._iterate_headers(height, headers): @@ -324,12 +322,11 @@ class Headers: assert start_height > 0 and height == start_height if fail: log.warning("Header file corrupted at height %s, truncating it.", height - 1) - def __truncate(at_height): - self.io.seek(max(0, (at_height - 1)) * self.header_size, os.SEEK_SET) - self.io.truncate() - self.io.flush() - self._size = self.io.seek(0, os.SEEK_END) // self.header_size - return await asyncio.get_running_loop().run_in_executor(self.executor, __truncate, height) + self.io.seek(max(0, (height - 1)) * self.header_size, os.SEEK_SET) + self.io.truncate() + self.io.flush() + self._size = self.io.seek(0, os.SEEK_END) // self.header_size + return previous_header_hash = header_hash @classmethod diff --git a/lbry/wallet/ledger.py b/lbry/wallet/ledger.py index 79e7ef6b2..50adf7467 100644 --- a/lbry/wallet/ledger.py +++ b/lbry/wallet/ledger.py @@ -1,5 +1,6 @@ import os import copy +import time import asyncio import logging from io import StringIO @@ -157,7 +158,7 @@ class Ledger(metaclass=LedgerRegistry): self._on_ready_controller = StreamController() self.on_ready = self._on_ready_controller.stream - self._tx_cache = pylru.lrucache(100000) + self._tx_cache = pylru.lrucache(self.config.get("tx_cache_size", 100_000)) self._update_tasks = TaskGroup() self._other_tasks = TaskGroup() # that we dont need to start self._utxo_reservation_lock = asyncio.Lock() @@ -639,6 +640,7 @@ class Ledger(metaclass=LedgerRegistry): return self.network.broadcast(hexlify(tx.raw).decode()) async def wait(self, tx: Transaction, height=-1, timeout=1): + timeout = timeout or 600 # after 10 minutes there is almost 0 hope addresses = set() for txi in tx.inputs: if txi.txo_ref.txo is not None: @@ -648,13 +650,20 @@ class Ledger(metaclass=LedgerRegistry): for txo in tx.outputs: if txo.has_address: addresses.add(self.hash160_to_address(txo.pubkey_hash)) + start = int(time.perf_counter()) + while timeout and (int(time.perf_counter()) - start) <= timeout: + if await self._wait_round(tx, height, addresses): + return + raise asyncio.TimeoutError('Timed out waiting for transaction.') + + async def _wait_round(self, tx: Transaction, height: int, addresses: Iterable[str]): records = await self.db.get_addresses(address__in=addresses) _, pending = await asyncio.wait([ self.on_transaction.where(partial( lambda a, e: a == e.address and e.tx.height >= height and e.tx.id == tx.id, address_record['address'] )) for address_record in records - ], timeout=timeout) + ], timeout=1) if pending: records = await self.db.get_addresses(address__in=addresses) for record in records: @@ -666,8 +675,9 @@ class Ledger(metaclass=LedgerRegistry): if txid == tx.id and local_height >= height: found = True if not found: - print(record['history'], addresses, tx.id) - raise asyncio.TimeoutError('Timed out waiting for transaction.') + log.debug("timeout: %s, %s, %s", record['history'], addresses, tx.id) + return False + return True async def _inflate_outputs( self, query, accounts, @@ -684,14 +694,26 @@ class Ledger(metaclass=LedgerRegistry): self.cache_transaction(*tx) for tx in outputs.txs )) - txos, blocked = outputs.inflate(txs) + _txos, blocked = outputs.inflate(txs) + + txos = [] + for txo in _txos: + if isinstance(txo, Output): + # transactions and outputs are cached and shared between wallets + # we don't want to leak informaion between wallet so we add the + # wallet specific metadata on throw away copies of the txos + txo = copy.copy(txo) + channel = txo.channel + txo.purchase_receipt = None + txo.update_annotations(None) + txo.channel = channel + txos.append(txo) includes = ( include_purchase_receipt, include_is_my_output, include_sent_supports, include_sent_tips ) if accounts and any(includes): - copies = [] receipts = {} if include_purchase_receipt: priced_claims = [] @@ -708,46 +730,38 @@ class Ledger(metaclass=LedgerRegistry): } for txo in txos: if isinstance(txo, Output) and txo.can_decode_claim: - # transactions and outputs are cached and shared between wallets - # we don't want to leak informaion between wallet so we add the - # wallet specific metadata on throw away copies of the txos - txo_copy = copy.copy(txo) - copies.append(txo_copy) if include_purchase_receipt: - txo_copy.purchase_receipt = receipts.get(txo.claim_id) + txo.purchase_receipt = receipts.get(txo.claim_id) if include_is_my_output: mine = await self.db.get_txo_count( claim_id=txo.claim_id, txo_type__in=CLAIM_TYPES, is_my_output=True, is_spent=False, accounts=accounts ) if mine: - txo_copy.is_my_output = True + txo.is_my_output = True else: - txo_copy.is_my_output = False + txo.is_my_output = False if include_sent_supports: supports = await self.db.get_txo_sum( claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], is_my_input=True, is_my_output=True, is_spent=False, accounts=accounts ) - txo_copy.sent_supports = supports + txo.sent_supports = supports if include_sent_tips: tips = await self.db.get_txo_sum( claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], is_my_input=True, is_my_output=False, accounts=accounts ) - txo_copy.sent_tips = tips + txo.sent_tips = tips if include_received_tips: tips = await self.db.get_txo_sum( claim_id=txo.claim_id, txo_type=TXO_TYPES['support'], is_my_input=False, is_my_output=True, accounts=accounts ) - txo_copy.received_tips = tips - else: - copies.append(txo) - txos = copies + txo.received_tips = tips return txos, blocked, outputs.offset, outputs.total async def resolve(self, accounts, urls, **kwargs): diff --git a/lbry/wallet/manager.py b/lbry/wallet/manager.py index 20658a2e5..22f0e5d21 100644 --- a/lbry/wallet/manager.py +++ b/lbry/wallet/manager.py @@ -184,6 +184,7 @@ class WalletManager: 'auto_connect': True, 'default_servers': config.lbryum_servers, 'data_path': config.wallet_dir, + 'tx_cache_size': config.transaction_cache_size } wallets_directory = os.path.join(config.wallet_dir, 'wallets') diff --git a/lbry/wallet/network.py b/lbry/wallet/network.py index b117a0164..94ffd5e4f 100644 --- a/lbry/wallet/network.py +++ b/lbry/wallet/network.py @@ -87,7 +87,7 @@ class ClientSession(BaseClientSession): raise except asyncio.CancelledError: log.info("cancelled sending %s to %s:%i", method, *self.server) - self.synchronous_close() + # self.synchronous_close() raise finally: self.pending_amount -= 1 diff --git a/lbry/wallet/rpc/session.py b/lbry/wallet/rpc/session.py index 53c164f4f..58d8a10fc 100644 --- a/lbry/wallet/rpc/session.py +++ b/lbry/wallet/rpc/session.py @@ -33,13 +33,16 @@ from asyncio import Event, CancelledError import logging import time from contextlib import suppress - +from prometheus_client import Counter, Histogram from lbry.wallet.tasks import TaskGroup from .jsonrpc import Request, JSONRPCConnection, JSONRPCv2, JSONRPC, Batch, Notification from .jsonrpc import RPCError, ProtocolError from .framing import BadMagicError, BadChecksumError, OversizedPayloadError, BitcoinFramer, NewlineFramer -from lbry.wallet.server.prometheus import NOTIFICATION_COUNT, RESPONSE_TIMES, REQUEST_ERRORS_COUNT, RESET_CONNECTIONS + +HISTOGRAM_BUCKETS = ( + .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') +) class Connector: @@ -372,10 +375,26 @@ class BatchRequest: raise BatchError(self) +NAMESPACE = "wallet_server" + + class RPCSession(SessionBase): """Base class for protocols where a message can lead to a response, for example JSON RPC.""" + RESPONSE_TIMES = Histogram("response_time", "Response times", namespace=NAMESPACE, + labelnames=("method", "version"), buckets=HISTOGRAM_BUCKETS) + NOTIFICATION_COUNT = Counter("notification", "Number of notifications sent (for subscriptions)", + namespace=NAMESPACE, labelnames=("method", "version")) + REQUEST_ERRORS_COUNT = Counter( + "request_error", "Number of requests that returned errors", namespace=NAMESPACE, + labelnames=("method", "version") + ) + RESET_CONNECTIONS = Counter( + "reset_clients", "Number of reset connections by client version", + namespace=NAMESPACE, labelnames=("version",) + ) + def __init__(self, *, framer=None, loop=None, connection=None): super().__init__(framer=framer, loop=loop) self.connection = connection or self.default_connection() @@ -388,7 +407,7 @@ class RPCSession(SessionBase): except MemoryError: self.logger.warning('received oversized message from %s:%s, dropping connection', self._address[0], self._address[1]) - RESET_CONNECTIONS.labels(version=self.client_version).inc() + self.RESET_CONNECTIONS.labels(version=self.client_version).inc() self._close() return @@ -422,7 +441,7 @@ class RPCSession(SessionBase): 'internal server error') if isinstance(request, Request): message = request.send_result(result) - RESPONSE_TIMES.labels( + self.RESPONSE_TIMES.labels( method=request.method, version=self.client_version ).observe(time.perf_counter() - start) @@ -430,7 +449,7 @@ class RPCSession(SessionBase): await self._send_message(message) if isinstance(result, Exception): self._bump_errors() - REQUEST_ERRORS_COUNT.labels( + self.REQUEST_ERRORS_COUNT.labels( method=request.method, version=self.client_version ).inc() @@ -467,7 +486,7 @@ class RPCSession(SessionBase): async def send_notification(self, method, args=()): """Send an RPC notification over the network.""" message = self.connection.send_notification(Notification(method, args)) - NOTIFICATION_COUNT.labels(method=method, version=self.client_version).inc() + self.NOTIFICATION_COUNT.labels(method=method, version=self.client_version).inc() await self._send_message(message) def send_batch(self, raise_errors=False): diff --git a/lbry/wallet/server/block_processor.py b/lbry/wallet/server/block_processor.py index 44eba7d1a..69a57a2eb 100644 --- a/lbry/wallet/server/block_processor.py +++ b/lbry/wallet/server/block_processor.py @@ -3,6 +3,7 @@ import asyncio from struct import pack, unpack from concurrent.futures.thread import ThreadPoolExecutor from typing import Optional +from prometheus_client import Gauge, Histogram import lbry from lbry.schema.claim import Claim from lbry.wallet.server.db.writer import SQLDB @@ -10,7 +11,6 @@ from lbry.wallet.server.daemon import DaemonError from lbry.wallet.server.hash import hash_to_hex_str, HASHX_LEN from lbry.wallet.server.util import chunks, class_logger from lbry.wallet.server.leveldb import FlushData -from lbry.wallet.server.prometheus import BLOCK_COUNT, BLOCK_UPDATE_TIMES, REORG_COUNT class Prefetcher: @@ -129,6 +129,12 @@ class ChainError(Exception): """Raised on error processing blocks.""" +NAMESPACE = "wallet_server" +HISTOGRAM_BUCKETS = ( + .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') +) + + class BlockProcessor: """Process blocks and update the DB state to match. @@ -136,6 +142,16 @@ class BlockProcessor: Coordinate backing up in case of chain reorganisations. """ + block_count_metric = Gauge( + "block_count", "Number of processed blocks", namespace=NAMESPACE + ) + block_update_time_metric = Histogram( + "block_time", "Block update times", namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS + ) + reorg_count_metric = Gauge( + "reorg_count", "Number of reorgs", namespace=NAMESPACE + ) + def __init__(self, env, db, daemon, notifications): self.env = env self.db = db @@ -199,8 +215,8 @@ class BlockProcessor: cache.clear() await self._maybe_flush() processed_time = time.perf_counter() - start - BLOCK_COUNT.set(self.height) - BLOCK_UPDATE_TIMES.observe(processed_time) + self.block_count_metric.set(self.height) + self.block_update_time_metric.observe(processed_time) if not self.db.first_sync: s = '' if len(blocks) == 1 else 's' self.logger.info('processed {:,d} block{} in {:.1f}s'.format(len(blocks), s, processed_time)) @@ -255,7 +271,7 @@ class BlockProcessor: last -= len(raw_blocks) await self.run_in_thread_with_lock(self.db.sql.delete_claims_above_height, self.height) await self.prefetcher.reset_height(self.height) - REORG_COUNT.inc() + self.reorg_count_metric.inc() async def reorg_hashes(self, count): """Return a pair (start, last, hashes) of blocks to back up during a diff --git a/lbry/wallet/server/daemon.py b/lbry/wallet/server/daemon.py index 960c47024..44a366b6a 100644 --- a/lbry/wallet/server/daemon.py +++ b/lbry/wallet/server/daemon.py @@ -6,11 +6,12 @@ from functools import wraps from pylru import lrucache import aiohttp +from prometheus_client import Gauge, Histogram from lbry.wallet.rpc.jsonrpc import RPCError from lbry.wallet.server.util import hex_to_bytes, class_logger from lbry.wallet.rpc import JSONRPC -from lbry.wallet.server.prometheus import LBRYCRD_REQUEST_TIMES, LBRYCRD_PENDING_COUNT + class DaemonError(Exception): """Raised when the daemon returns an error in its results.""" @@ -24,12 +25,23 @@ class WorkQueueFullError(Exception): """Internal - when the daemon's work queue is full.""" +NAMESPACE = "wallet_server" + + class Daemon: """Handles connections to a daemon at the given URL.""" WARMING_UP = -28 id_counter = itertools.count() + lbrycrd_request_time_metric = Histogram( + "lbrycrd_request", "lbrycrd requests count", namespace=NAMESPACE, labelnames=("method",) + ) + lbrycrd_pending_count_metric = Gauge( + "lbrycrd_pending_count", "Number of lbrycrd rpcs that are in flight", namespace=NAMESPACE, + labelnames=("method",) + ) + def __init__(self, coin, url, max_workqueue=10, init_retry=0.25, max_retry=4.0): self.coin = coin @@ -129,7 +141,7 @@ class Daemon: while True: try: for method in methods: - LBRYCRD_PENDING_COUNT.labels(method=method).inc() + self.lbrycrd_pending_count_metric.labels(method=method).inc() result = await self._send_data(data) result = processor(result) if on_good_message: @@ -154,7 +166,7 @@ class Daemon: on_good_message = 'running normally' finally: for method in methods: - LBRYCRD_PENDING_COUNT.labels(method=method).dec() + self.lbrycrd_pending_count_metric.labels(method=method).dec() await asyncio.sleep(retry) retry = max(min(self.max_retry, retry * 2), self.init_retry) @@ -175,7 +187,7 @@ class Daemon: if params: payload['params'] = params result = await self._send(payload, processor) - LBRYCRD_REQUEST_TIMES.labels(method=method).observe(time.perf_counter() - start) + self.lbrycrd_request_time_metric.labels(method=method).observe(time.perf_counter() - start) return result async def _send_vector(self, method, params_iterable, replace_errs=False): @@ -200,7 +212,7 @@ class Daemon: result = [] if payload: result = await self._send(payload, processor) - LBRYCRD_REQUEST_TIMES.labels(method=method).observe(time.perf_counter()-start) + self.lbrycrd_request_time_metric.labels(method=method).observe(time.perf_counter() - start) return result async def _is_rpc_available(self, method): diff --git a/lbry/wallet/server/db/reader.py b/lbry/wallet/server/db/reader.py index a98330efb..ae252355a 100644 --- a/lbry/wallet/server/db/reader.py +++ b/lbry/wallet/server/db/reader.py @@ -547,10 +547,10 @@ def _apply_constraints_for_array_attributes(constraints, attr, cleaner, for_coun if for_count or attr == 'tag': if attr == 'tag': any_queries[f'#_any_{attr}'] = f""" - (claim.claim_type != {CLAIM_TYPES['repost']} + ((claim.claim_type != {CLAIM_TYPES['repost']} AND claim.claim_hash IN (SELECT claim_hash FROM tag WHERE tag IN ({values}))) OR (claim.claim_type == {CLAIM_TYPES['repost']} AND - claim.reposted_claim_hash IN (SELECT claim_hash FROM tag WHERE tag IN ({values}))) + claim.reposted_claim_hash IN (SELECT claim_hash FROM tag WHERE tag IN ({values})))) """ else: any_queries[f'#_any_{attr}'] = f""" @@ -606,10 +606,10 @@ def _apply_constraints_for_array_attributes(constraints, attr, cleaner, for_coun if for_count: if attr == 'tag': constraints[f'#_not_{attr}'] = f""" - (claim.claim_type != {CLAIM_TYPES['repost']} - AND claim.claim_hash NOT IN (SELECT claim_hash FROM tag WHERE tag IN ({values}))) AND + ((claim.claim_type != {CLAIM_TYPES['repost']} + AND claim.claim_hash NOT IN (SELECT claim_hash FROM tag WHERE tag IN ({values}))) OR (claim.claim_type == {CLAIM_TYPES['repost']} AND - claim.reposted_claim_hash NOT IN (SELECT claim_hash FROM tag WHERE tag IN ({values}))) + claim.reposted_claim_hash NOT IN (SELECT claim_hash FROM tag WHERE tag IN ({values})))) """ else: constraints[f'#_not_{attr}'] = f""" diff --git a/lbry/wallet/server/prometheus.py b/lbry/wallet/server/prometheus.py deleted file mode 100644 index e28976bf9..000000000 --- a/lbry/wallet/server/prometheus.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -from aiohttp import web -from prometheus_client import Counter, Info, generate_latest as prom_generate_latest, Histogram, Gauge -from lbry import __version__ as version -from lbry.build_info import BUILD, COMMIT_HASH, DOCKER_TAG -from lbry.wallet.server import util -import lbry.wallet.server.version as wallet_server_version - -NAMESPACE = "wallet_server" -CPU_COUNT = f"{os.cpu_count()}" -VERSION_INFO = Info('build', 'Wallet server build info (e.g. version, commit hash)', namespace=NAMESPACE) -VERSION_INFO.info({ - 'build': BUILD, - "commit": COMMIT_HASH, - "docker_tag": DOCKER_TAG, - 'version': version, - "min_version": util.version_string(wallet_server_version.PROTOCOL_MIN), - "cpu_count": CPU_COUNT -}) -SESSIONS_COUNT = Gauge("session_count", "Number of connected client sessions", namespace=NAMESPACE, - labelnames=("version", )) -REQUESTS_COUNT = Counter("requests_count", "Number of requests received", namespace=NAMESPACE, - labelnames=("method", "version")) -RESPONSE_TIMES = Histogram("response_time", "Response times", namespace=NAMESPACE, labelnames=("method", "version")) -NOTIFICATION_COUNT = Counter("notification", "Number of notifications sent (for subscriptions)", - namespace=NAMESPACE, labelnames=("method", "version")) -REQUEST_ERRORS_COUNT = Counter("request_error", "Number of requests that returned errors", namespace=NAMESPACE, - labelnames=("method", "version")) -SQLITE_INTERRUPT_COUNT = Counter("interrupt", "Number of interrupted queries", namespace=NAMESPACE) -SQLITE_OPERATIONAL_ERROR_COUNT = Counter( - "operational_error", "Number of queries that raised operational errors", namespace=NAMESPACE -) -SQLITE_INTERNAL_ERROR_COUNT = Counter( - "internal_error", "Number of queries raising unexpected errors", namespace=NAMESPACE -) -SQLITE_EXECUTOR_TIMES = Histogram("executor_time", "SQLite executor times", namespace=NAMESPACE) -SQLITE_PENDING_COUNT = Gauge( - "pending_queries_count", "Number of pending and running sqlite queries", namespace=NAMESPACE -) -LBRYCRD_REQUEST_TIMES = Histogram( - "lbrycrd_request", "lbrycrd requests count", namespace=NAMESPACE, labelnames=("method",) -) -LBRYCRD_PENDING_COUNT = Gauge( - "lbrycrd_pending_count", "Number of lbrycrd rpcs that are in flight", namespace=NAMESPACE, labelnames=("method",) -) -CLIENT_VERSIONS = Counter( - "clients", "Number of connections received per client version", - namespace=NAMESPACE, labelnames=("version",) -) -BLOCK_COUNT = Gauge( - "block_count", "Number of processed blocks", namespace=NAMESPACE -) -BLOCK_UPDATE_TIMES = Histogram("block_time", "Block update times", namespace=NAMESPACE) -REORG_COUNT = Gauge( - "reorg_count", "Number of reorgs", namespace=NAMESPACE -) -RESET_CONNECTIONS = Counter( - "reset_clients", "Number of reset connections by client version", - namespace=NAMESPACE, labelnames=("version",) -) - - -class PrometheusServer: - def __init__(self): - self.logger = util.class_logger(__name__, self.__class__.__name__) - self.runner = None - - async def start(self, port: int): - prom_app = web.Application() - prom_app.router.add_get('/metrics', self.handle_metrics_get_request) - self.runner = web.AppRunner(prom_app) - await self.runner.setup() - - metrics_site = web.TCPSite(self.runner, "0.0.0.0", port, shutdown_timeout=.5) - await metrics_site.start() - self.logger.info('metrics server listening on %s:%i', *metrics_site._server.sockets[0].getsockname()[:2]) - - async def handle_metrics_get_request(self, request: web.Request): - try: - return web.Response( - text=prom_generate_latest().decode(), - content_type='text/plain; version=0.0.4' - ) - except Exception: - self.logger.exception('could not generate prometheus data') - raise - - async def stop(self): - await self.runner.cleanup() diff --git a/lbry/wallet/server/server.py b/lbry/wallet/server/server.py index 4d0374ba4..cca84c852 100644 --- a/lbry/wallet/server/server.py +++ b/lbry/wallet/server/server.py @@ -6,7 +6,7 @@ import typing import lbry from lbry.wallet.server.mempool import MemPool, MemPoolAPI -from lbry.wallet.server.prometheus import PrometheusServer +from lbry.prometheus import PrometheusServer class Notifications: @@ -143,4 +143,4 @@ class Server: async def start_prometheus(self): if not self.prometheus_server and self.env.prometheus_port: self.prometheus_server = PrometheusServer() - await self.prometheus_server.start(self.env.prometheus_port) + await self.prometheus_server.start("0.0.0.0", self.env.prometheus_port) diff --git a/lbry/wallet/server/session.py b/lbry/wallet/server/session.py index 9a9e23558..186755b73 100644 --- a/lbry/wallet/server/session.py +++ b/lbry/wallet/server/session.py @@ -20,16 +20,15 @@ from functools import partial from binascii import hexlify from pylru import lrucache from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from prometheus_client import Counter, Info, Histogram, Gauge import lbry +from lbry.build_info import BUILD, COMMIT_HASH, DOCKER_TAG from lbry.wallet.server.block_processor import LBRYBlockProcessor from lbry.wallet.server.db.writer import LBRYLevelDB from lbry.wallet.server.db import reader from lbry.wallet.server.websocket import AdminWebSocket from lbry.wallet.server.metrics import ServerLoadData, APICallMetrics -from lbry.wallet.server.prometheus import REQUESTS_COUNT, SQLITE_INTERRUPT_COUNT, SQLITE_INTERNAL_ERROR_COUNT -from lbry.wallet.server.prometheus import SQLITE_OPERATIONAL_ERROR_COUNT, SQLITE_EXECUTOR_TIMES, SESSIONS_COUNT -from lbry.wallet.server.prometheus import SQLITE_PENDING_COUNT, CLIENT_VERSIONS from lbry.wallet.rpc.framing import NewlineFramer import lbry.wallet.server.version as VERSION @@ -119,9 +118,49 @@ class SessionGroup: self.semaphore = asyncio.Semaphore(20) +NAMESPACE = "wallet_server" +HISTOGRAM_BUCKETS = ( + .005, .01, .025, .05, .075, .1, .25, .5, .75, 1.0, 2.5, 5.0, 7.5, 10.0, 15.0, 20.0, 30.0, 60.0, float('inf') +) + class SessionManager: """Holds global state about all sessions.""" + version_info_metric = Info( + 'build', 'Wallet server build info (e.g. version, commit hash)', namespace=NAMESPACE + ) + version_info_metric.info({ + 'build': BUILD, + "commit": COMMIT_HASH, + "docker_tag": DOCKER_TAG, + 'version': lbry.__version__, + "min_version": util.version_string(VERSION.PROTOCOL_MIN), + "cpu_count": os.cpu_count() + }) + session_count_metric = Gauge("session_count", "Number of connected client sessions", namespace=NAMESPACE, + labelnames=("version",)) + request_count_metric = Counter("requests_count", "Number of requests received", namespace=NAMESPACE, + labelnames=("method", "version")) + + interrupt_count_metric = Counter("interrupt", "Number of interrupted queries", namespace=NAMESPACE) + db_operational_error_metric = Counter( + "operational_error", "Number of queries that raised operational errors", namespace=NAMESPACE + ) + db_error_metric = Counter( + "internal_error", "Number of queries raising unexpected errors", namespace=NAMESPACE + ) + executor_time_metric = Histogram( + "executor_time", "SQLite executor times", namespace=NAMESPACE, buckets=HISTOGRAM_BUCKETS + ) + pending_query_metric = Gauge( + "pending_queries_count", "Number of pending and running sqlite queries", namespace=NAMESPACE + ) + + client_version_metric = Counter( + "clients", "Number of connections received per client version", + namespace=NAMESPACE, labelnames=("version",) + ) + def __init__(self, env: 'Env', db: LBRYLevelDB, bp: LBRYBlockProcessor, daemon: 'Daemon', mempool: 'MemPool', shutdown_event: asyncio.Event): env.max_send = max(350000, env.max_send) @@ -677,7 +716,7 @@ class SessionBase(RPCSession): context = {'conn_id': f'{self.session_id}'} self.logger = util.ConnectionLogger(self.logger, context) self.group = self.session_mgr.add_session(self) - SESSIONS_COUNT.labels(version=self.client_version).inc() + self.session_mgr.session_count_metric.labels(version=self.client_version).inc() peer_addr_str = self.peer_address_str() self.logger.info(f'{self.kind} {peer_addr_str}, ' f'{self.session_mgr.session_count():,d} total') @@ -686,7 +725,7 @@ class SessionBase(RPCSession): """Handle client disconnection.""" super().connection_lost(exc) self.session_mgr.remove_session(self) - SESSIONS_COUNT.labels(version=self.client_version).dec() + self.session_mgr.session_count_metric.labels(version=self.client_version).dec() msg = '' if not self._can_send.is_set(): msg += ' whilst paused' @@ -710,7 +749,7 @@ class SessionBase(RPCSession): """Handle an incoming request. ElectrumX doesn't receive notifications from client sessions. """ - REQUESTS_COUNT.labels(method=request.method, version=self.client_version).inc() + self.session_mgr.request_count_metric.labels(method=request.method, version=self.client_version).inc() if isinstance(request, Request): handler = self.request_handlers.get(request.method) handler = partial(handler, self) @@ -946,7 +985,7 @@ class LBRYElectrumX(SessionBase): async def run_in_executor(self, query_name, func, kwargs): start = time.perf_counter() try: - SQLITE_PENDING_COUNT.inc() + self.session_mgr.pending_query_metric.inc() result = await asyncio.get_running_loop().run_in_executor( self.session_mgr.query_executor, func, kwargs ) @@ -955,18 +994,18 @@ class LBRYElectrumX(SessionBase): except reader.SQLiteInterruptedError as error: metrics = self.get_metrics_or_placeholder_for_api(query_name) metrics.query_interrupt(start, error.metrics) - SQLITE_INTERRUPT_COUNT.inc() + self.session_mgr.interrupt_count_metric.inc() raise RPCError(JSONRPC.QUERY_TIMEOUT, 'sqlite query timed out') except reader.SQLiteOperationalError as error: metrics = self.get_metrics_or_placeholder_for_api(query_name) metrics.query_error(start, error.metrics) - SQLITE_OPERATIONAL_ERROR_COUNT.inc() + self.session_mgr.db_operational_error_metric.inc() raise RPCError(JSONRPC.INTERNAL_ERROR, 'query failed to execute') except Exception: log.exception("dear devs, please handle this exception better") metrics = self.get_metrics_or_placeholder_for_api(query_name) metrics.query_error(start, {}) - SQLITE_INTERNAL_ERROR_COUNT.inc() + self.session_mgr.db_error_metric.inc() raise RPCError(JSONRPC.INTERNAL_ERROR, 'unknown server error') else: if self.env.track_metrics: @@ -975,8 +1014,8 @@ class LBRYElectrumX(SessionBase): metrics.query_response(start, metrics_data) return base64.b64encode(result).decode() finally: - SQLITE_PENDING_COUNT.dec() - SQLITE_EXECUTOR_TIMES.observe(time.perf_counter() - start) + self.session_mgr.pending_query_metric.dec() + self.session_mgr.executor_time_metric.observe(time.perf_counter() - start) async def run_and_cache_query(self, query_name, function, kwargs): metrics = self.get_metrics_or_placeholder_for_api(query_name) @@ -1181,7 +1220,10 @@ class LBRYElectrumX(SessionBase): return await self.address_status(hashX) async def hashX_unsubscribe(self, hashX, alias): - del self.hashX_subs[hashX] + try: + del self.hashX_subs[hashX] + except ValueError: + pass def address_to_hashX(self, address): try: @@ -1443,10 +1485,10 @@ class LBRYElectrumX(SessionBase): raise RPCError(BAD_REQUEST, f'unsupported client: {client_name}') if self.client_version != client_name[:17]: - SESSIONS_COUNT.labels(version=self.client_version).dec() + self.session_mgr.session_count_metric.labels(version=self.client_version).dec() self.client_version = client_name[:17] - SESSIONS_COUNT.labels(version=self.client_version).inc() - CLIENT_VERSIONS.labels(version=self.client_version).inc() + self.session_mgr.session_count_metric.labels(version=self.client_version).inc() + self.session_mgr.client_version_metric.labels(version=self.client_version).inc() # Find the highest common protocol version. Disconnect if # that protocol version in unsupported. diff --git a/lbry/wallet/transaction.py b/lbry/wallet/transaction.py index df9beddcf..0f32ffdd3 100644 --- a/lbry/wallet/transaction.py +++ b/lbry/wallet/transaction.py @@ -2,6 +2,7 @@ import struct import hashlib import logging import typing +import asyncio from binascii import hexlify, unhexlify from typing import List, Iterable, Optional, Tuple @@ -412,8 +413,10 @@ class Output(InputOutput): self.channel = None self.claim.clear_signature() - def generate_channel_private_key(self): - self.private_key = ecdsa.SigningKey.generate(curve=ecdsa.SECP256k1, hashfunc=hashlib.sha256) + async def generate_channel_private_key(self): + self.private_key = await asyncio.get_event_loop().run_in_executor( + None, ecdsa.SigningKey.generate, ecdsa.SECP256k1, None, hashlib.sha256 + ) self.claim.channel.public_key_bytes = self.private_key.get_verifying_key().to_der() self.script.generate() return self.private_key diff --git a/scripts/download_blob_from_peer.py b/scripts/download_blob_from_peer.py index 08f046abc..3418ab1f8 100644 --- a/scripts/download_blob_from_peer.py +++ b/scripts/download_blob_from_peer.py @@ -3,6 +3,7 @@ import os import asyncio import socket import ipaddress +import lbry.wallet from lbry.conf import Config from lbry.extras.daemon.storage import SQLiteStorage from lbry.blob.blob_manager import BlobManager diff --git a/setup.cfg b/setup.cfg index c5d268dbb..e8bc1920b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,7 +6,7 @@ source = lbry .tox/*/lib/python*/site-packages/lbry -[cryptography.*,coincurve.*,pbkdf2] +[cryptography.*,coincurve.*,pbkdf2, libtorrent] ignore_missing_imports = True [pylint] @@ -18,6 +18,7 @@ max-line-length=120 good-names=T,t,n,i,j,k,x,y,s,f,d,h,c,e,op,db,tx,io,cachedproperty,log,id,r,iv,ts,l valid-metaclass-classmethod-first-arg=mcs disable= + c-extension-no-member, fixme, broad-except, no-else-return, diff --git a/tests/integration/blockchain/test_blockchain_reorganization.py b/tests/integration/blockchain/test_blockchain_reorganization.py index 216030839..af9349e67 100644 --- a/tests/integration/blockchain/test_blockchain_reorganization.py +++ b/tests/integration/blockchain/test_blockchain_reorganization.py @@ -2,7 +2,6 @@ import logging import asyncio from binascii import hexlify from lbry.testcase import CommandTestCase -from lbry.wallet.server.prometheus import REORG_COUNT class BlockchainReorganizationTests(CommandTestCase): @@ -16,7 +15,8 @@ class BlockchainReorganizationTests(CommandTestCase): ) async def test_reorg(self): - REORG_COUNT.set(0) + bp = self.conductor.spv_node.server.bp + bp.reorg_count_metric.set(0) # invalidate current block, move forward 2 self.assertEqual(self.ledger.headers.height, 206) await self.assertBlockHash(206) @@ -26,7 +26,7 @@ class BlockchainReorganizationTests(CommandTestCase): self.assertEqual(self.ledger.headers.height, 207) await self.assertBlockHash(206) await self.assertBlockHash(207) - self.assertEqual(1, REORG_COUNT._samples()[0][2]) + self.assertEqual(1, bp.reorg_count_metric._samples()[0][2]) # invalidate current block, move forward 3 await self.blockchain.invalidate_block((await self.ledger.headers.hash(206)).decode()) @@ -36,7 +36,7 @@ class BlockchainReorganizationTests(CommandTestCase): await self.assertBlockHash(206) await self.assertBlockHash(207) await self.assertBlockHash(208) - self.assertEqual(2, REORG_COUNT._samples()[0][2]) + self.assertEqual(2, bp.reorg_count_metric._samples()[0][2]) async def test_reorg_change_claim_height(self): # sanity check diff --git a/tests/integration/blockchain/test_claim_commands.py b/tests/integration/blockchain/test_claim_commands.py index eda369674..4873e0cde 100644 --- a/tests/integration/blockchain/test_claim_commands.py +++ b/tests/integration/blockchain/test_claim_commands.py @@ -661,12 +661,15 @@ class ClaimCommands(ClaimTestCase): channel_id = self.get_claim_id(await self.channel_create()) stream_id = self.get_claim_id(await self.stream_create()) + await self.stream_update(stream_id, title='foo') + # type filtering r = await self.claim_list(claim_type='channel') self.assertEqual(1, len(r)) self.assertEqual('channel', r[0]['value_type']) - r = await self.claim_list(claim_type='stream') + # catch a bug where cli sends is_spent=False by default + r = await self.claim_list(claim_type='stream', is_spent=False) self.assertEqual(1, len(r)) self.assertEqual('stream', r[0]['value_type']) diff --git a/tests/integration/blockchain/test_internal_transaction_api.py b/tests/integration/blockchain/test_internal_transaction_api.py index 2bb4ac944..6eba5e229 100644 --- a/tests/integration/blockchain/test_internal_transaction_api.py +++ b/tests/integration/blockchain/test_internal_transaction_api.py @@ -31,7 +31,7 @@ class BasicTransactionTest(IntegrationTestCase): channel_txo = Output.pay_claim_name_pubkey_hash( l2d('1.0'), '@bar', channel, self.account.ledger.address_to_hash160(address1) ) - channel_txo.generate_channel_private_key() + await channel_txo.generate_channel_private_key() channel_txo.script.generate() channel_tx = await Transaction.create([], [channel_txo], [self.account], self.account) diff --git a/tests/integration/datanetwork/test_file_commands.py b/tests/integration/datanetwork/test_file_commands.py index 1ab06d088..df46c6fab 100644 --- a/tests/integration/datanetwork/test_file_commands.py +++ b/tests/integration/datanetwork/test_file_commands.py @@ -2,10 +2,60 @@ import asyncio import os from binascii import hexlify +from lbry.schema import Claim from lbry.testcase import CommandTestCase +from lbry.torrent.session import TorrentSession +from lbry.wallet import Transaction class FileCommands(CommandTestCase): + async def initialize_torrent(self, tx_to_update=None): + if not hasattr(self, 'seeder_session'): + self.seeder_session = TorrentSession(self.loop, None) + self.addCleanup(self.seeder_session.stop) + await self.seeder_session.bind(port=4040) + btih = await self.seeder_session.add_fake_torrent() + address = await self.account.receiving.get_or_create_usable_address() + if not tx_to_update: + claim = Claim() + claim.stream.update(bt_infohash=btih) + tx = await Transaction.claim_create( + 'torrent', claim, 1, address, [self.account], self.account + ) + else: + claim = tx_to_update.outputs[0].claim + claim.stream.update(bt_infohash=btih) + tx = await Transaction.claim_update( + tx_to_update.outputs[0], claim, 1, address, [self.account], self.account + ) + await tx.sign([self.account]) + await self.broadcast(tx) + await self.confirm_tx(tx.id) + self.client_session = self.daemon.file_manager.source_managers['torrent'].torrent_session + self.client_session._session.add_dht_node(('localhost', 4040)) + self.client_session.wait_start = False # fixme: this is super slow on tests + return tx, btih + + async def test_download_torrent(self): + tx, btih = await self.initialize_torrent() + self.assertNotIn('error', await self.out(self.daemon.jsonrpc_get('torrent'))) + self.assertItemCount(await self.daemon.jsonrpc_file_list(), 1) + # second call, see its there and move on + self.assertNotIn('error', await self.out(self.daemon.jsonrpc_get('torrent'))) + self.assertItemCount(await self.daemon.jsonrpc_file_list(), 1) + self.assertEqual((await self.daemon.jsonrpc_file_list())['items'][0].identifier, btih) + self.assertIn(btih, self.client_session._handles) + tx, new_btih = await self.initialize_torrent(tx) + self.assertNotEqual(btih, new_btih) + # claim now points to another torrent, update to it + self.assertNotIn('error', await self.out(self.daemon.jsonrpc_get('torrent'))) + self.assertEqual((await self.daemon.jsonrpc_file_list())['items'][0].identifier, new_btih) + self.assertIn(new_btih, self.client_session._handles) + self.assertNotIn(btih, self.client_session._handles) + self.assertItemCount(await self.daemon.jsonrpc_file_list(), 1) + await self.daemon.jsonrpc_file_delete(delete_all=True) + self.assertItemCount(await self.daemon.jsonrpc_file_list(), 0) + self.assertNotIn(new_btih, self.client_session._handles) async def create_streams_in_range(self, *args, **kwargs): self.stream_claim_ids = [] @@ -228,11 +278,11 @@ class FileCommands(CommandTestCase): await self.daemon.jsonrpc_get('lbry://foo') with open(original_path, 'wb') as handle: handle.write(b'some other stuff was there instead') - self.daemon.stream_manager.stop() - await self.daemon.stream_manager.start() + self.daemon.file_manager.stop() + await self.daemon.file_manager.start() await asyncio.wait_for(self.wait_files_to_complete(), timeout=5) # if this hangs, file didn't get set completed # check that internal state got through up to the file list API - stream = self.daemon.stream_manager.get_stream_by_stream_hash(file_info['stream_hash']) + stream = self.daemon.file_manager.get_filtered(stream_hash=file_info['stream_hash'])[0] file_info = (await self.file_list())[0] self.assertEqual(stream.file_name, file_info['file_name']) # checks if what the API shows is what he have at the very internal level. @@ -255,7 +305,7 @@ class FileCommands(CommandTestCase): resp = await self.out(self.daemon.jsonrpc_get('lbry://foo', timeout=2)) self.assertNotIn('error', resp) self.assertTrue(os.path.isfile(path)) - self.daemon.stream_manager.stop() + self.daemon.file_manager.stop() await asyncio.sleep(0.01, loop=self.loop) # FIXME: this sleep should not be needed self.assertFalse(os.path.isfile(path)) @@ -348,8 +398,8 @@ class FileCommands(CommandTestCase): # restart the daemon and make sure the fee is still there - self.daemon.stream_manager.stop() - await self.daemon.stream_manager.start() + self.daemon.file_manager.stop() + await self.daemon.file_manager.start() self.assertItemCount(await self.daemon.jsonrpc_file_list(), 1) self.assertEqual((await self.daemon.jsonrpc_file_list())['items'][0].content_fee.raw, raw_content_fee) await self.daemon.jsonrpc_file_delete(claim_name='icanpay') diff --git a/tests/integration/datanetwork/test_streaming.py b/tests/integration/datanetwork/test_streaming.py index e6d572e94..856a3c090 100644 --- a/tests/integration/datanetwork/test_streaming.py +++ b/tests/integration/datanetwork/test_streaming.py @@ -21,8 +21,8 @@ def get_random_bytes(n: int) -> bytes: class RangeRequests(CommandTestCase): async def _restart_stream_manager(self): - self.daemon.stream_manager.stop() - await self.daemon.stream_manager.start() + self.daemon.file_manager.stop() + await self.daemon.file_manager.start() return async def _setup_stream(self, data: bytes, save_blobs: bool = True, save_files: bool = False, file_size=0): diff --git a/tests/integration/other/test_cli.py b/tests/integration/other/test_cli.py index 59b629747..459d2171a 100644 --- a/tests/integration/other/test_cli.py +++ b/tests/integration/other/test_cli.py @@ -6,7 +6,7 @@ from lbry.conf import Config from lbry.extras import cli from lbry.extras.daemon.components import ( DATABASE_COMPONENT, BLOB_COMPONENT, WALLET_COMPONENT, DHT_COMPONENT, - HASH_ANNOUNCER_COMPONENT, STREAM_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, + HASH_ANNOUNCER_COMPONENT, FILE_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, WALLET_SERVER_PAYMENTS_COMPONENT ) from lbry.extras.daemon.daemon import Daemon @@ -21,7 +21,7 @@ class CLIIntegrationTest(AsyncioTestCase): conf.api = 'localhost:5299' conf.components_to_skip = ( DATABASE_COMPONENT, BLOB_COMPONENT, WALLET_COMPONENT, DHT_COMPONENT, - HASH_ANNOUNCER_COMPONENT, STREAM_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, + HASH_ANNOUNCER_COMPONENT, FILE_MANAGER_COMPONENT, PEER_PROTOCOL_SERVER_COMPONENT, UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT, WALLET_SERVER_PAYMENTS_COMPONENT ) Daemon.component_attributes = {} @@ -34,4 +34,4 @@ class CLIIntegrationTest(AsyncioTestCase): with contextlib.redirect_stdout(actual_output): cli.main(["--api", "localhost:5299", "status"]) actual_output = actual_output.getvalue() - self.assertIn("connection_status", actual_output) \ No newline at end of file + self.assertIn("connection_status", actual_output) diff --git a/tests/unit/blob/test_blob_manager.py b/tests/unit/blob/test_blob_manager.py index c868890f1..788ab0953 100644 --- a/tests/unit/blob/test_blob_manager.py +++ b/tests/unit/blob/test_blob_manager.py @@ -16,7 +16,7 @@ class TestBlobManager(AsyncioTestCase): self.blob_manager = BlobManager(self.loop, tmp_dir, self.storage, self.config) await self.storage.open() - async def test_memory_blobs_arent_verifie_but_real_ones_are(self): + async def test_memory_blobs_arent_verified_but_real_ones_are(self): for save_blobs in (False, True): await self.setup_blob_manager(save_blobs=save_blobs) # add a blob file @@ -24,6 +24,7 @@ class TestBlobManager(AsyncioTestCase): blob_bytes = b'1' * ((2 * 2 ** 20) - 1) blob = self.blob_manager.get_blob(blob_hash, len(blob_bytes)) blob.save_verified_blob(blob_bytes) + await blob.verified.wait() self.assertTrue(blob.get_is_verified()) self.blob_manager.blob_completed(blob) self.assertEqual(self.blob_manager.is_blob_verified(blob_hash), save_blobs) diff --git a/tests/unit/blob_exchange/test_transfer_blob.py b/tests/unit/blob_exchange/test_transfer_blob.py index f7c011e3b..b6339f375 100644 --- a/tests/unit/blob_exchange/test_transfer_blob.py +++ b/tests/unit/blob_exchange/test_transfer_blob.py @@ -34,13 +34,13 @@ class BlobExchangeTestBase(AsyncioTestCase): self.addCleanup(shutil.rmtree, self.client_dir) self.addCleanup(shutil.rmtree, self.server_dir) self.server_config = Config(data_dir=self.server_dir, download_dir=self.server_dir, wallet=self.server_dir, - reflector_servers=[]) + fixed_peers=[]) self.server_storage = SQLiteStorage(self.server_config, os.path.join(self.server_dir, "lbrynet.sqlite")) self.server_blob_manager = BlobManager(self.loop, self.server_dir, self.server_storage, self.server_config) self.server = BlobServer(self.loop, self.server_blob_manager, 'bQEaw42GXsgCAGio1nxFncJSyRmnztSCjP') self.client_config = Config(data_dir=self.client_dir, download_dir=self.client_dir, wallet=self.client_dir, - reflector_servers=[]) + fixed_peers=[]) self.client_storage = SQLiteStorage(self.client_config, os.path.join(self.client_dir, "lbrynet.sqlite")) self.client_blob_manager = BlobManager(self.loop, self.client_dir, self.client_storage, self.client_config) self.client_peer_manager = PeerManager(self.loop) @@ -130,10 +130,14 @@ class TestBlobExchange(BlobExchangeTestBase): write_blob = blob._write_blob write_called_count = 0 - def wrap_write_blob(blob_bytes): + async def _wrap_write_blob(blob_bytes): nonlocal write_called_count write_called_count += 1 - write_blob(blob_bytes) + await write_blob(blob_bytes) + + def wrap_write_blob(blob_bytes): + return asyncio.create_task(_wrap_write_blob(blob_bytes)) + blob._write_blob = wrap_write_blob writer1 = blob.get_blob_writer(peer_port=1) @@ -166,6 +170,7 @@ class TestBlobExchange(BlobExchangeTestBase): self.assertDictEqual({1: mock_blob_bytes, 2: mock_blob_bytes}, results) self.assertEqual(1, write_called_count) + await blob.verified.wait() self.assertTrue(blob.get_is_verified()) self.assertDictEqual({}, blob.writers) diff --git a/tests/unit/comments/test_comment_signing.py b/tests/unit/comments/test_comment_signing.py index 9cdfd3d69..ceee8a8a2 100644 --- a/tests/unit/comments/test_comment_signing.py +++ b/tests/unit/comments/test_comment_signing.py @@ -18,31 +18,31 @@ class TestSigningComments(AsyncioTestCase): 'comment_id': hashlib.sha256(comment.encode()).hexdigest() } - def test01_successful_create_sign_and_validate_comment(self): - channel = get_channel('@BusterBluth') + async def test01_successful_create_sign_and_validate_comment(self): + channel = await get_channel('@BusterBluth') stream = get_stream('pop secret') comment = self.create_claim_comment_body('Cool stream', stream, channel) sign_comment(comment, channel) self.assertTrue(is_comment_signed_by_channel(comment, channel)) - def test02_fail_to_validate_spoofed_channel(self): - pdiddy = get_channel('@PDitty') - channel2 = get_channel('@TomHaverford') + async def test02_fail_to_validate_spoofed_channel(self): + pdiddy = await get_channel('@PDitty') + channel2 = await get_channel('@TomHaverford') stream = get_stream() comment = self.create_claim_comment_body('Woahh This is Sick!! Shout out 2 my boy Tommy H', stream, pdiddy) sign_comment(comment, channel2) self.assertFalse(is_comment_signed_by_channel(comment, pdiddy)) - def test03_successful_sign_abandon_comment(self): - rswanson = get_channel('@RonSwanson') + async def test03_successful_sign_abandon_comment(self): + rswanson = await get_channel('@RonSwanson') dsilver = get_stream('Welcome to the Pawnee, and give a big round for Ron Swanson, AKA Duke Silver') comment_body = self.create_claim_comment_body('COMPUTER, DELETE ALL VIDEOS OF RON.', dsilver, rswanson) sign_comment(comment_body, rswanson, abandon=True) self.assertTrue(is_comment_signed_by_channel(comment_body, rswanson, abandon=True)) - def test04_invalid_signature(self): - rswanson = get_channel('@RonSwanson') - jeanralphio = get_channel('@JeanRalphio') + async def test04_invalid_signature(self): + rswanson = await get_channel('@RonSwanson') + jeanralphio = await get_channel('@JeanRalphio') chair = get_stream('This is a nice chair. I made it with Mahogany wood and this electric saw') chair_comment = self.create_claim_comment_body( 'Hah. You use an electric saw? Us swansons have been making chairs with handsaws just three after birth.', diff --git a/tests/unit/components/test_component_manager.py b/tests/unit/components/test_component_manager.py index d8d2ed5a9..48d844168 100644 --- a/tests/unit/components/test_component_manager.py +++ b/tests/unit/components/test_component_manager.py @@ -16,6 +16,7 @@ class TestComponentManager(AsyncioTestCase): [ components.DatabaseComponent, components.ExchangeRateManagerComponent, + components.TorrentComponent, components.UPnPComponent ], [ @@ -24,9 +25,9 @@ class TestComponentManager(AsyncioTestCase): components.WalletComponent ], [ + components.FileManagerComponent, components.HashAnnouncerComponent, components.PeerProtocolServerComponent, - components.StreamManagerComponent, components.WalletServerPaymentsComponent ] ] @@ -119,6 +120,8 @@ class FakeComponent: class FakeDelayedWallet(FakeComponent): component_name = "wallet" depends_on = [] + ledger = None + default_wallet = None async def stop(self): await asyncio.sleep(1) @@ -135,8 +138,8 @@ class FakeDelayedBlobManager(FakeComponent): await asyncio.sleep(1) -class FakeDelayedStreamManager(FakeComponent): - component_name = "stream_manager" +class FakeDelayedFileManager(FakeComponent): + component_name = "file_manager" depends_on = [FakeDelayedBlobManager.component_name] async def start(self): @@ -153,7 +156,7 @@ class TestComponentManagerProperStart(AdvanceTimeTestCase): PEER_PROTOCOL_SERVER_COMPONENT, UPNP_COMPONENT, EXCHANGE_RATE_MANAGER_COMPONENT], wallet=FakeDelayedWallet, - stream_manager=FakeDelayedStreamManager, + file_manager=FakeDelayedFileManager, blob_manager=FakeDelayedBlobManager ) @@ -163,17 +166,17 @@ class TestComponentManagerProperStart(AdvanceTimeTestCase): await self.advance(0) self.assertTrue(self.component_manager.get_component('wallet').running) self.assertFalse(self.component_manager.get_component('blob_manager').running) - self.assertFalse(self.component_manager.get_component('stream_manager').running) + self.assertFalse(self.component_manager.get_component('file_manager').running) await self.advance(1) self.assertTrue(self.component_manager.get_component('wallet').running) self.assertTrue(self.component_manager.get_component('blob_manager').running) - self.assertFalse(self.component_manager.get_component('stream_manager').running) + self.assertFalse(self.component_manager.get_component('file_manager').running) await self.advance(1) self.assertTrue(self.component_manager.get_component('wallet').running) self.assertTrue(self.component_manager.get_component('blob_manager').running) - self.assertTrue(self.component_manager.get_component('stream_manager').running) + self.assertTrue(self.component_manager.get_component('file_manager').running) async def test_proper_stopping_of_components(self): asyncio.create_task(self.component_manager.start()) @@ -182,18 +185,18 @@ class TestComponentManagerProperStart(AdvanceTimeTestCase): await self.advance(1) self.assertTrue(self.component_manager.get_component('wallet').running) self.assertTrue(self.component_manager.get_component('blob_manager').running) - self.assertTrue(self.component_manager.get_component('stream_manager').running) + self.assertTrue(self.component_manager.get_component('file_manager').running) asyncio.create_task(self.component_manager.stop()) await self.advance(0) - self.assertFalse(self.component_manager.get_component('stream_manager').running) + self.assertFalse(self.component_manager.get_component('file_manager').running) self.assertTrue(self.component_manager.get_component('blob_manager').running) self.assertTrue(self.component_manager.get_component('wallet').running) await self.advance(1) - self.assertFalse(self.component_manager.get_component('stream_manager').running) + self.assertFalse(self.component_manager.get_component('file_manager').running) self.assertFalse(self.component_manager.get_component('blob_manager').running) self.assertTrue(self.component_manager.get_component('wallet').running) await self.advance(1) - self.assertFalse(self.component_manager.get_component('stream_manager').running) + self.assertFalse(self.component_manager.get_component('file_manager').running) self.assertFalse(self.component_manager.get_component('blob_manager').running) self.assertFalse(self.component_manager.get_component('wallet').running) diff --git a/tests/unit/stream/test_managed_stream.py b/tests/unit/stream/test_managed_stream.py index dbdfa5157..3542c60e4 100644 --- a/tests/unit/stream/test_managed_stream.py +++ b/tests/unit/stream/test_managed_stream.py @@ -76,7 +76,8 @@ class TestManagedStream(BlobExchangeTestBase): return q2, self.loop.create_task(_task()) mock_node.accumulate_peers = mock_accumulate_peers or _mock_accumulate_peers - await self.stream.save_file(node=mock_node) + self.stream.downloader.node = mock_node + await self.stream.save_file() await self.stream.finished_write_attempt.wait() self.assertTrue(os.path.isfile(self.stream.full_path)) if stop_when_done: @@ -109,7 +110,6 @@ class TestManagedStream(BlobExchangeTestBase): await self.setup_stream(2) mock_node = mock.Mock(spec=Node) - q = asyncio.Queue() bad_peer = make_kademlia_peer(b'2' * 48, "127.0.0.1", tcp_port=3334, allow_localhost=True) @@ -123,7 +123,8 @@ class TestManagedStream(BlobExchangeTestBase): mock_node.accumulate_peers = _mock_accumulate_peers - await self.stream.save_file(node=mock_node) + self.stream.downloader.node = mock_node + await self.stream.save_file() await self.stream.finished_writing.wait() self.assertTrue(os.path.isfile(self.stream.full_path)) with open(self.stream.full_path, 'rb') as f: diff --git a/tests/unit/stream/test_reflector.py b/tests/unit/stream/test_reflector.py index 4845948d1..b47cf31d0 100644 --- a/tests/unit/stream/test_reflector.py +++ b/tests/unit/stream/test_reflector.py @@ -39,7 +39,7 @@ class TestStreamAssembler(AsyncioTestCase): with open(file_path, 'wb') as f: f.write(self.cleartext) - self.stream = await self.stream_manager.create_stream(file_path) + self.stream = await self.stream_manager.create(file_path) async def _test_reflect_stream(self, response_chunk_size): reflector = ReflectorServer(self.server_blob_manager, response_chunk_size=response_chunk_size) diff --git a/tests/unit/stream/test_stream_manager.py b/tests/unit/stream/test_stream_manager.py index e33064503..8c17f61e4 100644 --- a/tests/unit/stream/test_stream_manager.py +++ b/tests/unit/stream/test_stream_manager.py @@ -5,6 +5,8 @@ from unittest import mock import asyncio import json from decimal import Decimal + +from lbry.file.file_manager import FileManager from tests.unit.blob_exchange.test_transfer_blob import BlobExchangeTestBase from lbry.testcase import get_fake_exchange_rate_manager from lbry.utils import generate_id @@ -110,10 +112,7 @@ async def get_mock_wallet(sd_hash, storage, balance=10.0, fee=None): async def mock_resolve(*args, **kwargs): result = {txo.meta['permanent_url']: txo} - claims = [ - StreamManager._convert_to_old_resolve_output(manager, result)[txo.meta['permanent_url']] - ] - await storage.save_claims(claims) + await storage.save_claim_from_output(ledger, txo) return result manager.ledger.resolve = mock_resolve @@ -138,11 +137,20 @@ class TestStreamManager(BlobExchangeTestBase): ) self.sd_hash = descriptor.sd_hash self.mock_wallet, self.uri = await get_mock_wallet(self.sd_hash, self.client_storage, balance, fee) - self.stream_manager = StreamManager(self.loop, self.client_config, self.client_blob_manager, self.mock_wallet, - self.client_storage, get_mock_node(self.server_from_client), - AnalyticsManager(self.client_config, - binascii.hexlify(generate_id()).decode(), - binascii.hexlify(generate_id()).decode())) + analytics_manager = AnalyticsManager( + self.client_config, + binascii.hexlify(generate_id()).decode(), + binascii.hexlify(generate_id()).decode() + ) + self.stream_manager = StreamManager( + self.loop, self.client_config, self.client_blob_manager, self.mock_wallet, + self.client_storage, get_mock_node(self.server_from_client), + analytics_manager + ) + self.file_manager = FileManager( + self.loop, self.client_config, self.mock_wallet, self.client_storage, analytics_manager + ) + self.file_manager.source_managers['stream'] = self.stream_manager self.exchange_rate_manager = get_fake_exchange_rate_manager() async def _test_time_to_first_bytes(self, check_post, error=None, after_setup=None): @@ -159,9 +167,9 @@ class TestStreamManager(BlobExchangeTestBase): self.stream_manager.analytics_manager._post = _check_post if error: with self.assertRaises(error): - await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) + await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager) else: - await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) + await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager) await asyncio.sleep(0, loop=self.loop) self.assertTrue(checked_analytics_event) @@ -179,7 +187,7 @@ class TestStreamManager(BlobExchangeTestBase): await self._test_time_to_first_bytes(check_post) async def test_fixed_peer_delay_dht_peers_found(self): - self.client_config.reflector_servers = [(self.server_from_client.address, self.server_from_client.tcp_port - 1)] + self.client_config.fixed_peers = [(self.server_from_client.address, self.server_from_client.tcp_port)] server_from_client = None self.server_from_client, server_from_client = server_from_client, self.server_from_client @@ -223,7 +231,7 @@ class TestStreamManager(BlobExchangeTestBase): await self._test_time_to_first_bytes(check_post, DownloadSDTimeoutError, after_setup=after_setup) async def test_override_fixed_peer_delay_dht_disabled(self): - self.client_config.reflector_servers = [(self.server_from_client.address, self.server_from_client.tcp_port - 1)] + self.client_config.fixed_peers = [(self.server_from_client.address, self.server_from_client.tcp_port)] self.client_config.components_to_skip = ['dht', 'hash_announcer'] self.client_config.fixed_peer_delay = 9001.0 self.server_from_client = None @@ -281,7 +289,7 @@ class TestStreamManager(BlobExchangeTestBase): self.stream_manager.analytics_manager._post = check_post self.assertDictEqual(self.stream_manager.streams, {}) - stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) + stream = await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager) stream_hash = stream.stream_hash self.assertDictEqual(self.stream_manager.streams, {stream.sd_hash: stream}) self.assertTrue(stream.running) @@ -302,7 +310,8 @@ class TestStreamManager(BlobExchangeTestBase): ) self.assertEqual(stored_status, "stopped") - await stream.save_file(node=self.stream_manager.node) + stream.downloader.node = self.stream_manager.node + await stream.save_file() await stream.finished_writing.wait() await asyncio.sleep(0, loop=self.loop) self.assertTrue(stream.finished) @@ -313,7 +322,7 @@ class TestStreamManager(BlobExchangeTestBase): ) self.assertEqual(stored_status, "finished") - await self.stream_manager.delete_stream(stream, True) + await self.stream_manager.delete(stream, True) self.assertDictEqual(self.stream_manager.streams, {}) self.assertFalse(os.path.isfile(os.path.join(self.client_dir, "test_file"))) stored_status = await self.client_storage.run_and_return_one_or_none( @@ -325,7 +334,7 @@ class TestStreamManager(BlobExchangeTestBase): async def _test_download_error_on_start(self, expected_error, timeout=None): error = None try: - await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager, timeout) + await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager, timeout) except Exception as err: if isinstance(err, asyncio.CancelledError): # TODO: remove when updated to 3.8 raise @@ -401,7 +410,7 @@ class TestStreamManager(BlobExchangeTestBase): last_blob_hash = json.loads(sdf.read())['blobs'][-2]['blob_hash'] self.server_blob_manager.delete_blob(last_blob_hash) self.client_config.blob_download_timeout = 0.1 - stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) + stream = await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager) await stream.started_writing.wait() self.assertEqual('running', stream.status) self.assertIsNotNone(stream.full_path) @@ -433,7 +442,7 @@ class TestStreamManager(BlobExchangeTestBase): self.stream_manager.analytics_manager._post = check_post self.assertDictEqual(self.stream_manager.streams, {}) - stream = await self.stream_manager.download_stream_from_uri(self.uri, self.exchange_rate_manager) + stream = await self.file_manager.download_from_uri(self.uri, self.exchange_rate_manager) await stream.finished_writing.wait() await asyncio.sleep(0, loop=self.loop) self.stream_manager.stop() diff --git a/tests/unit/test_conf.py b/tests/unit/test_conf.py index cae6e6a90..6abeef642 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/test_conf.py @@ -90,10 +90,15 @@ class ConfigurationTests(unittest.TestCase): def test_environment(self): c = TestConfig() + self.assertEqual(c.test_str, 'the default') c.set_environment({'LBRY_TEST_STR': 'from environ'}) self.assertEqual(c.test_str, 'from environ') + self.assertEqual(c.test_int, 9) + c.set_environment({'LBRY_TEST_INT': '1'}) + self.assertEqual(c.test_int, 1) + def test_persisted(self): with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/unit/wallet/test_database.py b/tests/unit/wallet/test_database.py index 7ceed23d7..b796a9c1a 100644 --- a/tests/unit/wallet/test_database.py +++ b/tests/unit/wallet/test_database.py @@ -211,6 +211,7 @@ class TestQueries(AsyncioTestCase): 'db': Database(':memory:'), 'headers': Headers(':memory:') }) + await self.ledger.headers.open() self.wallet = Wallet() await self.ledger.db.open() diff --git a/tests/unit/wallet/test_headers.py b/tests/unit/wallet/test_headers.py index e014f6e46..da433724d 100644 --- a/tests/unit/wallet/test_headers.py +++ b/tests/unit/wallet/test_headers.py @@ -21,8 +21,8 @@ class TestHeaders(AsyncioTestCase): async def test_deserialize(self): self.maxDiff = None h = Headers(':memory:') - h.io.write(HEADERS) await h.open() + await h.connect(0, HEADERS) self.assertEqual(await h.get(0), { 'bits': 520159231, 'block_height': 0, @@ -52,8 +52,11 @@ class TestHeaders(AsyncioTestCase): self.assertEqual(headers.height, 19) async def test_connect_from_middle(self): - h = Headers(':memory:') - h.io.write(HEADERS[:block_bytes(10)]) + headers_temporary_file = tempfile.mktemp() + self.addCleanup(os.remove, headers_temporary_file) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.write(HEADERS[:block_bytes(10)]) + h = Headers(headers_temporary_file) await h.open() self.assertEqual(h.height, 9) await h.connect(len(h), HEADERS[block_bytes(10):block_bytes(20)]) @@ -115,6 +118,7 @@ class TestHeaders(AsyncioTestCase): async def test_bounds(self): headers = Headers(':memory:') + await headers.open() await headers.connect(0, HEADERS) self.assertEqual(19, headers.height) with self.assertRaises(IndexError): @@ -126,6 +130,7 @@ class TestHeaders(AsyncioTestCase): async def test_repair(self): headers = Headers(':memory:') + await headers.open() await headers.connect(0, HEADERS[:block_bytes(11)]) self.assertEqual(10, headers.height) await headers.repair() @@ -147,24 +152,39 @@ class TestHeaders(AsyncioTestCase): await headers.repair(start_height=10) self.assertEqual(19, headers.height) - def test_do_not_estimate_unconfirmed(self): + async def test_do_not_estimate_unconfirmed(self): headers = Headers(':memory:') + await headers.open() self.assertIsNone(headers.estimated_timestamp(-1)) self.assertIsNone(headers.estimated_timestamp(0)) self.assertIsNotNone(headers.estimated_timestamp(1)) - async def test_misalignment_triggers_repair_on_open(self): + async def test_dont_estimate_whats_there(self): headers = Headers(':memory:') - headers.io.seek(0) - headers.io.write(HEADERS) + await headers.open() + estimated = headers.estimated_timestamp(10) + await headers.connect(0, HEADERS) + real_time = (await headers.get(10))['timestamp'] + after_downloading_header_estimated = headers.estimated_timestamp(10) + self.assertNotEqual(estimated, after_downloading_header_estimated) + self.assertEqual(after_downloading_header_estimated, real_time) + + async def test_misalignment_triggers_repair_on_open(self): + headers_temporary_file = tempfile.mktemp() + self.addCleanup(os.remove, headers_temporary_file) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.write(HEADERS) + headers = Headers(headers_temporary_file) with self.assertLogs(level='WARN') as cm: await headers.open() + await headers.close() self.assertEqual(cm.output, []) - headers.io.seek(0) - headers.io.truncate() - headers.io.write(HEADERS[:block_bytes(10)]) - headers.io.write(b'ops') - headers.io.write(HEADERS[block_bytes(10):]) + with open(headers_temporary_file, 'w+b') as headers_file: + headers_file.seek(0) + headers_file.truncate() + headers_file.write(HEADERS[:block_bytes(10)]) + headers_file.write(b'ops') + headers_file.write(HEADERS[block_bytes(10):]) await headers.open() self.assertEqual( cm.output, [ @@ -192,6 +212,7 @@ class TestHeaders(AsyncioTestCase): reader_task = asyncio.create_task(reader()) await writer() await reader_task + await headers.close() HEADERS = unhexlify( diff --git a/tests/unit/wallet/test_ledger.py b/tests/unit/wallet/test_ledger.py index 0244de987..bfe5cc71b 100644 --- a/tests/unit/wallet/test_ledger.py +++ b/tests/unit/wallet/test_ledger.py @@ -48,6 +48,8 @@ class LedgerTestCase(AsyncioTestCase): 'db': Database(':memory:'), 'headers': Headers(':memory:') }) + self.ledger.headers.checkpoints = {} + await self.ledger.headers.open() self.account = Account.generate(self.ledger, Wallet(), "lbryum") await self.ledger.db.open() diff --git a/tests/unit/wallet/test_schema_signing.py b/tests/unit/wallet/test_schema_signing.py index dbe31943e..08b61ce9d 100644 --- a/tests/unit/wallet/test_schema_signing.py +++ b/tests/unit/wallet/test_schema_signing.py @@ -21,9 +21,9 @@ def get_tx(): return Transaction().add_inputs([get_input()]) -def get_channel(claim_name='@foo'): +async def get_channel(claim_name='@foo'): channel_txo = Output.pay_claim_name_pubkey_hash(CENT, claim_name, Claim(), b'abc') - channel_txo.generate_channel_private_key() + await channel_txo.generate_channel_private_key() get_tx().add_outputs([channel_txo]) return channel_txo @@ -36,32 +36,32 @@ def get_stream(claim_name='foo'): class TestSigningAndValidatingClaim(AsyncioTestCase): - def test_successful_create_sign_and_validate(self): - channel = get_channel() + async def test_successful_create_sign_and_validate(self): + channel = await get_channel() stream = get_stream() stream.sign(channel) self.assertTrue(stream.is_signed_by(channel)) - def test_fail_to_validate_on_wrong_channel(self): + async def test_fail_to_validate_on_wrong_channel(self): stream = get_stream() - stream.sign(get_channel()) - self.assertFalse(stream.is_signed_by(get_channel())) + stream.sign(await get_channel()) + self.assertFalse(stream.is_signed_by(await get_channel())) - def test_fail_to_validate_altered_claim(self): - channel = get_channel() + async def test_fail_to_validate_altered_claim(self): + channel = await get_channel() stream = get_stream() stream.sign(channel) self.assertTrue(stream.is_signed_by(channel)) stream.claim.stream.title = 'hello' self.assertFalse(stream.is_signed_by(channel)) - def test_valid_private_key_for_cert(self): - channel = get_channel() + async def test_valid_private_key_for_cert(self): + channel = await get_channel() self.assertTrue(channel.is_channel_private_key(channel.private_key)) - def test_fail_to_load_wrong_private_key_for_cert(self): - channel = get_channel() - self.assertFalse(channel.is_channel_private_key(get_channel().private_key)) + async def test_fail_to_load_wrong_private_key_for_cert(self): + channel = await get_channel() + self.assertFalse(channel.is_channel_private_key((await get_channel()).private_key)) class TestValidatingOldSignatures(AsyncioTestCase): diff --git a/tox.ini b/tox.ini index 3b446a241..73d57410b 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,7 @@ commands = --global-option=fetch \ --global-option=--version --global-option=3.30.1 --global-option=--all \ --global-option=build --global-option=--enable --global-option=fts5 + pip install lbry-libtorrent orchstr8 download blockchain: coverage run -p --source={envsitepackagesdir}/lbry -m unittest discover -vv integration.blockchain {posargs} datanetwork: coverage run -p --source={envsitepackagesdir}/lbry -m unittest discover -vv integration.datanetwork {posargs}