refactoring and added basic test for reading/writing wallet file

This commit is contained in:
Lex Berezhny 2018-06-17 23:22:15 -04:00
parent 43cd9c4100
commit 833ef98ff5
5 changed files with 42 additions and 21 deletions

View file

@ -1,3 +1,4 @@
import tempfile
from twisted.trial import unittest from twisted.trial import unittest
from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger from torba.coin.bitcoinsegwit import MainNetLedger as BTCLedger
@ -36,11 +37,11 @@ class TestWalletCreation(unittest.TestCase):
"h absent", "h absent",
'encrypted': False, 'encrypted': False,
'private_key': 'private_key':
b'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' 'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P'
b'6yz3jMbycrLrRMpeAJxR8qDg8', '6yz3jMbycrLrRMpeAJxR8qDg8',
'public_key': 'public_key':
b'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f'
b'iW44g14WF52fYC5J483wqQ5ZP', 'iW44g14WF52fYC5J483wqQ5ZP',
'receiving_gap': 10, 'receiving_gap': 10,
'receiving_maximum_use_per_address': 2, 'receiving_maximum_use_per_address': 2,
'change_gap': 10, 'change_gap': 10,
@ -57,3 +58,24 @@ class TestWalletCreation(unittest.TestCase):
self.assertIsInstance(account, BTCLedger.account_class) self.assertIsInstance(account, BTCLedger.account_class)
self.maxDiff = None self.maxDiff = None
self.assertDictEqual(wallet_dict, wallet.to_dict()) self.assertDictEqual(wallet_dict, wallet.to_dict())
def test_read_write(self):
manager = WalletManager()
config = {'wallet_path': '/tmp/wallet'}
ledger = manager.get_or_create_ledger(BTCLedger.get_id(), config)
with tempfile.NamedTemporaryFile(suffix='.json') as wallet_file:
wallet_file.write(b'{}')
wallet_file.seek(0)
# create and write wallet to a file
wallet_storage = WalletStorage(wallet_file.name)
wallet = Wallet.from_storage(wallet_storage, manager)
account = wallet.generate_account(ledger)
wallet.save()
# read wallet from file
wallet_storage = WalletStorage(wallet_file.name)
wallet = Wallet.from_storage(wallet_storage, manager)
self.assertEqual(account.public_key.address, wallet.default_account.public_key.address)

View file

@ -133,8 +133,8 @@ class BaseAccount:
'seed': self.seed, 'seed': self.seed,
'encrypted': self.encrypted, 'encrypted': self.encrypted,
'private_key': self.private_key if self.encrypted else 'private_key': self.private_key if self.encrypted else
self.private_key.extended_key_string(), self.private_key.extended_key_string().decode(),
'public_key': self.public_key.extended_key_string(), 'public_key': self.public_key.extended_key_string().decode(),
'receiving_gap': self.receiving.gap, 'receiving_gap': self.receiving.gap,
'change_gap': self.change.gap, 'change_gap': self.change.gap,
'receiving_maximum_use_per_address': self.receiving.maximum_use_per_address, 'receiving_maximum_use_per_address': self.receiving.maximum_use_per_address,

View file

@ -2,7 +2,7 @@ import os
import six import six
import hashlib import hashlib
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
from typing import Dict, Type from typing import Dict, Type, Iterable, Generator
from operator import itemgetter from operator import itemgetter
from twisted.internet import defer from twisted.internet import defer
@ -126,6 +126,16 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)):
def get_unspent_outputs(self, account): def get_unspent_outputs(self, account):
return self.db.get_utxos(account, self.transaction_class.output_class) return self.db.get_utxos(account, self.transaction_class.output_class)
@defer.inlineCallbacks
def get_effective_amount_estimators(self, funding_accounts):
# type: (Iterable[baseaccount.BaseAccount]) -> defer.Deferred
estimators = []
for account in funding_accounts:
utxos = yield self.get_unspent_outputs(account)
for utxo in utxos:
estimators.append(utxo.get_estimator(self))
defer.returnValue(estimators)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_local_status(self, address): def get_local_status(self, address):
address_details = yield self.db.get_address(address) address_details = yield self.db.get_address(address)

View file

@ -1,6 +1,6 @@
import six import six
import logging import logging
from typing import List, Iterable, Generator from typing import List, Iterable
from binascii import hexlify from binascii import hexlify
from twisted.internet import defer from twisted.internet import defer
@ -271,17 +271,6 @@ class BaseTransaction:
]) ])
self.locktime = stream.read_uint32() self.locktime = stream.read_uint32()
@classmethod
@defer.inlineCallbacks
def get_effective_amount_estimators(cls, funding_accounts):
# type: (Iterable[torba.baseaccount.BaseAccount]) -> Generator[BaseOutputEffectiveAmountEstimator]
estimators = []
for account in funding_accounts:
utxos = yield account.ledger.get_unspent_outputs(account)
for utxo in utxos:
estimators.append(utxo.get_estimator(account.ledger))
defer.returnValue(estimators)
@classmethod @classmethod
def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None): def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None):
# type: (Iterable[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> torba.baseledger.BaseLedger # type: (Iterable[torba.baseaccount.BaseAccount], torba.baseaccount.BaseAccount) -> torba.baseledger.BaseLedger
@ -306,7 +295,7 @@ class BaseTransaction:
tx = cls().add_outputs(outputs) tx = cls().add_outputs(outputs)
ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account)
amount = tx.output_sum + ledger.get_transaction_base_fee(tx) amount = tx.output_sum + ledger.get_transaction_base_fee(tx)
txos = yield cls.get_effective_amount_estimators(funding_accounts) txos = yield ledger.get_effective_amount_estimators(funding_accounts)
selector = CoinSelector( selector = CoinSelector(
txos, amount, txos, amount,
ledger.get_input_output_fee( ledger.get_input_output_fee(

View file

@ -17,7 +17,7 @@ class Wallet:
def __init__(self, name='Wallet', accounts=None, storage=None): def __init__(self, name='Wallet', accounts=None, storage=None):
# type: (str, List[torba.baseaccount.BaseAccount], WalletStorage) -> None # type: (str, List[torba.baseaccount.BaseAccount], WalletStorage) -> None
self.name = name self.name = name
self.accounts = accounts or [] self.accounts = accounts or [] # type: List[torba.baseaccount.BaseAccount]
self.storage = storage or WalletStorage() self.storage = storage or WalletStorage()
def generate_account(self, ledger): def generate_account(self, ledger):