From fbd1b2ec306c0015d24c1ce8c751dd3e160f94db Mon Sep 17 00:00:00 2001 From: Lex Berezhny Date: Tue, 12 Jun 2018 10:02:04 -0400 Subject: [PATCH] monday progress --- tests/integration/test_transactions.py | 18 ++- tests/unit/test_account.py | 100 ++++++++++--- tests/unit/test_ledger.py | 69 ++++++++- torba/baseaccount.py | 147 +++++++++---------- torba/basedatabase.py | 193 +++++++++++-------------- torba/baseledger.py | 100 ++++++------- torba/basetransaction.py | 32 ++-- torba/coin/bitcoincash.py | 3 +- torba/coin/bitcoinsegwit.py | 11 +- torba/manager.py | 15 +- 10 files changed, 386 insertions(+), 302 deletions(-) diff --git a/tests/integration/test_transactions.py b/tests/integration/test_transactions.py index fd16cc534..9cf9327aa 100644 --- a/tests/integration/test_transactions.py +++ b/tests/integration/test_transactions.py @@ -10,23 +10,29 @@ class BasicTransactionTests(IntegrationTestCase): async def test_sending_and_recieving(self): account1, account2 = self.account, self.wallet.generate_account(self.ledger) + await account1.ensure_address_gap().asFuture(asyncio.get_event_loop()) + self.assertEqual(await self.get_balance(account1), 0) self.assertEqual(await self.get_balance(account2), 0) - address = await account1.get_least_used_receiving_address().asFuture(asyncio.get_event_loop()) + address = await account1.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop()) sendtxid = await self.blockchain.send_to_address(address.decode(), 5.5) - await self.blockchain.generate(1) await self.on_transaction(sendtxid) + await self.blockchain.generate(1) + await asyncio.sleep(5) self.assertEqual(await self.get_balance(account1), int(5.5*COIN)) self.assertEqual(await self.get_balance(account2), 0) - address = await account2.get_least_used_receiving_address().asFuture(asyncio.get_event_loop()) - sendtxid = await self.blockchain.send_to_address(address.decode(), 5.5) + address = await account2.receiving.get_or_create_usable_address().asFuture(asyncio.get_event_loop()) + tx = await self.ledger.transaction_class.pay( + [self.ledger.transaction_class.output_class.pay_pubkey_hash(2, self.ledger.address_to_hash160(address))], + [account1], account1 + ).asFuture(asyncio.get_event_loop()) await self.broadcast(tx) await self.on_transaction(tx.id.decode()) await self.lbrycrd.generate(1) - self.assertEqual(await self.get_balance(account1), int(3.0*COIN)) - self.assertEqual(await self.get_balance(account2), int(2.5*COIN)) + self.assertEqual(await self.get_balance(account1), int(3.5*COIN)) + self.assertEqual(await self.get_balance(account2), int(2.0*COIN)) diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 323420ba4..393def1d4 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -5,6 +5,66 @@ from twisted.internet import defer from torba.coin.bitcoinsegwit import MainNetLedger +class TestKeyChain(unittest.TestCase): + + def setUp(self): + self.ledger = MainNetLedger(db=':memory:') + return self.ledger.db.start() + + @defer.inlineCallbacks + def test_address_gap_algorithm(self): + account = self.ledger.account_class.generate(self.ledger, u"torba") + + # save records out of order to make sure we're really testing ORDER BY + # and not coincidentally getting records in the correct order + yield account.receiving.generate_keys(4, 7) + yield account.receiving.generate_keys(0, 3) + yield account.receiving.generate_keys(8, 11) + keys = yield account.receiving.get_addresses(None, True) + self.assertEqual( + [key['position'] for key in keys], + [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + ) + + # we have 12, but default gap is 20 + new_keys = yield account.receiving.ensure_address_gap() + self.assertEqual(len(new_keys), 8) + keys = yield account.receiving.get_addresses(None, True) + self.assertEqual( + [key['position'] for key in keys], + [19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + ) + + # case #1: no new addresses needed + empty = yield account.receiving.ensure_address_gap() + self.assertEqual(len(empty), 0) + + # case #2: only one new addressed needed + keys = yield account.receiving.get_addresses(None, True) + yield self.ledger.db.set_address_history(keys[19]['address'], 'a:1:') + new_keys = yield account.receiving.ensure_address_gap() + self.assertEqual(len(new_keys), 1) + + # case #3: 20 addresses needed + keys = yield account.receiving.get_addresses(None, True) + yield self.ledger.db.set_address_history(keys[0]['address'], 'a:1:') + new_keys = yield account.receiving.ensure_address_gap() + self.assertEqual(len(new_keys), 20) + + @defer.inlineCallbacks + def test_create_usable_address(self): + account = self.ledger.account_class.generate(self.ledger, u"torba") + + keys = yield account.receiving.get_addresses(None, True) + self.assertEqual(len(keys), 0) + + address = yield account.receiving.get_or_create_usable_address() + self.assertIsNotNone(address) + + keys = yield account.receiving.get_addresses(None, True) + self.assertEqual(len(keys), 20) + + class TestAccount(unittest.TestCase): def setUp(self): @@ -19,24 +79,16 @@ class TestAccount(unittest.TestCase): self.assertEqual(account.public_key.ledger, self.ledger) self.assertEqual(account.private_key.public_key, account.public_key) - keys = yield account.receiving.get_keys() addresses = yield account.receiving.get_addresses() - self.assertEqual(len(keys), 0) self.assertEqual(len(addresses), 0) - keys = yield account.change.get_keys() addresses = yield account.change.get_addresses() - self.assertEqual(len(keys), 0) self.assertEqual(len(addresses), 0) - yield account.ensure_enough_useable_addresses() + yield account.ensure_address_gap() - keys = yield account.receiving.get_keys() addresses = yield account.receiving.get_addresses() - self.assertEqual(len(keys), 20) self.assertEqual(len(addresses), 20) - keys = yield account.change.get_keys() addresses = yield account.change.get_addresses() - self.assertEqual(len(keys), 6) self.assertEqual(len(addresses), 6) @defer.inlineCallbacks @@ -57,19 +109,23 @@ class TestAccount(unittest.TestCase): b'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' b'iW44g14WF52fYC5J483wqQ5ZP' ) - address = yield account.receiving.ensure_enough_useable_addresses() - self.assertEqual(address[0], b'1PGDB1CRy8UxPCrkcakRqroVnHxqzvUZhp') - private_key = yield self.ledger.get_private_key_for_address(b'1PGDB1CRy8UxPCrkcakRqroVnHxqzvUZhp') + address = yield account.receiving.ensure_address_gap() + self.assertEqual(address[0], b'1PmX9T3sCiDysNtWszJa44SkKcpGc2NaXP') + + self.maxDiff = None + private_key = yield self.ledger.get_private_key_for_address(b'1PmX9T3sCiDysNtWszJa44SkKcpGc2NaXP') self.assertEqual( private_key.extended_key_string(), - b'xprv9xNEfQ296VTRc5QF7AZZ1WTimGzMs54FepRXVxbyypJXCrUKjxsYSyk5EhHYNxU4ApsaBr8AQ4sYo86BbGh2dZSddGXU1CMGwExvnyckjQn' + b'xprv9xNEfQ296VTRaEUDZ8oKq74xw2U6kpj486vFUB4K1wT9U25GX4UwuzFgJN1YuRrqkQ5TTwCpkYnjNpSoH' + b'SBaEigNHPkoeYbuPMRo6mRUjxg' ) + invalid_key = yield self.ledger.get_private_key_for_address(b'BcQjRlhDOIrQez1WHfz3whnB33Bp34sUgX') self.assertIsNone(invalid_key) self.assertEqual( hexlify(private_key.wif()), - b'1c5664e848772b199644ab390b5c27d2f6664d9cdfdb62e1c7ac25151b00858b7a01' + b'1cc27be89ad47ef932562af80e95085eb0ab2ae3e5c019b1369b8b05ff2e94512f01' ) @defer.inlineCallbacks @@ -80,26 +136,24 @@ class TestAccount(unittest.TestCase): "h absent", 'encrypted': False, 'private_key': - 'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' - '6yz3jMbycrLrRMpeAJxR8qDg8', + b'xprv9s21ZrQH143K2dyhK7SevfRG72bYDRNv25yKPWWm6dqApNxm1Zb1m5gGcBWYfbsPjTr2v5joit8Af2Zp5P' + b'6yz3jMbycrLrRMpeAJxR8qDg8', 'public_key': - 'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' - 'iW44g14WF52fYC5J483wqQ5ZP', + b'xpub661MyMwAqRbcF84AR8yfHoMzf4S2ct6mPJtvBtvNeyN9hBHuZ6uGJszkTSn5fQUCdz3XU17eBzFeAUwV6f' + b'iW44g14WF52fYC5J483wqQ5ZP', 'receiving_gap': 10, + 'receiving_maximum_use_per_address': 2, 'change_gap': 10, + 'change_maximum_use_per_address': 2 } account = self.ledger.account_class.from_dict(self.ledger, account_data) - yield account.ensure_enough_useable_addresses() + yield account.ensure_address_gap() - keys = yield account.receiving.get_keys() addresses = yield account.receiving.get_addresses() - self.assertEqual(len(keys), 10) self.assertEqual(len(addresses), 10) - keys = yield account.change.get_keys() addresses = yield account.change.get_addresses() - self.assertEqual(len(keys), 10) self.assertEqual(len(addresses), 10) self.maxDiff = None diff --git a/tests/unit/test_ledger.py b/tests/unit/test_ledger.py index d1aaec30d..eb281888a 100644 --- a/tests/unit/test_ledger.py +++ b/tests/unit/test_ledger.py @@ -1,15 +1,72 @@ +from binascii import hexlify from twisted.trial import unittest from twisted.internet import defer -from torba.basedatabase import BaseSQLiteWalletStorage +from torba.coin.bitcoinsegwit import MainNetLedger + +from .test_transaction import get_transaction -class TestDatabase(unittest.TestCase): +class MockNetwork: + + def __init__(self, history, transaction): + self.history = history + self.transaction = transaction + self.address = None + self.get_history_called = [] + self.get_transaction_called = [] + + def get_history(self, address): + self.get_history_called.append(address) + self.address = address + return defer.succeed(self.history) + + def get_transaction(self, tx_hash): + self.get_transaction_called.append(tx_hash) + return defer.succeed(self.transaction[tx_hash]) + + +class TestSynchronization(unittest.TestCase): def setUp(self): - self.db = BaseSQLiteWalletStorage(':memory:') - return self.db.start() + self.ledger = MainNetLedger(db=':memory:') + return self.ledger.db.start() @defer.inlineCallbacks - def test_empty_db(self): - result = yield self.db. + def test_update_history(self): + account = self.ledger.account_class.generate(self.ledger, u"torba") + address = yield account.receiving.get_or_create_usable_address() + address_details = yield self.ledger.db.get_address(address) + self.assertEqual(address_details['history'], None) + + self.ledger.network = MockNetwork([ + {'tx_hash': b'abc', 'height': 1}, + {'tx_hash': b'def', 'height': 2}, + {'tx_hash': b'ghi', 'height': 3}, + ], { + b'abc': hexlify(get_transaction().raw), + b'def': hexlify(get_transaction().raw), + b'ghi': hexlify(get_transaction().raw), + }) + yield self.ledger.update_history(address) + self.assertEqual(self.ledger.network.get_history_called, [address]) + self.assertEqual(self.ledger.network.get_transaction_called, [b'abc', b'def', b'ghi']) + + address_details = yield self.ledger.db.get_address(address) + self.assertEqual(address_details['history'], b'abc:1:def:2:ghi:3:') + + self.ledger.network.get_history_called = [] + self.ledger.network.get_transaction_called = [] + yield self.ledger.update_history(address) + self.assertEqual(self.ledger.network.get_history_called, [address]) + self.assertEqual(self.ledger.network.get_transaction_called, []) + + self.ledger.network.history.append({'tx_hash': b'jkl', 'height': 4}) + self.ledger.network.transaction[b'jkl'] = hexlify(get_transaction().raw) + self.ledger.network.get_history_called = [] + self.ledger.network.get_transaction_called = [] + yield self.ledger.update_history(address) + self.assertEqual(self.ledger.network.get_history_called, [address]) + self.assertEqual(self.ledger.network.get_transaction_called, [b'jkl']) + address_details = yield self.ledger.db.get_address(address) + self.assertEqual(address_details['history'], b'abc:1:def:2:ghi:3:jkl:4:') diff --git a/torba/baseaccount.py b/torba/baseaccount.py index 8ef24e864..194f287ba 100644 --- a/torba/baseaccount.py +++ b/torba/baseaccount.py @@ -1,7 +1,7 @@ from typing import Dict -from binascii import unhexlify from twisted.internet import defer +import torba.baseledger from torba.mnemonic import Mnemonic from torba.bip32 import PrivateKey, PubKey, from_extended_key_string from torba.hash import double_sha256, aes_encrypt, aes_decrypt @@ -9,55 +9,59 @@ from torba.hash import double_sha256, aes_encrypt, aes_decrypt class KeyChain: - def __init__(self, account, parent_key, chain_number, minimum_usable_addresses): + def __init__(self, account, parent_key, chain_number, gap, maximum_use_per_address): + # type: ('BaseAccount', PubKey, int, int, int) -> None self.account = account self.db = account.ledger.db - self.main_key = parent_key.child(chain_number) # type: PubKey + self.main_key = parent_key.child(chain_number) self.chain_number = chain_number - self.minimum_usable_addresses = minimum_usable_addresses + self.gap = gap + self.maximum_use_per_address = maximum_use_per_address - def get_keys(self): - return self.db.get_keys(self.account, self.chain_number) + def get_addresses(self, limit=None, details=False): + return self.db.get_addresses(self.account, self.chain_number, limit, details) - def get_addresses(self): - return self.db.get_addresses(self.account, self.chain_number) + def get_usable_addresses(self, limit=None): + return self.db.get_usable_addresses( + self.account, self.chain_number, self.maximum_use_per_address, limit + ) @defer.inlineCallbacks - def ensure_enough_useable_addresses(self): - usable_address_count = yield self.db.get_usable_address_count( - self.account, self.chain_number - ) - - if usable_address_count >= self.minimum_usable_addresses: - defer.returnValue([]) - - new_addresses_needed = self.minimum_usable_addresses - usable_address_count - - start = yield self.db.get_last_address_index( - self.account, self.chain_number - ) - end = start + new_addresses_needed - + def generate_keys(self, start, end): new_keys = [] - for index in range(start+1, end+1): + for index in range(start, end+1): new_keys.append((index, self.main_key.child(index))) - yield self.db.add_keys( self.account, self.chain_number, new_keys ) - - defer.returnValue([ - key[1].address for key in new_keys - ]) + defer.returnValue([key[1].address for key in new_keys]) @defer.inlineCallbacks - def has_gap(self): - if len(self.addresses) < self.minimum_gap: - defer.returnValue(False) - for address in self.addresses[-self.minimum_gap:]: - if (yield self.ledger.is_address_old(address)): - defer.returnValue(False) - defer.returnValue(True) + def ensure_address_gap(self): + addresses = yield self.get_addresses(self.gap, True) + + existing_gap = 0 + for address in addresses: + if address['used_times'] == 0: + existing_gap += 1 + else: + break + + if existing_gap == self.gap: + defer.returnValue([]) + + start = addresses[0]['position']+1 if addresses else 0 + end = start + (self.gap - existing_gap) + new_keys = yield self.generate_keys(start, end-1) + defer.returnValue(new_keys) + + @defer.inlineCallbacks + def get_or_create_usable_address(self): + addresses = yield self.get_usable_addresses(1) + if addresses: + return addresses[0] + addresses = yield self.ensure_address_gap() + return addresses[0] class BaseAccount: @@ -67,26 +71,28 @@ class BaseAccount: public_key_class = PubKey def __init__(self, ledger, seed, encrypted, private_key, - public_key, receiving_gap=20, change_gap=6): - self.ledger = ledger # type: baseledger.BaseLedger - self.seed = seed # type: str - self.encrypted = encrypted # type: bool - self.private_key = private_key # type: PrivateKey - self.public_key = public_key # type: PubKey + public_key, receiving_gap=20, change_gap=6, + receiving_maximum_use_per_address=2, change_maximum_use_per_address=2): + # type: (torba.baseledger.BaseLedger, str, bool, PrivateKey, PubKey, int, int, int, int) -> None + self.ledger = ledger + self.seed = seed + self.encrypted = encrypted + self.private_key = private_key + self.public_key = public_key self.receiving, self.change = self.keychains = ( - KeyChain(self, public_key, 0, receiving_gap), - KeyChain(self, public_key, 1, change_gap) + KeyChain(self, public_key, 0, receiving_gap, receiving_maximum_use_per_address), + KeyChain(self, public_key, 1, change_gap, change_maximum_use_per_address) ) ledger.account_created(self) @classmethod - def generate(cls, ledger, password): # type: (baseledger.BaseLedger, str) -> BaseAccount + def generate(cls, ledger, password): # type: (torba.baseledger.BaseLedger, str) -> BaseAccount seed = cls.mnemonic_class().make_seed() return cls.from_seed(ledger, seed, password) @classmethod def from_seed(cls, ledger, seed, password): - # type: (baseledger.BaseLedger, str, str) -> BaseAccount + # type: (torba.baseledger.BaseLedger, str, str) -> BaseAccount private_key = cls.get_private_key_from_seed(ledger, seed, password) return cls( ledger=ledger, seed=seed, encrypted=False, @@ -96,13 +102,13 @@ class BaseAccount: @classmethod def get_private_key_from_seed(cls, ledger, seed, password): - # type: (baseledger.BaseLedger, str, str) -> PrivateKey + # type: (torba.baseledger.BaseLedger, str, str) -> PrivateKey return cls.private_key_class.from_seed( ledger, cls.mnemonic_class.mnemonic_to_seed(seed, password) ) @classmethod - def from_dict(cls, ledger, d): # type: (baseledger.BaseLedger, Dict) -> BaseAccount + def from_dict(cls, ledger, d): # type: (torba.baseledger.BaseLedger, Dict) -> BaseAccount if not d['encrypted']: private_key = from_extended_key_string(ledger, d['private_key']) public_key = private_key.public_key @@ -116,7 +122,9 @@ class BaseAccount: private_key=private_key, public_key=public_key, receiving_gap=d['receiving_gap'], - change_gap=d['change_gap'] + change_gap=d['change_gap'], + receiving_maximum_use_per_address=d['receiving_maximum_use_per_address'], + change_maximum_use_per_address=d['change_maximum_use_per_address'] ) def to_dict(self): @@ -127,8 +135,10 @@ class BaseAccount: 'private_key': self.private_key if self.encrypted else self.private_key.extended_key_string(), 'public_key': self.public_key.extended_key_string(), - 'receiving_gap': self.receiving.minimum_usable_addresses, - 'change_gap': self.change.minimum_usable_addresses, + 'receiving_gap': self.receiving.gap, + 'change_gap': self.change.gap, + 'receiving_maximum_use_per_address': self.receiving.maximum_use_per_address, + 'change_maximum_use_per_address': self.change.maximum_use_per_address } def decrypt(self, password): @@ -146,39 +156,22 @@ class BaseAccount: self.encrypted = True @defer.inlineCallbacks - def ensure_enough_useable_addresses(self): + def ensure_address_gap(self): addresses = [] for keychain in self.keychains: - new_addresses = yield keychain.ensure_enough_useable_addresses() + new_addresses = yield keychain.ensure_address_gap() addresses.extend(new_addresses) defer.returnValue(addresses) + def get_addresses(self, limit=None, details=False): + return self.ledger.db.get_addresses(self, None, limit, details) + + def get_unused_addresses(self): + return self.ledger.db.get_unused_addresses(self, None) + def get_private_key(self, chain, index): assert not self.encrypted, "Cannot get private key on encrypted wallet account." return self.private_key.child(chain).child(index) - def get_least_used_receiving_address(self, max_transactions=1000): - return self._get_least_used_address( - self.receiving_keys, - max_transactions - ) - - def get_least_used_change_address(self, max_transactions=100): - return self._get_least_used_address( - self.change_keys, - max_transactions - ) - - def _get_least_used_address(self, keychain, max_transactions): - ledger = self.ledger - address = ledger.get_least_used_address(self, keychain, max_transactions) - if address: - return address - address = keychain.generate_next_address() - ledger.subscribe_history(address) - return address - - @defer.inlineCallbacks def get_balance(self): - utxos = yield self.ledger.get_unspent_outputs(self) - defer.returnValue(sum(utxo.amount for utxo in utxos)) + return self.ledger.db.get_balance_for_account(self) diff --git a/torba/basedatabase.py b/torba/basedatabase.py index c06995d88..5f7d4f2a6 100644 --- a/torba/basedatabase.py +++ b/torba/basedatabase.py @@ -1,8 +1,12 @@ import logging +from typing import List, Union + import sqlite3 from twisted.internet import defer from twisted.enterprise import adbapi +import torba.baseaccount + log = logging.getLogger(__name__) @@ -34,17 +38,19 @@ class SQLiteMixin(object): return trans.execute(sql).fetchall() def _insert_sql(self, table, data): + # type: (str, dict) -> tuple[str, List] columns, values = [], [] for column, value in data.items(): columns.append(column) values.append(value) - sql = "REPLACE INTO %s (%s) VALUES (%s)".format( + sql = "REPLACE INTO {} ({}) VALUES ({})".format( table, ', '.join(columns), ', '.join(['?'] * len(values)) ) return sql, values @defer.inlineCallbacks def query_one_value_list(self, query, params): + # type: (str, Union[dict,tuple]) -> defer.Deferred[List] result = yield self.db.runQuery(query, params) if result: defer.returnValue([i[0] for i in result]) @@ -137,27 +143,7 @@ class BaseDatabase(SQLiteMixin): CREATE_TXI_TABLE ) - def get_missing_transactions(self, address, txids): - def _steps(t): - missing = [] - chunk_size = 100 - for i in range(0, len(txids), chunk_size): - chunk = txids[i:i + chunk_size] - t.execute( - "SELECT 1 FROM tx WHERE txid=?", - (sqlite3.Binary(txid) for txid in chunk) - ) - if not t.execute("SELECT 1 FROM tx WHERE txid=?", (sqlite3.Binary(tx.id),)).fetchone(): - t.execute(*self._insert_sql('tx', { - 'txid': sqlite3.Binary(tx.id), - 'raw': sqlite3.Binary(tx.raw), - 'height': height, - 'is_confirmed': is_confirmed, - 'is_verified': is_verified - })) - return self.db.runInteraction(_steps) - - def add_transaction(self, address, tx, height, is_confirmed, is_verified): + def add_transaction(self, address, hash, tx, height, is_confirmed, is_verified): def _steps(t): if not t.execute("SELECT 1 FROM tx WHERE txid=?", (sqlite3.Binary(tx.id),)).fetchone(): t.execute(*self._insert_sql('tx', { @@ -167,46 +153,37 @@ class BaseDatabase(SQLiteMixin): 'is_confirmed': is_confirmed, 'is_verified': is_verified })) - t.execute(*self._insert_sql( - "insert into txo values (?, ?, ?, ?, ?, ?, ?, ?, ?)", ( - sqlite3.Binary(account.public_key.address), - sqlite3.Binary(txo.script.values['pubkey_hash']), - sqlite3.Binary(txo.txid), - txo.index, - txo.amount, - sqlite3.Binary(txo.script.source), - txo.script.is_claim_name, - txo.script.is_support_claim, - txo.script.is_update_claim - ) - - )) - txoid = t.execute( - "select rowid from txo where txid=? and pos=?", ( - sqlite3.Binary(txi.output_txid), txi.output_index - ) - ).fetchone()[0] - t.execute( - "insert into txi values (?, ?, ?)", ( - sqlite3.Binary(account.public_key.address), - sqlite3.Binary(txi.txid), - txoid - ) - ) + for txo in tx.outputs: + if txo.script.is_pay_pubkey_hash and txo.script.values['pubkey_hash'] == hash: + t.execute(*self._insert_sql("txo", { + 'txid': sqlite3.Binary(tx.id), + 'address': sqlite3.Binary(address), + 'position': txo.index, + 'amount': txo.amount, + 'script': sqlite3.Binary(txo.script.source) + })) + elif txo.script.is_pay_script_hash: + # TODO: implement script hash payments + print('Database.add_transaction pay script hash is not implemented!') + for txi in tx.inputs: + txoid = t.execute( + "SELECT txoid, address FROM txo WHERE txid = ? AND position = ?", + (sqlite3.Binary(txi.output_txid), txi.output_index) + ).fetchone() + if txoid: + t.execute(*self._insert_sql("txi", { + 'txid': sqlite3.Binary(tx.id), + 'address': sqlite3.Binary(address), + 'txoid': txoid, + })) return self.db.runInteraction(_steps) - @defer.inlineCallbacks - def has_transaction(self, txid): - result = yield self.db.runQuery( - "select rowid from tx where txid=?", (txid,) - ) - defer.returnValue(bool(result)) - @defer.inlineCallbacks def get_balance_for_account(self, account): result = yield self.db.runQuery( - "select sum(amount) from txo where account=:account and rowid not in (select txo from txi where account=:account)", + "SELECT SUM(amount) FROM txo NATURAL JOIN pubkey_address WHERE account=:account AND " + "txoid NOT IN (SELECT txoid FROM txi)", {'account': sqlite3.Binary(account.public_key.address)} ) if result: @@ -218,12 +195,9 @@ class BaseDatabase(SQLiteMixin): def get_utxos(self, account, output_class): utxos = yield self.db.runQuery( """ - SELECT - amount, script, txid - FROM txo - WHERE - account=:account AND - txoid NOT IN (SELECT txoid FROM txi WHERE account=:account) + SELECT amount, script, txid, position + FROM txo NATURAL JOIN pubkey_address + WHERE account=:account AND txoid NOT IN (SELECT txoid FROM txi) """, {'account': sqlite3.Binary(account.public_key.address)} ) @@ -231,7 +205,8 @@ class BaseDatabase(SQLiteMixin): output_class( values[0], output_class.script_class(values[1]), - values[2] + values[2], + index=values[3] ) for values in utxos ]) @@ -250,64 +225,58 @@ class BaseDatabase(SQLiteMixin): values.append(sqlite3.Binary(pubkey.pubkey_bytes)) return self.db.runOperation(sql, values) - def get_keys(self, account, chain): - return self.query_one_value_list( - "SELECT pubkey FROM pubkey_address WHERE account = ? AND chain = ?", - (sqlite3.Binary(account.public_key.address), chain) - ) + def get_addresses(self, account, chain, limit=None, details=False): + sql = ["SELECT {} FROM pubkey_address WHERE account = :account"] + params = {'account': sqlite3.Binary(account.public_key.address)} + if chain is not None: + sql.append("AND chain = :chain") + params['chain'] = chain + sql.append("ORDER BY position DESC") + if limit is not None: + sql.append("LIMIT {}".format(limit)) + if details: + return self.query_dict_value_list(' '.join(sql), ('address', 'position', 'used_times'), params) + else: + return self.query_one_value_list(' '.join(sql).format('address'), params) - def get_address_details(self, address): - return self.query_dict_value( - "SELECT {} FROM pubkey_address WHERE address = ?", - ('account', 'chain', 'position'), (sqlite3.Binary(address),) - ) - - def get_addresses(self, account, chain): - return self.query_one_value_list( - "SELECT address FROM pubkey_address WHERE account = ? AND chain = ?", - (sqlite3.Binary(account.public_key.address), chain) - ) - - def get_last_address_index(self, account, chain): - return self.query_one_value( - """ - SELECT position FROM pubkey_address - WHERE account = ? AND chain = ? - ORDER BY position DESC LIMIT 1""", - (sqlite3.Binary(account.public_key.address), chain), - default=0 - ) - - def _usable_address_sql(self, account, chain, exclude_used_times): - return """ - SELECT address FROM pubkey_address - WHERE - account = :account AND - chain = :chain AND - used_times <= :exclude_used_times - """, { + def _used_address_sql(self, account, chain, comparison_op, used_times, limit=None): + sql = [ + "SELECT address FROM pubkey_address", + "WHERE account = :account AND" + ] + params = { 'account': sqlite3.Binary(account.public_key.address), - 'chain': chain, - 'exclude_used_times': exclude_used_times + 'used_times': used_times } + if chain is not None: + sql.append("chain = :chain AND") + params['chain'] = chain + sql.append("used_times {} :used_times".format(comparison_op)) + sql.append("ORDER BY used_times ASC") + if limit is not None: + sql.append('LIMIT {}'.format(limit)) + return ' '.join(sql), params - def get_usable_addresses(self, account, chain, exclude_used_times=2): - return self.query_one_value_list(*self._usable_address_sql( - account, chain, exclude_used_times + def get_unused_addresses(self, account, chain): + # type: (torba.baseaccount.BaseAccount, int) -> defer.Deferred[List[str]] + return self.query_one_value_list(*self._used_address_sql( + account, chain, '=', 0 )) - def get_usable_address_count(self, account, chain, exclude_used_times=2): - return self.query_count(*self._usable_address_sql( - account, chain, exclude_used_times + def get_usable_addresses(self, account, chain, max_used_times, limit): + return self.query_one_value_list(*self._used_address_sql( + account, chain, '<=', max_used_times, limit )) - def get_address_history(self, address): - return self.query_one_value( - "SELECT history FROM pubkey_address WHERE address = ?", (sqlite3.Binary(address),) + def get_address(self, address): + return self.query_dict_value( + "SELECT {} FROM pubkey_address WHERE address= :address", + ('address', 'account', 'chain', 'position', 'pubkey', 'history', 'used_times'), + {'address': sqlite3.Binary(address)} ) - def set_address_status(self, address, status): + def set_address_history(self, address, history): return self.db.runOperation( - "replace into address_status (address, status) values (?, ?)", (address,status) + "UPDATE pubkey_address SET history = ?, used_times = ? WHERE address = ?", + (sqlite3.Binary(history), history.count(b':')//2, sqlite3.Binary(address)) ) - diff --git a/torba/baseledger.py b/torba/baseledger.py index dcad300db..1acc1c1a1 100644 --- a/torba/baseledger.py +++ b/torba/baseledger.py @@ -76,7 +76,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): raw_address = self.pubkey_address_prefix + h160 return Base58.encode(bytearray(raw_address + double_sha256(raw_address)[0:4])) - def account_created(self, account): + def account_created(self, account): # type: (baseaccount.BaseAccount) -> None self.accounts.add(account) @staticmethod @@ -94,7 +94,7 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): @property def path(self): - return os.path.join(self.config['path'], self.get_id()) + return os.path.join(self.config['wallet_path'], self.get_id()) def get_input_output_fee(self, io): """ Fee based on size of the input / output. """ @@ -104,29 +104,17 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): """ Fee for the transaction header and all outputs; without inputs. """ return self.fee_per_byte * tx.base_size - def get_keys(self, account, chain): - return self.db.get_keys(account, chain) - @defer.inlineCallbacks - def add_transaction(self, transaction, height): # type: (basetransaction.BaseTransaction, int) -> None - yield self.db.add_transaction(transaction, height, False, False) + def add_transaction(self, address, transaction, height): + # type: (bytes, basetransaction.BaseTransaction, int) -> None + yield self.db.add_transaction( + address, self.address_to_hash160(address), transaction, height, False, False + ) self._on_transaction_controller.add(transaction) - def has_address(self, address): - return address in self.accounts.addresses - - @defer.inlineCallbacks - def get_least_used_address(self, account, keychain, max_transactions=100): - used_addresses = yield self.db.get_used_addresses(account) - unused_set = set(keychain.addresses) - set(map(itemgetter(0), used_addresses)) - if unused_set: - defer.returnValue(unused_set.pop()) - if used_addresses and used_addresses[0][1] < max_transactions: - defer.returnValue(used_addresses[0][0]) - @defer.inlineCallbacks def get_private_key_for_address(self, address): - match = yield self.db.get_address_details(address) + match = yield self.db.get_address(address) if match: for account in self.accounts: if bytes(match['account']) == account.public_key.address: @@ -148,6 +136,19 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): # if output[1:] not in inputs: # yield output[0] + @defer.inlineCallbacks + def get_local_status(self, address): + address_details = yield self.db.get_address(address) + hash = hashlib.sha256(address_details['history']).digest() + defer.returnValue(hexlify(hash)) + + @defer.inlineCallbacks + def get_local_history(self, address): + address_details = yield self.db.get_address(address) + history = address_details['history'] or b'' + parts = history.split(b':')[:-1] + defer.returnValue(list(zip(parts[0::2], map(int, parts[1::2])))) + @defer.inlineCallbacks def start(self): if not os.path.exists(self.path): @@ -195,62 +196,57 @@ class BaseLedger(six.with_metaclass(LedgerRegistry)): ]) @defer.inlineCallbacks - def update_account(self, account): # type: (Account) -> defer.Defferred + def update_account(self, account): # type: (baseaccount.BaseAccount) -> defer.Defferred # Before subscribing, download history for any addresses that don't have any, # this avoids situation where we're getting status updates to addresses we know # need to update anyways. Continue to get history and create more addresses until # all missing addresses are created and history for them is fully restored. - yield account.ensure_enough_addresses() - addresses = yield account.get_unused_addresses(account) + yield account.ensure_address_gap() + addresses = yield account.get_unused_addresses() while addresses: yield defer.DeferredList([ self.update_history(a) for a in addresses ]) - addresses = yield account.ensure_enough_addresses() + addresses = yield account.ensure_address_gap() # By this point all of the addresses should be restored and we # can now subscribe all of them to receive updates. - yield defer.DeferredList([ - self.subscribe_history(address) - for address in account.addresses - ]) - - def _get_status_from_history(self, history): - hashes = [ - '{}:{}:'.format(hash.decode(), height).encode() - for hash, height in map(itemgetter('tx_hash', 'height'), history) - ] - if hashes: - return hexlify(hashlib.sha256(b''.join(hashes)).digest()) + all_addresses = yield account.get_addresses() + yield defer.DeferredList( + list(map(self.subscribe_history, all_addresses)) + ) @defer.inlineCallbacks - def update_history(self, address, remote_status=None): - history = yield self.network.get_history(address) - hashes = list(map(itemgetter('tx_hash'), history)) - for hash, height in map(itemgetter('tx_hash', 'height'), history): + def update_history(self, address): + remote_history = yield self.network.get_history(address) + local = yield self.get_local_history(address) - if not (yield self.db.has_transaction(hash)): - raw = yield self.network.get_transaction(hash) - transaction = self.transaction_class(unhexlify(raw)) - yield self.add_transaction(transaction, height) - if remote_status is None: - remote_status = self._get_status_from_history(history) - if remote_status: - yield self.db.set_address_status(address, remote_status) + history_parts = [] + for i, (hash, height) in enumerate(map(itemgetter('tx_hash', 'height'), remote_history)): + history_parts.append('{}:{}:'.format(hash.decode(), height)) + if i < len(local) and local[i] == (hash, height): + continue + raw = yield self.network.get_transaction(hash) + transaction = self.transaction_class(unhexlify(raw)) + yield self.add_transaction(address, transaction, height) + + yield self.db.set_address_history( + address, ''.join(history_parts).encode() + ) @defer.inlineCallbacks def subscribe_history(self, address): remote_status = yield self.network.subscribe_address(address) - local_status = yield self.db.get_address_status(address) + local_status = yield self.get_local_status(address) if local_status != remote_status: - yield self.update_history(address, remote_status) + yield self.update_history(address) @defer.inlineCallbacks def process_status(self, response): address, remote_status = response - local_status = yield self.db.get_address_status(address) + local_status = yield self.get_local_status(address) if local_status != remote_status: - yield self.update_history(address, remote_status) + yield self.update_history(address) def broadcast(self, tx): return self.network.broadcast(hexlify(tx.raw)) diff --git a/torba/basetransaction.py b/torba/basetransaction.py index 813c4e555..54e46cb7a 100644 --- a/torba/basetransaction.py +++ b/torba/basetransaction.py @@ -3,6 +3,8 @@ import logging from typing import List, Iterable, Generator from binascii import hexlify +from twisted.internet import defer + from torba.basescript import BaseInputScript, BaseOutputScript from torba.coinselection import CoinSelector from torba.constants import COIN @@ -20,10 +22,10 @@ NULL_HASH = b'\x00'*32 class InputOutput(object): - def __init__(self, txid): + def __init__(self, txid, index=None): self._txid = txid # type: bytes self.transaction = None # type: BaseTransaction - self.index = None # type: int + self.index = index # type: int @property def txid(self): @@ -53,7 +55,7 @@ class BaseInput(InputOutput): super(BaseInput, self).__init__(txid) if isinstance(output_or_txid_index, BaseOutput): self.output = output_or_txid_index # type: BaseOutput - self.output_txid = self.output.transaction.id + self.output_txid = self.output.txid self.output_index = self.output.index else: self.output = None # type: BaseOutput @@ -127,8 +129,8 @@ class BaseOutput(InputOutput): script_class = BaseOutputScript estimator_class = BaseOutputEffectiveAmountEstimator - def __init__(self, amount, script, txid=None): - super(BaseOutput, self).__init__(txid) + def __init__(self, amount, script, txid=None, index=None): + super(BaseOutput, self).__init__(txid, index) self.amount = amount # type: int self.script = script # type: BaseOutputScript @@ -275,11 +277,15 @@ class BaseTransaction: self.locktime = stream.read_uint32() @classmethod + @defer.inlineCallbacks def get_effective_amount_estimators(cls, funding_accounts): # type: (Iterable[BaseAccount]) -> Generator[BaseOutputEffectiveAmountEstimator] + estimators = [] for account in funding_accounts: - for utxo in account.coin.ledger.get_unspent_outputs(account): - yield utxo.get_estimator(account.coin) + utxos = yield account.ledger.get_unspent_outputs(account) + for utxo in utxos: + estimators.append(utxo.get_estimator(account.ledger)) + defer.returnValue(estimators) @classmethod def ensure_all_have_same_ledger(cls, funding_accounts, change_account=None): @@ -297,15 +303,17 @@ class BaseTransaction: return ledger @classmethod + @defer.inlineCallbacks def pay(cls, outputs, funding_accounts, change_account): + # type: (List[BaseOutput], List[BaseAccount], BaseAccount) -> BaseTransaction """ Efficiently spend utxos from funding_accounts to cover the new outputs. """ tx = cls().add_outputs(outputs) ledger = cls.ensure_all_have_same_ledger(funding_accounts, change_account) amount = ledger.get_transaction_base_fee(tx) + txos = yield cls.get_effective_amount_estimators(funding_accounts) selector = CoinSelector( - list(cls.get_effective_amount_estimators(funding_accounts)), - amount, + txos, amount, ledger.get_input_output_fee( cls.output_class.pay_pubkey_hash(COIN, NULL_HASH) ) @@ -317,14 +325,14 @@ class BaseTransaction: spent_sum = sum(s.effective_amount for s in spendables) if spent_sum > amount: - change_address = change_account.get_least_used_change_address() - change_hash160 = change_account.coin.address_to_hash160(change_address) + change_address = change_account.change.get_or_create_usable_address() + change_hash160 = change_account.ledger.address_to_hash160(change_address) change_amount = spent_sum - amount tx.add_outputs([cls.output_class.pay_pubkey_hash(change_amount, change_hash160)]) tx.add_inputs([s.txi for s in spendables]) tx.sign(funding_accounts) - return tx + defer.returnValue(tx) @classmethod def liquidate(cls, assets, funding_accounts, change_account): diff --git a/torba/coin/bitcoincash.py b/torba/coin/bitcoincash.py index f416ced41..1830f9cdb 100644 --- a/torba/coin/bitcoincash.py +++ b/torba/coin/bitcoincash.py @@ -1,6 +1,6 @@ -__coin__ = 'BitcoinCash' __node_daemon__ = 'bitcoind' __node_cli__ = 'bitcoin-cli' +__node_bin__ = 'bitcoin-abc-0.17.2/bin' __node_url__ = ( 'https://download.bitcoinabc.org/0.17.2/linux/bitcoin-abc-0.17.2-x86_64-linux-gnu.tar.gz' ) @@ -11,7 +11,6 @@ from torba.baseledger import BaseLedger, BaseHeaders from torba.basenetwork import BaseNetwork from torba.basescript import BaseInputScript, BaseOutputScript from torba.basetransaction import BaseTransaction, BaseInput, BaseOutput -from torba.basecoin import BaseCoin from torba.basedatabase import BaseSQLiteWalletStorage from torba.manager import BaseWalletManager diff --git a/torba/coin/bitcoinsegwit.py b/torba/coin/bitcoinsegwit.py index d38cc71c2..6739d87f2 100644 --- a/torba/coin/bitcoinsegwit.py +++ b/torba/coin/bitcoinsegwit.py @@ -1,6 +1,6 @@ -__coin__ = 'BitcoinSegwit' __node_daemon__ = 'bitcoind' __node_cli__ = 'bitcoin-cli' +__node_bin__ = 'bitcoin-0.16.0/bin' __node_url__ = ( 'https://bitcoin.org/bin/bitcoin-core-0.16.0/bitcoin-0.16.0-x86_64-linux-gnu.tar.gz' ) @@ -29,10 +29,15 @@ class UnverifiedHeaders(BaseHeaders): class RegTestLedger(MainNetLedger): - network_name = 'regtest' headers_class = UnverifiedHeaders + network_name = 'regtest' + + pubkey_address_prefix = int2byte(111) + script_address_prefix = int2byte(196) + extended_public_key_prefix = unhexlify('043587cf') + extended_private_key_prefix = unhexlify('04358394') + max_target = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff genesis_hash = '0f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206' genesis_bits = 0x207fffff target_timespan = 1 - verify_bits_to_target = False diff --git a/torba/manager.py b/torba/manager.py index 31b5ab107..104200352 100644 --- a/torba/manager.py +++ b/torba/manager.py @@ -2,8 +2,7 @@ import functools from typing import List, Dict, Type from twisted.internet import defer -from torba.baseaccount import AccountsView -from torba.baseledger import BaseLedger +from torba.baseledger import BaseLedger, LedgerRegistry from torba.basetransaction import BaseTransaction, NULL_HASH from torba.coinselection import CoinSelector from torba.constants import COIN @@ -29,9 +28,8 @@ class WalletManager(object): wallets.append(wallet) return manager - def get_or_create_ledger(self, coin_id, ledger_config=None): - coin_class = CoinRegistry.get_coin_class(coin_id) - ledger_class = coin_class.ledger_class + def get_or_create_ledger(self, ledger_id, ledger_config=None): + ledger_class = LedgerRegistry.get_ledger_class(ledger_id) ledger = self.ledgers.get(ledger_class) if ledger is None: ledger = self.create_ledger(ledger_class, ledger_config or {}) @@ -100,17 +98,16 @@ class WalletManager(object): amount = int(amount * COIN) account = self.default_account - coin = account.coin - ledger = coin.ledger + ledger = account.ledger tx_class = ledger.transaction_class # type: BaseTransaction in_class, out_class = tx_class.input_class, tx_class.output_class estimators = [ - txo.get_estimator(coin) for txo in account.get_unspent_utxos() + txo.get_estimator(ledger) for txo in account.get_unspent_utxos() ] tx_class.create() - cost_of_output = coin.get_input_output_fee( + cost_of_output = ledger.get_input_output_fee( out_class.pay_pubkey_hash(COIN, NULL_HASH) )