diff --git a/tests/unit/test_wallet.py b/tests/unit/test_wallet.py index 86521fdd5..5ca3bd904 100644 --- a/tests/unit/test_wallet.py +++ b/tests/unit/test_wallet.py @@ -1,3 +1,4 @@ +import tempfile from twisted.trial import unittest from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger @@ -36,11 +37,11 @@ class TestWalletCreation(unittest.TestCase): "h absent", 'encrypted': False, 'private_key': - b'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' - b'6yz3jMbycrLrRMpeAJxR8qDg8', + 'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' + '6yz3jMbycrLrRMpeAJxR8qDg8', 'public_key': - b'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' - b'iW44g14WF52fYC5J483wqQ5ZP', + 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' + 'iW44g14WF52fYC5J483wqQ5ZP', 'receiving_gap': 10, 'receiving_maximum_use_per_address': 2, 'change_gap': 10, @@ -57,3 +58,24 @@ class TestWalletCreation(unittest.TestCase): self.assertIsInstance(account, BTCLedger.account_class) self.maxDiff = None self.assertDictEqual(wallet_dict, wallet.to_dict()) + + def test_read_write(self): + manager = WalletManager() + config = {'wallet_path': '/tmp/wallet'} + ledger = manager.get_or_create_ledger(BTCLedger.get_id(), config) + + with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file: + wallet_file.write(b'{}') + wallet_file.seek(0) + + # create and write wallet to a file + wallet_storage = WalletStorage(wallet_file.name) + wallet = Wallet.from_storage(wallet_storage, manager) + account = wallet.generate_account(ledger) + wallet.save() + + # read wallet from file + wallet_storage = WalletStorage(wallet_file.name) + wallet = Wallet.from_storage(wallet_storage, manager) + + self.assertEqual(account.public_key.address, wallet.default_account.public_key.address) diff --git a/torba/baseaccount.py b/torba/baseaccount.py index f753732df..d2a5b0c82 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -133,8 +133,8 @@ class BaseAccount: 'seed': self.seed, 'encrypted': self.encrypted, 'private_key': self.private_key if self.encrypted else - self.private_key.extended_key_string(), - 'public_key': self.public_key.extended_key_string(), + self.private_key.extended_key_string().decode(), + 'public_key': self.public_key.extended_key_string().decode(), 'receiving_gap': self.receiving.gap, 'change_gap': self.change.gap, 'receiving_maximum_use_per_address': self.receiving.maximum_use_per_address, diff --git a/torba/baseledger.py b/torba/baseledger.py index 072825b49..2e9e53463 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -2,7 +2,7 @@ import os import six import hashlib from binascii import hexlify, unhexlify -from typing import Dict, Type +from typing import Dict, Type, Iterable, Generator from operator import itemgetter from twisted.internet import defer @@ -126,6 +126,16 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): def get_unspent_outputs(self, account): return self.db.get_utxos(account, self.transaction_class.output_class) + @defer.inlineCallbacks + def get_effective_amount_estimators(self, funding_accounts): + # type: (Iterable[baseaccount.BaseAccount]) -> defer.Deferred + estimators = [] + for account in funding_accounts: + utxos = yield self.get_unspent_outputs(account) + for utxo in utxos: + estimators.append(utxo.get_estimator(self)) + defer.returnValue(estimators) + @defer.inlineCallbacks def get_local_status(self, address): address_details = yield self.db.get_address(address) diff --git a/torba/basetransaction.py b/torba/basetransaction.py index 8ae403675..8914b9173 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -1,6 +1,6 @@ import six import logging -from typing import List, Iterable, Generator +from typing import List, Iterable from binascii import hexlify from twisted.internet import defer @@ -271,17 +271,6 @@ class BaseTransaction: ]) self.locktime = stream.read_uint32() - @classmethod - @defer.inlineCallbacks - def get_effective_amount_estimators(cls, funding_accounts): - # type: (Iterable[torba.baseaccount.BaseAccount]) -> Generator[BaseOutputEffectiveAmountEstimator] - estimators = [] - for account in funding_accounts: - utxos = yield account.ledger.get_unspent_outputs(account) - for utxo in utxos: - estimators.append(utxo.get_estimator(account.ledger)) - defer.returnValue(estimators) - @classmethod def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None): # type: (Iterable[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> torba.baseledger.BaseLedger @@ -306,7 +295,7 @@ class BaseTransaction: tx = cls().add_outputs(outputs) ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) amount = tx.output_sum + ledger.get_transaction_base_fee(tx) - txos = yield cls.get_effective_amount_estimators(funding_accounts) + txos = yield ledger.get_effective_amount_estimators(funding_accounts) selector = CoinSelector( txos, amount, ledger.get_input_output_fee( diff --git a/torba/wallet.py b/torba/wallet.py index a93b77f25..9ac2d9f76 100644 --- a/torba/wallet.py +++ b/torba/wallet.py @@ -17,7 +17,7 @@ class Wallet: def __init__(self, name='Wallet', accounts=None, storage=None): # type: (str, List[torba.baseaccount.BaseAccount], WalletStorage) -> None self.name = name - self.accounts = accounts or [] + self.accounts = accounts or [] # type: List[torba.baseaccount.BaseAccount] self.storage = storage or WalletStorage() def generate_account(self, ledger):