diff --git a/lbrynet/extras/daemon/Daemon.py b/lbrynet/extras/daemon/Daemon.py index 9d836cf5d..71a51aab3 100644 --- a/lbrynet/extras/daemon/Daemon.py +++ b/lbrynet/extras/daemon/Daemon.py @@ -2177,13 +2177,13 @@ class Daemon(metaclass=JSONRPCServerType): ) account: LBCAccount = await self.ledger.get_account_for_address(data['holding_address']) if not account: - new_account = LBCAccount.from_dict(self.ledger, self.default_wallet, { + account = LBCAccount.from_dict(self.ledger, self.default_wallet, { 'name': f"Holding Account For Channel {data['name']}", 'public_key': data['holding_public_key'], 'address_generator': {'name': 'single-address'} }) if self.ledger.network.is_connected: - asyncio.create_task(self.ledger.subscribe_account(new_account)) + await self.ledger.subscribe_account(account) account.add_channel_private_key(channel_private_key) self.default_wallet.save() return f"Added channel signing key for {data['name']}." diff --git a/lbrynet/testcase.py b/lbrynet/testcase.py index 3ecbe5b12..b4fa27207 100644 --- a/lbrynet/testcase.py +++ b/lbrynet/testcase.py @@ -4,7 +4,7 @@ import tempfile import logging from binascii import unhexlify -from torba.testcase import IntegrationTestCase +from torba.testcase import IntegrationTestCase, WalletNode import lbrynet.wallet @@ -71,18 +71,10 @@ class CommandTestCase(IntegrationTestCase): logging.getLogger('lbrynet.daemon').setLevel(self.VERBOSITY) logging.getLogger('lbrynet.stream').setLevel(self.VERBOSITY) - conf = Config() - conf.data_dir = self.wallet_node.data_path - conf.wallet_dir = self.wallet_node.data_path - conf.download_dir = self.wallet_node.data_path - conf.share_usage_data = False - conf.use_upnp = False - conf.reflect_streams = True - conf.blockchain_name = 'lbrycrd_regtest' - conf.lbryum_servers = [('127.0.0.1', 50001)] - conf.reflector_servers = [('127.0.0.1', 5566)] - conf.known_dht_nodes = [] - conf.blob_lru_cache_size = self.blob_lru_cache_size + self.daemons = [] + self.extra_wallet_nodes = [] + self.extra_wallet_node_port = 5280 + self.daemon = await self.add_daemon(self.wallet_node) await self.account.ensure_address_gap() address = (await self.account.receiving.get_addresses(limit=1, only_usable=True))[0] @@ -90,23 +82,6 @@ class CommandTestCase(IntegrationTestCase): await self.confirm_tx(sendtxid) await self.generate(5) - def wallet_maker(component_manager): - self.wallet_component = WalletComponent(component_manager) - self.wallet_component.wallet_manager = self.manager - self.wallet_component._running = True - return self.wallet_component - - conf.components_to_skip = [ - DHT_COMPONENT, UPNP_COMPONENT, HASH_ANNOUNCER_COMPONENT, - PEER_PROTOCOL_SERVER_COMPONENT - ] - self.daemon = Daemon(conf, ComponentManager( - conf, skip_components=conf.components_to_skip, wallet=wallet_maker, - exchange_rate_manager=ExchangeRateManagerComponent - )) - await self.daemon.initialize() - self.manager.old_db = self.daemon.storage - server_tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, server_tmp_dir) self.server_config = Config() @@ -125,8 +100,53 @@ class CommandTestCase(IntegrationTestCase): async def asyncTearDown(self): await super().asyncTearDown() - self.wallet_component._running = False - await self.daemon.stop(shutdown_runner=False) + for wallet_node in self.extra_wallet_nodes: + await wallet_node.stop(cleanup=True) + for daemon in self.daemons: + daemon.component_manager.get_component('wallet')._running = False + await daemon.stop(shutdown_runner=False) + + async def add_daemon(self, wallet_node=None, seed=None): + if wallet_node is None: + wallet_node = WalletNode( + self.wallet_node.manager_class, + self.wallet_node.ledger_class, + port=self.extra_wallet_node_port + ) + self.extra_wallet_node_port += 1 + await wallet_node.start(self.conductor.spv_node, seed=seed) + self.extra_wallet_nodes.append(wallet_node) + + conf = Config() + conf.data_dir = wallet_node.data_path + conf.wallet_dir = wallet_node.data_path + conf.download_dir = wallet_node.data_path + conf.share_usage_data = False + conf.use_upnp = False + conf.reflect_streams = True + conf.blockchain_name = 'lbrycrd_regtest' + conf.lbryum_servers = [('127.0.0.1', 50001)] + conf.reflector_servers = [('127.0.0.1', 5566)] + conf.known_dht_nodes = [] + conf.blob_lru_cache_size = self.blob_lru_cache_size + conf.components_to_skip = [ + DHT_COMPONENT, UPNP_COMPONENT, HASH_ANNOUNCER_COMPONENT, + PEER_PROTOCOL_SERVER_COMPONENT + ] + + def wallet_maker(component_manager): + wallet_component = WalletComponent(component_manager) + wallet_component.wallet_manager = wallet_node.manager + wallet_component._running = True + return wallet_component + + daemon = Daemon(conf, ComponentManager( + conf, skip_components=conf.components_to_skip, wallet=wallet_maker, + exchange_rate_manager=ExchangeRateManagerComponent + )) + await daemon.initialize() + wallet_node.manager.old_db = daemon.storage + return daemon async def confirm_tx(self, txid): """ Wait for tx to be in mempool, then generate a block, wait for tx to be in a block. """ diff --git a/tests/integration/test_claim_commands.py b/tests/integration/test_claim_commands.py index e2bae4487..286be70e1 100644 --- a/tests/integration/test_claim_commands.py +++ b/tests/integration/test_claim_commands.py @@ -306,31 +306,51 @@ class ChannelCommands(CommandTestCase): txo = (await account2.get_channels())[0] self.assertIsNotNone(txo.private_key) - async def test_channel_export_import_without_password(self): + async def test_channel_export_import_into_new_account(self): tx = await self.channel_create('@foo', '1.0') claim_id = tx['outputs'][0]['claim_id'] channel_private_key = (await self.account.get_channels())[0].private_key + exported_data = await self.out(self.daemon.jsonrpc_channel_export(claim_id)) - _account2 = await self.out(self.daemon.jsonrpc_account_create("Account 2")) - account2_id, account2 = _account2["id"], self.daemon.get_account_or_error(_account2['id']) + daemon2 = await self.add_daemon() - # before exporting/importing channel - self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 0) + # before importing channel + self.assertEqual(1, len(daemon2.default_wallet.accounts)) - # exporting from default account - serialized_channel_info = await self.out(self.daemon.jsonrpc_channel_export(claim_id)) + # importing channel which will create a new single key account + await daemon2.jsonrpc_channel_import(exported_data) - other_address = await account2.receiving.get_or_create_usable_address() - await self.out(self.channel_update(claim_id, claim_address=other_address)) + # after import + self.assertEqual(2, len(daemon2.default_wallet.accounts)) + new_account = daemon2.default_wallet.accounts[1] + await daemon2.ledger._update_tasks.done.wait() + channels = await new_account.get_channels() + self.assertEqual(1, len(channels)) + self.assertEqual(channel_private_key.to_string(), channels[0].private_key.to_string()) - # importing into second account - await self.daemon.jsonrpc_channel_import(serialized_channel_info, password=None, account_id=account2_id) + async def test_channel_export_import_into_existing_account(self): + tx = await self.channel_create('@foo', '1.0') + claim_id = tx['outputs'][0]['claim_id'] + channel_private_key = (await self.account.get_channels())[0].private_key + exported_data = await self.out(self.daemon.jsonrpc_channel_export(claim_id)) - # after exporting/importing channel - self.assertEqual(len(await self.daemon.jsonrpc_channel_list(account_id=account2_id)), 1) - txo_channel_account2 = (await account2.get_channels())[0] + daemon2 = await self.add_daemon(seed=self.account.seed) + await daemon2.ledger._update_tasks.done.wait() # will sync channel previously created - self.assertEqual(channel_private_key, txo_channel_account2.private_key) + # before importing channel key, has channel without key + self.assertEqual(1, len(daemon2.default_wallet.accounts)) + channels = await daemon2.default_account.get_channels() + self.assertEqual(1, len(channels)) + self.assertIsNone(channels[0].private_key) + + # importing channel will add it to existing account + await daemon2.jsonrpc_channel_import(exported_data) + + # after import, still just one account but with private key now + self.assertEqual(1, len(daemon2.default_wallet.accounts)) + channels = await daemon2.default_account.get_channels() + self.assertEqual(1, len(channels)) + self.assertEqual(channel_private_key.to_string(), channels[0].private_key.to_string()) class StreamCommands(CommandTestCase):