diff --git a/lbrynet/conf.py b/lbrynet/conf.py index 35b20fcf6..bbd3f272e 100644 --- a/lbrynet/conf.py +++ b/lbrynet/conf.py @@ -4,65 +4,15 @@ import sys import typing import json import logging -import envparse import base58 import yaml +from contextlib import contextmanager from appdirs import user_data_dir, user_config_dir from lbrynet import utils from lbrynet.p2p.Error import InvalidCurrencyError log = logging.getLogger(__name__) -ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - -ENV_NAMESPACE = 'LBRY_' - -LBRYCRD_WALLET = 'lbrycrd' -LBRYUM_WALLET = 'lbryum' -PTC_WALLET = 'ptc' -TORBA_WALLET = 'torba' - -PROTOCOL_PREFIX = 'lbry' -APP_NAME = 'LBRY' - -LINUX = 1 -DARWIN = 2 -WINDOWS = 3 -ANDROID = 4 -KB = 2 ** 10 -MB = 2 ** 20 - -DEFAULT_CONCURRENT_ANNOUNCERS = 10 - -DEFAULT_DHT_NODES = [ - ('lbrynet1.lbry.io', 4444), # US EAST - ('lbrynet2.lbry.io', 4444), # US WEST - ('lbrynet3.lbry.io', 4444), # EU - ('lbrynet4.lbry.io', 4444) # ASIA -] - -settings_decoders = { - '.json': json.loads, - '.yml': yaml.load -} - -settings_encoders = { - '.json': json.dumps, - '.yml': yaml.safe_dump -} - -if 'ANDROID_ARGUMENT' in os.environ: - # https://github.com/kivy/kivy/blob/master/kivy/utils.py#L417-L421 - platform = ANDROID -elif 'darwin' in sys.platform: - platform = DARWIN -elif 'win' in sys.platform: - platform = WINDOWS -else: - platform = LINUX - -ICON_PATH = 'icons' if platform is WINDOWS else 'app.icns' - def get_windows_directories() -> typing.Tuple[str, str, str]: from lbrynet.winpaths import get_path, FOLDERID, UserHandle @@ -91,7 +41,6 @@ def get_darwin_directories() -> typing.Tuple[str, str, str]: def get_linux_directories() -> typing.Tuple[str, str, str]: - download_dir = None try: with open(os.path.join(user_config_dir(), 'user-dirs.dirs'), 'r') as xdg: down_dir = re.search(r'XDG_DOWNLOAD_DIR=(.+)', xdg.read()).group(1) @@ -113,152 +62,398 @@ def get_linux_directories() -> typing.Tuple[str, str, str]: return user_data_dir('lbry/lbrynet'), user_data_dir('lbry/lbryum'), download_dir -def server_port(server_and_port): - server, port = server_and_port.split(':') - return server, int(port) +NOT_SET = type(str('NoValue'), (object,), {}) +T = typing.TypeVar('T') -def server_list(servers): - return [server_port(server) for server in servers] +class Setting(typing.Generic[T]): + def __init__(self, default: typing.Optional[T]): + self.default = default -def server_list_reverse(servers): - return [f"{server}:{port}" for server, port in servers] + def __set_name__(self, owner, name): + self.name = name + def __get__(self, obj: typing.Optional['Configuration'], owner) -> T: + if obj is None: + return self + for location in obj.search_order: + if self.name in location: + return location[self.name] + return self.default -class Env(envparse.Env): - """An Env parser that automatically namespaces the variables with LBRY""" + def __set__(self, obj: 'Configuration', val: typing.Union[T, NOT_SET]): + if val == NOT_SET: + for location in obj.modify_order: + if self.name in location: + del location[self.name] + else: + self.validate(val) + for location in obj.modify_order: + location[self.name] = val - def __init__(self, **schema): - self.original_schema = schema - my_schema = { - self._convert_key(key): self._convert_value(value) - for key, value in schema.items() - } - super().__init__(**my_schema) + def validate(self, val): + raise NotImplementedError() - def __call__(self, key, *args, **kwargs): - my_key = self._convert_key(key) - return super().__call__(my_key, *args, **kwargs) + def deserialize(self, value): + return value - @staticmethod - def _convert_key(key): - return ENV_NAMESPACE + key.upper() - - @staticmethod - def _convert_value(value): - """ Allow value to be specified as a tuple or list. - - If you do this, the tuple/list must be of the - form (cast, default) or (cast, default, subcast) - """ - - if isinstance(value, (tuple, list)): - new_value = {'cast': value[0], 'default': value[1]} - if len(value) == 3: - new_value['subcast'] = value[2] - return new_value + def serialize(self, value): return value -TYPE_DEFAULT = 'default' -TYPE_PERSISTED = 'persisted' -TYPE_ENV = 'env' -TYPE_CLI = 'cli' -TYPE_RUNTIME = 'runtime' +class String(Setting[str]): + def validate(self, val): + assert isinstance(val, str), \ + f"Setting '{self.name}' must be a string." -FIXED_SETTINGS = { - 'ANALYTICS_ENDPOINT': 'https://api.segment.io/v1', - 'ANALYTICS_TOKEN': 'Ax5LZzR1o3q3Z3WjATASDwR5rKyHH0qOIRIbLmMXn2H=', - 'API_ADDRESS': 'lbryapi', - 'APP_NAME': APP_NAME, - 'BLOBFILES_DIR': 'blobfiles', - 'CRYPTSD_FILE_EXTENSION': '.cryptsd', - 'CURRENCIES': { - 'BTC': {'type': 'crypto'}, - 'LBC': {'type': 'crypto'}, - 'USD': {'type': 'fiat'}, - }, - 'DB_REVISION_FILE_NAME': 'db_revision', - 'ICON_PATH': ICON_PATH, - 'LOGGLY_TOKEN': 'BQEzZmMzLJHgAGxkBF00LGD0YGuyATVgAmqxAQEuAQZ2BQH4', - 'LOG_FILE_NAME': 'lbrynet.log', - 'LOG_POST_URL': 'https://lbry.io/log-upload', - 'MAX_BLOB_REQUEST_SIZE': 64 * KB, - 'MAX_HANDSHAKE_SIZE': 64 * KB, - 'MAX_REQUEST_SIZE': 64 * KB, - 'MAX_RESPONSE_INFO_SIZE': 64 * KB, - 'MAX_BLOB_INFOS_TO_REQUEST': 20, - 'PROTOCOL_PREFIX': PROTOCOL_PREFIX, - 'SLACK_WEBHOOK': ('nUE0pUZ6Yl9bo29epl5moTSwnl5wo20ip2IlqzywMKZiIQSFZR5' - 'AHx4mY0VmF0WQZ1ESEP9kMHZlp1WzJwWOoKN3ImR1M2yUAaMyqGZ='), - 'WALLET_TYPES': [LBRYUM_WALLET, LBRYCRD_WALLET], - 'HEADERS_FILE_SHA256_CHECKSUM': (366295, 'b0c8197153a33ccbc52fb81a279588b6015b68b7726f73f6a2b81f7e25bfe4b9') -} -ADJUSTABLE_SETTINGS = { - 'data_dir': (str, ''), # these blank defaults will be updated to OS specific defaults - 'wallet_dir': (str, ''), - 'lbryum_wallet_dir': (str, ''), # to be deprecated - 'download_directory': (str, ''), +class Integer(Setting[int]): + def validate(self, val): + assert isinstance(val, int), \ + f"Setting '{self.name}' must be an integer." - # By default, daemon will block all cross origin requests - # but if this is set, this value will be used for the - # Access-Control-Allow-Origin. For example - # set to '*' to allow all requests, or set to 'http://localhost:8080' - # if you're running a test UI on that port - 'allowed_origin': (str, ''), + +class Float(Setting[float]): + def validate(self, val): + assert isinstance(val, float), \ + f"Setting '{self.name}' must be a decimal." + + +class Toggle(Setting[bool]): + def validate(self, val): + assert isinstance(val, bool), \ + f"Setting '{self.name}' must be a true/false value." + + +class Path(String): + def __init__(self): + super().__init__('') + + def __get__(self, obj, owner): + value = super().__get__(obj, owner) + if isinstance(value, str): + return os.path.expanduser(os.path.expandvars(value)) + return value + + +class MaxKeyFee(Setting[dict]): + + def validate(self, value): + assert isinstance(value, dict), \ + f"Setting '{self.name}' must be of the format \"{'currency': 'USD', 'amount': 50.0}\"." + assert set(value) == {'currency', 'amount'}, \ + f"Setting '{self.name}' must contain a 'currency' and an 'amount' field." + currency = str(value["currency"]).upper() + if currency not in CURRENCIES: + raise InvalidCurrencyError(currency) + + serialize = staticmethod(json.dumps) + deserialize = staticmethod(json.loads) + + +class Servers(Setting[list]): + + def validate(self, val): + assert isinstance(val, (tuple, list)), \ + f"Setting '{self.name}' must be a tuple or list of servers." + for idx, server in enumerate(val): + assert isinstance(server, (tuple, list)) and len(server) == 2, \ + f"Server defined '{server}' at index {idx} in setting " \ + f"'{self.name}' must be a tuple or list of two items." + assert isinstance(server[0], str), \ + f"Server defined '{server}' at index {idx} in setting " \ + f"'{self.name}' must be have hostname as string in first position." + assert isinstance(server[1], int), \ + f"Server defined '{server}' at index {idx} in setting " \ + f"'{self.name}' must be have port as int in second position." + + def deserialize(self, value): + servers = [] + if isinstance(value, list): + for server in value: + if isinstance(server, str) and server.count(':') == 1: + host, port = server.split(':') + try: + servers.append((host, int(port))) + except ValueError: + pass + return servers + + def serialize(self, value): + if value: + return [f"{host}:{port}" for host, port in value] + return value + + +class Strings(Setting[list]): + + def validate(self, val): + assert isinstance(val, (tuple, list)), \ + f"Setting '{self.name}' must be a tuple or list of strings." + for idx, string in enumerate(val): + assert isinstance(string, str), \ + f"Value of '{string}' at index {idx} in setting " \ + f"'{self.name}' must be a string." + + +class EnvironmentAccess: + PREFIX = 'LBRY_' + + def __init__(self, environ: dict): + self.environ = environ + + def __contains__(self, item: str): + return f'{self.PREFIX}{item.upper()}' in self.environ + + def __getitem__(self, item: str): + return self.environ[f'{self.PREFIX}{item.upper()}'] + + +class ArgumentAccess: + + def __init__(self, args: dict): + self.args = args + + def __contains__(self, item: str): + return getattr(self.args, item, None) is not None + + def __getitem__(self, item: str): + return getattr(self.args, item) + + +class ConfigFileAccess: + + def __init__(self, config: 'Configuration', path: str): + self.configuration = config + self.path = path + self.data = {} + if self.exists: + self.load() + + @property + def exists(self): + return self.path and os.path.exists(self.path) + + def load(self): + cls = type(self.configuration) + with open(self.path, 'r') as config_file: + raw = config_file.read() + serialized = yaml.load(raw) or {} + for key, value in serialized.items(): + attr = getattr(cls, key, None) + if attr is not None: + self.data[key] = attr.deserialize(value) + + def save(self): + cls = type(self.configuration) + serialized = {} + for key, value in self.data.items(): + attr = getattr(cls, key) + serialized[key] = attr.serialize(value) + with open(self.path, 'w') as config_file: + config_file.write(yaml.safe_dump(serialized, default_flow_style=False)) + + def __contains__(self, item: str): + return item in self.data + + def __getitem__(self, item: str): + return self.data[item] + + def __setitem__(self, key, value): + self.data[key] = value + + def __delitem__(self, key): + del self.data[key] + + +class Configuration: + + config = Path() + + data_dir = Path() + wallet_dir = Path() + lbryum_wallet_dir = Path() + download_dir = Path() # Changing this value is not-advised as it could potentially # expose the lbrynet daemon to the outside world which would # give an attacker access to your wallet and you could lose # all of your credits. - 'api_host': (str, 'localhost'), - 'api_port': (int, 5279), + api_host = String('localhost') + api_port = Integer(5279) + + share_usage_data = Toggle(True) # whether to share usage stats and diagnostic info with LBRY + + def __init__(self): + self.runtime = {} # set internally or by various API calls + self.arguments = {} # from command line arguments + self.environment = {} # from environment variables + self.persisted = {} # from config file + self.set_default_paths() + self._updating_config = False + + @contextmanager + def update_config(self): + if not isinstance(self.persisted, ConfigFileAccess): + raise TypeError("Config file cannot be updated.") + self._updating_config = True + yield self + self._updating_config = False + self.persisted.save() + + @property + def modify_order(self): + locations = [self.runtime] + if self._updating_config: + locations.append(self.persisted) + return locations + + @property + def search_order(self): + return [ + self.runtime, + self.arguments, + self.environment, + self.persisted + ] + + def set_default_paths(self): + if 'win' in sys.platform: + get_directories = get_windows_directories + elif 'darwin' in sys.platform: + get_directories = get_darwin_directories + elif 'linux' in sys.platform: + get_directories = get_linux_directories + else: + return + cls = type(self) + cls.data_dir.default, cls.wallet_dir.default, cls.download_dir.default = get_directories() + cls.config.default = os.path.join( + self.data_dir, 'daemon_settings.yml' + ) + + @classmethod + def create_from_arguments(cls, args): + conf = cls() + conf.set_arguments(args) + conf.set_environment() + conf.set_persisted() + return conf + + def set_arguments(self, args): + self.arguments = ArgumentAccess(args) + + def set_environment(self, environ=None): + self.environment = EnvironmentAccess(environ or os.environ) + + def set_persisted(self, config_file_path=None): + if config_file_path is None: + config_file_path = self.config + + ext = os.path.splitext(config_file_path)[1] + assert ext in ('.yml', '.yaml'),\ + f"File extension '{ext}' is not supported, " \ + f"configuration file must be in YAML (.yaml)." + + self.persisted = ConfigFileAccess(self, config_file_path) + + +class CommandLineConfiguration(Configuration): + pass + + +class ServerConfiguration(Configuration): + # claims set to expire within this many blocks will be # automatically renewed after startup (if set to 0, renews # will not be made automatically) - 'auto_renew_claim_height_delta': (int, 0), - 'cache_time': (int, 150), - 'data_rate': (float, .0001), # points/megabyte - 'delete_blobs_on_remove': (bool, True), - 'dht_node_port': (int, 4444), - 'download_timeout': (int, 180), - 'download_mirrors': (list, ['blobs.lbry.io']), - 'is_generous_host': (bool, True), - 'announce_head_blobs_only': (bool, True), - 'concurrent_announcers': (int, DEFAULT_CONCURRENT_ANNOUNCERS), - 'known_dht_nodes': (list, DEFAULT_DHT_NODES, server_list, server_list_reverse), - 'max_connections_per_stream': (int, 5), - 'seek_head_blob_first': (bool, True), + auto_renew_claim_height_delta = Integer(0) + cache_time = Integer(150) + data_rate = Float(.0001) # points/megabyte + delete_blobs_on_remove = Toggle(True) + dht_node_port = Integer(4444) + download_timeout = Integer(180) + download_mirrors = Servers([ + ('blobs.lbry.io', 80) + ]) + is_generous_host = Toggle(True) + announce_head_blobs_only = Toggle(True) + concurrent_announcers = Integer(10) + known_dht_nodes = Servers([ + ('lbrynet1.lbry.io', 4444), # US EAST + ('lbrynet2.lbry.io', 4444), # US WEST + ('lbrynet3.lbry.io', 4444), # EU + ('lbrynet4.lbry.io', 4444) # ASIA + ]) + max_connections_per_stream = Integer(5) + seek_head_blob_first = Toggle(True) # TODO: writing json on the cmd line is a pain, come up with a nicer # parser for this data structure. maybe 'USD:25' - 'max_key_fee': (json.loads, {'currency': 'USD', 'amount': 50.0}), - 'disable_max_key_fee': (bool, False), - 'min_info_rate': (float, .02), # points/1000 infos - 'min_valuable_hash_rate': (float, .05), # points/1000 infos - 'min_valuable_info_rate': (float, .05), # points/1000 infos - 'peer_port': (int, 3333), - 'pointtrader_server': (str, 'http://127.0.0.1:2424'), - 'reflector_port': (int, 5566), + max_key_fee = MaxKeyFee({'currency': 'USD', 'amount': 50.0}) + disable_max_key_fee = Toggle(False) + min_info_rate = Float(.02) # points/1000 infos + min_valuable_hash_rate = Float(.05) # points/1000 infos + min_valuable_info_rate = Float(.05) # points/1000 infos + peer_port = Integer(3333) + pointtrader_server = String('http://127.0.0.1:2424') + reflector_port = Integer(5566) # if reflect_uploads is True, send files to reflector after publishing (as well as a periodic check in the # event the initial upload failed or was disconnected part way through, provided the auto_re_reflect_interval > 0) - 'reflect_uploads': (bool, True), - 'auto_re_reflect_interval': (int, 86400), # set to 0 to disable - 'reflector_servers': (list, [('reflector.lbry.io', 5566)], server_list, server_list_reverse), - 'run_reflector_server': (bool, False), # adds `reflector` to components_to_skip unless True - 'sd_download_timeout': (int, 3), - 'share_usage_data': (bool, True), # whether to share usage stats and diagnostic info with LBRY - 'peer_search_timeout': (int, 60), - 'use_upnp': (bool, True), - 'use_keyring': (bool, False), - 'wallet': (str, LBRYUM_WALLET), - 'blockchain_name': (str, 'lbrycrd_main'), - 'lbryum_servers': (list, [('lbryumx1.lbry.io', 50001), ('lbryumx2.lbry.io', - 50001)], server_list, server_list_reverse), - 's3_headers_depth': (int, 96 * 10), # download headers from s3 when the local height is more than 10 chunks behind - 'components_to_skip': (list, []) # components which will be skipped during start-up of daemon + reflect_uploads = Toggle(True) + auto_re_reflect_interval = Integer(86400) # set to 0 to disable + reflector_servers = Servers([ + ('reflector.lbry.io', 5566) + ]) + run_reflector_server = Toggle(False) # adds `reflector` to components_to_skip unless True + sd_download_timeout = Integer(3) + peer_search_timeout = Integer(60) + use_upnp = Toggle(True) + use_keyring = Toggle(False) + blockchain_name = String('lbrycrd_main') + lbryum_servers = Servers([ + ('lbryumx1.lbry.io', 50001), + ('lbryumx2.lbry.io', 50001) + ]) + s3_headers_depth = Integer(96 * 10) # download headers from s3 when the local height is more than 10 chunks behind + components_to_skip = Strings([]) # components which will be skipped during start-up of daemon + + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +KB = 2 ** 10 +MB = 2 ** 20 + + +ANALYTICS_ENDPOINT = 'https://api.segment.io/v1' +ANALYTICS_TOKEN = 'Ax5LZzR1o3q3Z3WjATASDwR5rKyHH0qOIRIbLmMXn2H=' +API_ADDRESS = 'lbryapi' +APP_NAME = 'LBRY' +BLOBFILES_DIR = 'blobfiles' +CRYPTSD_FILE_EXTENSION = '.cryptsd' +CURRENCIES = { + 'BTC': {'type': 'crypto'}, + 'LBC': {'type': 'crypto'}, + 'USD': {'type': 'fiat'}, } +DB_REVISION_FILE_NAME = 'db_revision' +ICON_PATH = 'icons' if 'win' in sys.platform else 'app.icns' +LOGGLY_TOKEN = 'BQEzZmMzLJHgAGxkBF00LGD0YGuyATVgAmqxAQEuAQZ2BQH4' +LOG_FILE_NAME = 'lbrynet.log' +LOG_POST_URL = 'https://lbry.io/log-upload' +MAX_BLOB_REQUEST_SIZE = 64 * KB +MAX_HANDSHAKE_SIZE = 64 * KB +MAX_REQUEST_SIZE = 64 * KB +MAX_RESPONSE_INFO_SIZE = 64 * KB +MAX_BLOB_INFOS_TO_REQUEST = 20 +PROTOCOL_PREFIX = 'lbry' +SLACK_WEBHOOK = ( + 'nUE0pUZ6Yl9bo29epl5moTSwnl5wo20ip2IlqzywMKZiIQSFZR5' + 'AHx4mY0VmF0WQZ1ESEP9kMHZlp1WzJwWOoKN3ImR1M2yUAaMyqGZ=' +) +HEADERS_FILE_SHA256_CHECKSUM = ( + 366295, 'b0c8197153a33ccbc52fb81a279588b6015b68b7726f73f6a2b81f7e25bfe4b9' +) optional_str = typing.Optional[str] @@ -277,21 +472,6 @@ class Config: # copy the default adjustable settings self._adjustable_defaults = {k: v for k, v in adjustable_defaults.items()} - # set the os specific default directories - if platform is WINDOWS: - self.default_data_dir, self.default_wallet_dir, self.default_download_dir = get_windows_directories() - elif platform is DARWIN: - self.default_data_dir, self.default_wallet_dir, self.default_download_dir = get_darwin_directories() - elif platform is LINUX: - self.default_data_dir, self.default_wallet_dir, self.default_download_dir = get_linux_directories() - else: - assert None not in [data_dir, wallet_dir, download_dir] - if data_dir: - self.default_data_dir = data_dir - if wallet_dir: - self.default_wallet_dir = wallet_dir - if download_dir: - self.default_download_dir = download_dir self._data = { TYPE_DEFAULT: {}, # defaults @@ -459,7 +639,7 @@ class Config: return self._data[possible_data_type][name] raise KeyError(f'{name} is not a valid setting') - def set(self, name, value, data_types=(TYPE_RUNTIME,)): + def set(self, name, value, data_types): """Set a config value Args: @@ -480,7 +660,7 @@ class Config: self._assert_valid_data_type(data_type) self._data[data_type][name] = value - def update(self, updated_settings, data_types=(TYPE_RUNTIME,)): + def update(self, updated_settings): for k, v in updated_settings.items(): try: self.set(k, v, data_types=data_types) diff --git a/lbrynet/extras/cli.py b/lbrynet/extras/cli.py index a85e8f8b6..91524bb5d 100644 --- a/lbrynet/extras/cli.py +++ b/lbrynet/extras/cli.py @@ -5,12 +5,6 @@ import asyncio import argparse import typing -# Set SSL_CERT_FILE env variable for Twisted SSL verification on Windows -# This needs to happen before anything else -if 'win' in sys.platform: - import certifi - os.environ['SSL_CERT_FILE'] = certifi.where() - from twisted.internet import asyncioreactor if 'twisted.internet.reactor' not in sys.modules: asyncioreactor.install() diff --git a/tests/unit/test_conf.py b/tests/unit/test_conf.py index 94f18e2c6..ce76f13c8 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/test_conf.py @@ -1,121 +1,156 @@ import os import json import sys +import types import tempfile import shutil -from unittest import skipIf -from twisted.trial import unittest -from twisted.internet import defer +import unittest +import argparse from lbrynet import conf from lbrynet.p2p.Error import InvalidCurrencyError -class SettingsTest(unittest.TestCase): - def setUp(self): - os.environ['LBRY_TEST'] = 'test_string' +class TestConfig(conf.Configuration): + test = conf.String('the default') + test_int = conf.Integer(9) + test_toggle = conf.Toggle(False) + servers = conf.Servers([('localhost', 80)]) - def tearDown(self): - del os.environ['LBRY_TEST'] - def get_mock_config_instance(self): - settings = {'test': (str, '')} - env = conf.Env(**settings) - self.tmp_dir = tempfile.mkdtemp() - self.addCleanup(lambda : defer.succeed(shutil.rmtree(self.tmp_dir))) - return conf.Config({}, settings, environment=env, data_dir=self.tmp_dir, wallet_dir=self.tmp_dir, download_dir=self.tmp_dir) +class ConfigurationTests(unittest.TestCase): - def test_envvar_is_read(self): - settings = self.get_mock_config_instance() - self.assertEqual('test_string', settings['test']) + @unittest.skipIf('linux' not in sys.platform, 'skipping linux only test') + def test_linux_defaults(self): + c = TestConfig() + self.assertEqual(c.data_dir, os.path.expanduser('~/.local/share/lbry/lbrynet')) + self.assertEqual(c.wallet_dir, os.path.expanduser('~/.local/share/lbry/lbryum')) + self.assertEqual(c.download_dir, os.path.expanduser('~/Downloads')) + self.assertEqual(c.config, os.path.expanduser('~/.local/share/lbry/lbrynet/daemon_settings.yml')) - def test_setting_can_be_overridden(self): - settings = self.get_mock_config_instance() - settings['test'] = 'my_override' - self.assertEqual('my_override', settings['test']) + def test_search_order(self): + c = TestConfig() + c.runtime = {'test': 'runtime'} + c.arguments = {'test': 'arguments'} + c.environment = {'test': 'environment'} + c.persisted = {'test': 'persisted'} + self.assertEqual(c.test, 'runtime') + c.runtime = {} + self.assertEqual(c.test, 'arguments') + c.arguments = {} + self.assertEqual(c.test, 'environment') + c.environment = {} + self.assertEqual(c.test, 'persisted') + c.persisted = {} + self.assertEqual(c.test, 'the default') - def test_setting_can_be_updated(self): - settings = self.get_mock_config_instance() - settings.update({'test': 'my_update'}) - self.assertEqual('my_update', settings['test']) + def test_arguments(self): + parser = argparse.ArgumentParser() + parser.add_argument("--test") + args = parser.parse_args(['--test', 'blah']) + c = TestConfig.create_from_arguments(args) + self.assertEqual(c.test, 'blah') + c.arguments = {} + self.assertEqual(c.test, 'the default') - def test_setting_is_in_dict(self): - settings = self.get_mock_config_instance() - setting_dict = settings.get_current_settings_dict() - self.assertEqual({'test': 'test_string'}, setting_dict) + def test_environment(self): + c = TestConfig() + self.assertEqual(c.test, 'the default') + c.set_environment({'LBRY_TEST': 'from environ'}) + self.assertEqual(c.test, 'from environ') - def test_invalid_setting_raises_exception(self): - settings = self.get_mock_config_instance() - self.assertRaises(KeyError, settings.set, 'invalid_name', 123) + def test_persisted(self): + with tempfile.TemporaryDirectory() as temp_dir: - def test_invalid_data_type_raises_exception(self): - settings = self.get_mock_config_instance() - self.assertIsNone(settings.set('test', 123)) - self.assertRaises(KeyError, settings.set, 'test', 123, ('fake_data_type',)) + c = TestConfig.create_from_arguments( + types.SimpleNamespace(config=os.path.join(temp_dir, 'settings.yml')) + ) - def test_setting_precedence(self): - settings = self.get_mock_config_instance() - settings.set('test', 'cli_test_string', data_types=(conf.TYPE_CLI,)) - self.assertEqual('cli_test_string', settings['test']) - settings.set('test', 'this_should_not_take_precedence', data_types=(conf.TYPE_ENV,)) - self.assertEqual('cli_test_string', settings['test']) - settings.set('test', 'runtime_takes_precedence', data_types=(conf.TYPE_RUNTIME,)) - self.assertEqual('runtime_takes_precedence', settings['test']) + # settings.yml doesn't exist on file system + self.assertFalse(c.persisted.exists) + self.assertEqual(c.test, 'the default') - def test_max_key_fee_set(self): - fixed_default = {'CURRENCIES':{'BTC':{'type':'crypto'}}} - adjustable_settings = {'max_key_fee': (json.loads, {'currency':'USD', 'amount':1})} - env = conf.Env(**adjustable_settings) - settings = conf.Config(fixed_default, adjustable_settings, environment=env) + self.assertEqual(c.modify_order, [c.runtime]) + with c.update_config(): + self.assertEqual(c.modify_order, [c.runtime, c.persisted]) + c.test = 'new value' + self.assertEqual(c.modify_order, [c.runtime]) - with self.assertRaises(InvalidCurrencyError): - settings.set('max_key_fee', {'currency':'USD', 'amount':1}) + # share_usage_data has been saved to settings file + self.assertTrue(c.persisted.exists) + with open(c.config, 'r') as fd: + self.assertEqual(fd.read(), 'test: new value\n') - valid_setting = {'currency':'BTC', 'amount':1} - settings.set('max_key_fee', valid_setting) - out = settings.get('max_key_fee') - self.assertEqual(out, valid_setting) + # load the settings file and check share_usage_data is false + c = TestConfig.create_from_arguments( + types.SimpleNamespace(config=os.path.join(temp_dir, 'settings.yml')) + ) + self.assertTrue(c.persisted.exists) + self.assertEqual(c.test, 'new value') - def test_data_dir(self): - # check if these directories are returned as string and not unicode - # otherwise there will be problems when calling os.path.join on - # unicode directory names with string file names - settings = conf.Config({}, {}) - self.assertEqual(str, type(settings.download_dir)) - self.assertEqual(str, type(settings.data_dir)) - self.assertEqual(str, type(settings.wallet_dir)) + # setting in runtime overrides config + self.assertNotIn('test', c.runtime) + c.test = 'from runtime' + self.assertIn('test', c.runtime) + self.assertEqual(c.test, 'from runtime') - @skipIf('win' in sys.platform, 'fix me!') - def test_load_save_config_file(self): - # setup settings - adjustable_settings = {'lbryum_servers': (list, [])} - env = conf.Env(**adjustable_settings) - settings = conf.Config({}, adjustable_settings, environment=env) - conf.settings = settings - # setup tempfile - conf_entry = b"lbryum_servers: ['localhost:50001', 'localhost:50002']\n" - with tempfile.NamedTemporaryFile(suffix='.yml') as conf_file: - conf_file.write(conf_entry) - conf_file.seek(0) - conf.conf_file = conf_file.name - # load and save settings from conf file - settings.load_conf_file_settings() - settings.save_conf_file_settings() - # test if overwritten entry equals original entry - # use decoded versions, because format might change without - # changing the interpretation - decoder = conf.settings_decoders['.yml'] - conf_decoded = decoder(conf_entry) - conf_entry_new = conf_file.read() - conf_decoded_new = decoder(conf_entry_new) - self.assertEqual(conf_decoded, conf_decoded_new) + # NOT_SET only clears it in runtime location + c.test = conf.NOT_SET + self.assertNotIn('test', c.runtime) + self.assertEqual(c.test, 'new value') - def test_load_file(self): - settings = self.get_mock_config_instance() + # clear it in persisted as well + self.assertIn('test', c.persisted) + with c.update_config(): + c.test = conf.NOT_SET + self.assertNotIn('test', c.persisted) + self.assertEqual(c.test, 'the default') + with open(c.config, 'r') as fd: + self.assertEqual(fd.read(), '{}\n') - # invalid extensions - for filename in ('monkey.yymmll', 'monkey'): - settings.file_name = filename - with open(os.path.join(self.tmp_dir, filename), "w"): - pass - with self.assertRaises(ValueError): - settings.load_conf_file_settings() + def test_validation(self): + c = TestConfig() + with self.assertRaisesRegex(AssertionError, 'must be a string'): + c.test = 9 + with self.assertRaisesRegex(AssertionError, 'must be an integer'): + c.test_int = 'hi' + with self.assertRaisesRegex(AssertionError, 'must be a true/false'): + c.test_toggle = 'hi' + + def test_file_extension_validation(self): + with self.assertRaisesRegex(AssertionError, "'.json' is not supported"): + TestConfig.create_from_arguments( + types.SimpleNamespace(config=os.path.join('settings.json')) + ) + + def test_serialize_deserialize(self): + with tempfile.TemporaryDirectory() as temp_dir: + c = TestConfig.create_from_arguments( + types.SimpleNamespace(config=os.path.join(temp_dir, 'settings.yml')) + ) + self.assertEqual(c.servers, [('localhost', 80)]) + with c.update_config(): + c.servers = [('localhost', 8080)] + with open(c.config, 'r+') as fd: + self.assertEqual(fd.read(), 'servers:\n- localhost:8080\n') + fd.write('servers:\n - localhost:5566\n') + c = TestConfig.create_from_arguments( + types.SimpleNamespace(config=os.path.join(temp_dir, 'settings.yml')) + ) + self.assertEqual(c.servers, [('localhost', 5566)]) + + def test_max_key_fee(self): + with tempfile.TemporaryDirectory() as temp_dir: + config = os.path.join(temp_dir, 'settings.yml') + with open(config, 'w') as fd: + fd.write('max_key_fee: \'{"currency":"USD", "amount":1}\'\n') + c = conf.ServerConfiguration.create_from_arguments( + types.SimpleNamespace(config=config) + ) + self.assertEqual(c.max_key_fee['currency'], 'USD') + self.assertEqual(c.max_key_fee['amount'], 1) + with self.assertRaises(InvalidCurrencyError): + c.max_key_fee = {'currency': 'BCH', 'amount': 1} + with c.update_config(): + c.max_key_fee = {'currency': 'BTC', 'amount': 1} + with open(config, 'r') as fd: + self.assertEqual(fd.read(), 'max_key_fee: \'{"currency": "BTC", "amount": 1}\'\n')