+ tx.position, + tx.net_account_balance, + txo.is_my_account

This commit is contained in:
Lex Berezhny 2018-09-25 18:02:50 -04:00
parent 5977f42a7e
commit c29b4c476d
4 changed files with 151 additions and 51 deletions

View file

@ -54,6 +54,50 @@ class TestSizeAndFeeEstimation(unittest.TestCase):
self.assertEqual(tx.get_base_fee(self.ledger), FEE_PER_BYTE * tx.base_size) self.assertEqual(tx.get_base_fee(self.ledger), FEE_PER_BYTE * tx.base_size)
class TestAccountBalanceImpactFromTransaction(unittest.TestCase):
def test_is_my_account_not_set(self):
tx = get_transaction()
with self.assertRaisesRegex(ValueError, "Cannot access net_account_balance"):
_ = tx.net_account_balance
tx.inputs[0].is_my_account = True
with self.assertRaisesRegex(ValueError, "Cannot access net_account_balance"):
_ = tx.net_account_balance
tx.outputs[0].is_my_account = True
# all inputs/outputs are set now so it should work
_ = tx.net_account_balance
def test_paying_from_my_account_to_other_account(self):
tx = ledger_class.transaction_class() \
.add_inputs([get_input(300*CENT)]) \
.add_outputs([get_output(190*CENT, NULL_HASH),
get_output(100*CENT, NULL_HASH)])
tx.inputs[0].is_my_account = True
tx.outputs[0].is_my_account = False
tx.outputs[1].is_my_account = True
self.assertEqual(tx.net_account_balance, -200*CENT)
def test_paying_from_other_account_to_my_account(self):
tx = ledger_class.transaction_class() \
.add_inputs([get_input(300*CENT)]) \
.add_outputs([get_output(190*CENT, NULL_HASH),
get_output(100*CENT, NULL_HASH)])
tx.inputs[0].is_my_account = False
tx.outputs[0].is_my_account = True
tx.outputs[1].is_my_account = False
self.assertEqual(tx.net_account_balance, 190*CENT)
def test_paying_from_my_account_to_my_account(self):
tx = ledger_class.transaction_class() \
.add_inputs([get_input(300*CENT)]) \
.add_outputs([get_output(190*CENT, NULL_HASH),
get_output(100*CENT, NULL_HASH)])
tx.inputs[0].is_my_account = True
tx.outputs[0].is_my_account = True
tx.outputs[1].is_my_account = True
self.assertEqual(tx.net_account_balance, -10*CENT) # lost to fee
class TestTransactionSerialization(unittest.TestCase): class TestTransactionSerialization(unittest.TestCase):
def test_genesis_transaction(self): def test_genesis_transaction(self):
@ -217,7 +261,7 @@ class TransactionIOBalancing(unittest.TestCase):
save_tx = 'insert' save_tx = 'insert'
for utxo in utxos: for utxo in utxos:
yield self.ledger.db.save_transaction_io( yield self.ledger.db.save_transaction_io(
save_tx, self.funding_tx, 1, True, save_tx, self.funding_tx, True,
self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']), self.ledger.hash160_to_address(utxo.script.values['pubkey_hash']),
utxo.script.values['pubkey_hash'], '' utxo.script.values['pubkey_hash'], ''
) )

View file

@ -143,6 +143,7 @@ class BaseDatabase(SQLiteMixin):
txid text primary key, txid text primary key,
raw blob not null, raw blob not null,
height integer not null, height integer not null,
position integer not null,
is_verified boolean not null default 0 is_verified boolean not null default 0
); );
""" """
@ -185,19 +186,20 @@ class BaseDatabase(SQLiteMixin):
'script': sqlite3.Binary(txo.script.source) 'script': sqlite3.Binary(txo.script.source)
} }
def save_transaction_io(self, save_tx, tx, height, is_verified, address, txhash, history): def save_transaction_io(self, save_tx, tx, is_verified, address, txhash, history):
def _steps(t): def _steps(t):
if save_tx == 'insert': if save_tx == 'insert':
self.execute(t, *self._insert_sql('tx', { self.execute(t, *self._insert_sql('tx', {
'txid': tx.id, 'txid': tx.id,
'raw': sqlite3.Binary(tx.raw), 'raw': sqlite3.Binary(tx.raw),
'height': height, 'height': tx.height,
'position': tx.position,
'is_verified': is_verified 'is_verified': is_verified
})) }))
elif save_tx == 'update': elif save_tx == 'update':
self.execute(t, *self._update_sql("tx", { self.execute(t, *self._update_sql("tx", {
'height': height, 'is_verified': is_verified 'height': tx.height, 'position': tx.position, 'is_verified': is_verified
}, 'txid = ?', (tx.id,))) }, 'txid = ?', (tx.id,)))
existing_txos = [r[0] for r in self.execute( existing_txos = [r[0] for r in self.execute(
@ -260,19 +262,19 @@ class BaseDatabase(SQLiteMixin):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transaction(self, txid): def get_transaction(self, txid):
result = yield self.run_query( result = yield self.run_query(
"SELECT raw, height, is_verified FROM tx WHERE txid = ?", (txid,) "SELECT raw, height, position, is_verified FROM tx WHERE txid = ?", (txid,)
) )
if result: if result:
return result[0] return result[0]
else: else:
return None, None, False return None, None, None, False
@defer.inlineCallbacks @defer.inlineCallbacks
def get_transactions(self, account, offset=0, limit=100): def get_transactions(self, account, offset=0, limit=100):
offset, limit = min(offset, 0), max(limit, 100) account_id = account.public_key.address
tx_rows = yield self.run_query( tx_rows = yield self.run_query(
""" """
SELECT txid, raw, height FROM tx WHERE txid IN ( SELECT txid, raw, height, position FROM tx WHERE txid IN (
SELECT txo.txid FROM txo SELECT txo.txid FROM txo
JOIN pubkey_address USING (address) JOIN pubkey_address USING (address)
WHERE pubkey_address.account = :account WHERE pubkey_address.account = :account
@ -281,47 +283,67 @@ class BaseDatabase(SQLiteMixin):
JOIN txo USING (txoid) JOIN txo USING (txoid)
JOIN pubkey_address USING (address) JOIN pubkey_address USING (address)
WHERE pubkey_address.account = :account WHERE pubkey_address.account = :account
) ORDER BY height DESC LIMIT :offset, :limit ) ORDER BY height DESC, position DESC LIMIT :offset, :limit
""", {'account': account.public_key.address, 'offset': offset, 'limit': limit} """, {
'account': account_id,
'offset': min(offset, 0),
'limit': max(limit, 100)
}
) )
txids, txs = [], [] txids, txs = [], []
for row in tx_rows: for row in tx_rows:
txids.append(row[0]) txids.append(row[0])
txs.append(account.ledger.transaction_class(raw=row[1], height=row[2])) txs.append(account.ledger.transaction_class(
raw=row[1], height=row[2], position=row[3]
))
txo_rows = yield self.run_query( txo_rows = yield self.run_query(
""" """
SELECT txoid, pubkey_address.chain SELECT txoid, chain, account
FROM txo JOIN pubkey_address USING (address) FROM txo JOIN pubkey_address USING (address)
WHERE txid IN ({}) WHERE txid IN ({})
""".format(', '.join(['?']*len(txids))), txids """.format(', '.join(['?']*len(txids))), txids
) )
txos = dict(txo_rows) txos = {}
for row in txo_rows:
txos[row[0]] = {
'is_change': row[1] == 1,
'is_my_account': row[2] == account_id
}
txi_rows = yield self.run_query( referenced_txo_rows = yield self.run_query(
""" """
SELECT txoid, txo.amount, txo.script, txo.txid, txo.position SELECT txoid, txo.amount, txo.script, txo.txid, txo.position, chain, account
FROM txi JOIN txo USING (txoid) FROM txi
JOIN txo USING (txoid)
JOIN pubkey_address USING (address)
WHERE txi.txid IN ({}) WHERE txi.txid IN ({})
""".format(', '.join(['?']*len(txids))), txids """.format(', '.join(['?']*len(txids))), txids
) )
txis = {} referenced_txos = {}
output_class = account.ledger.transaction_class.output_class output_class = account.ledger.transaction_class.output_class
for row in txi_rows: for row in referenced_txo_rows:
txis[row[0]] = output_class( referenced_txos[row[0]] = output_class(
row[1], amount=row[1],
output_class.script_class(row[2]), script=output_class.script_class(row[2]),
TXRefImmutable.from_id(row[3]), tx_ref=TXRefImmutable.from_id(row[3]),
position=row[4] position=row[4],
is_change=row[5] == 1,
is_my_account=row[6] == account_id
) )
for tx in txs: for tx in txs:
for txi in tx.inputs: for txi in tx.inputs:
if txi.txo_ref.id in txis: if txi.txo_ref.id in referenced_txos:
txi.txo_ref = TXORefResolvable(txis[txi.txo_ref.id]) txi.txo_ref = TXORefResolvable(referenced_txos[txi.txo_ref.id])
for txo in tx.outputs: for txo in tx.outputs:
if txo.id in txos: txo_meta = txos.get(txo.id)
txo.is_change = txos[txo.id] == 1 if txo_meta is not None:
txo.is_change = txo_meta['is_change']
txo.is_my_account = txo_meta['is_my_account']
else:
txo.is_change = False
txo.is_my_account = False
return txs return txs

View file

@ -40,7 +40,7 @@ class LedgerRegistry(type):
return mcs.ledgers[ledger_id] return mcs.ledgers[ledger_id]
class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'height', 'is_verified'))): class TransactionEvent(namedtuple('TransactionEvent', ('address', 'tx', 'is_verified'))):
pass pass
@ -87,7 +87,7 @@ class BaseLedger(metaclass=LedgerRegistry):
self.on_transaction.listen( self.on_transaction.listen(
lambda e: log.info( lambda e: log.info(
'(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s', '(%s) on_transaction: address=%s, height=%s, is_verified=%s, tx.id=%s',
self.get_id(), e.address, e.height, e.is_verified, e.tx.id self.get_id(), e.address, e.tx.height, e.is_verified, e.tx.id
) )
) )
@ -207,12 +207,13 @@ class BaseLedger(metaclass=LedgerRegistry):
return hexlify(working_branch[::-1]) return hexlify(working_branch[::-1])
@defer.inlineCallbacks @defer.inlineCallbacks
def is_valid_transaction(self, tx, height): def validate_transaction_and_set_position(self, tx, height):
if not height <= len(self.headers): if not height <= len(self.headers):
return False return False
merkle = yield self.network.get_merkle(tx.id, height) merkle = yield self.network.get_merkle(tx.id, height)
merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash) merkle_root = self.get_root_of_merkle_tree(merkle['merkle'], merkle['pos'], tx.hash)
header = self.headers[height] header = self.headers[height]
tx.position = merkle['pos']
return merkle_root == header['merkle_root'] return merkle_root == header['merkle_root']
@defer.inlineCallbacks @defer.inlineCallbacks
@ -365,23 +366,23 @@ class BaseLedger(metaclass=LedgerRegistry):
try: try:
# see if we have a local copy of transaction, otherwise fetch it from server # see if we have a local copy of transaction, otherwise fetch it from server
raw, _, is_verified = yield self.db.get_transaction(hex_id) raw, _, position, is_verified = yield self.db.get_transaction(hex_id)
save_tx = None save_tx = None
if raw is None: if raw is None:
_raw = yield self.network.get_transaction(hex_id) _raw = yield self.network.get_transaction(hex_id)
tx = self.transaction_class(unhexlify(_raw)) tx = self.transaction_class(unhexlify(_raw), height=remote_height)
save_tx = 'insert' save_tx = 'insert'
else: else:
tx = self.transaction_class(raw) tx = self.transaction_class(raw, height=remote_height)
if remote_height > 0 and not is_verified: if remote_height > 0 and (not is_verified or position is None):
is_verified = yield self.is_valid_transaction(tx, remote_height) is_verified = yield self.validate_transaction_and_set_position(tx, remote_height)
is_verified = 1 if is_verified else 0 is_verified = 1 if is_verified else 0
if save_tx is None: if save_tx is None:
save_tx = 'update' save_tx = 'update'
yield self.db.save_transaction_io( yield self.db.save_transaction_io(
save_tx, tx, remote_height, is_verified, address, self.address_to_hash160(address), save_tx, tx, is_verified, address, self.address_to_hash160(address),
''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history) ''.join('{}:{}:'.format(tx_id, tx_height) for tx_id, tx_height in synced_history)
) )
@ -390,7 +391,7 @@ class BaseLedger(metaclass=LedgerRegistry):
self.get_id(), hex_id, address, remote_height, is_verified self.get_id(), hex_id, address, remote_height, is_verified
) )
self._on_transaction_controller.add(TransactionEvent(address, tx, remote_height, is_verified)) self._on_transaction_controller.add(TransactionEvent(address, tx, is_verified))
except Exception: except Exception:
log.exception('Failed to synchronize transaction:') log.exception('Failed to synchronize transaction:')

View file

@ -137,6 +137,13 @@ class BaseInput(InputOutput):
raise ValueError('Cannot resolve output to get amount.') raise ValueError('Cannot resolve output to get amount.')
return self.txo_ref.txo.amount return self.txo_ref.txo.amount
@property
def is_my_account(self) -> int:
""" True if the output this input spends is yours. """
if self.txo_ref.txo is None:
raise ValueError('Cannot resolve output to determine ownership.')
return self.txo_ref.txo.is_my_account
@classmethod @classmethod
def deserialize_from(cls, stream): def deserialize_from(cls, stream):
tx_ref = TXRefImmutable.from_hash(stream.read(32)) tx_ref = TXRefImmutable.from_hash(stream.read(32))
@ -181,14 +188,17 @@ class BaseOutput(InputOutput):
script_class = BaseOutputScript script_class = BaseOutputScript
estimator_class = BaseOutputEffectiveAmountEstimator estimator_class = BaseOutputEffectiveAmountEstimator
__slots__ = 'amount', 'script', 'is_change' __slots__ = 'amount', 'script', 'is_change', 'is_my_account'
def __init__(self, amount: int, script: BaseOutputScript, def __init__(self, amount: int, script: BaseOutputScript,
tx_ref: TXRef = None, position: int = None) -> None: tx_ref: TXRef = None, position: int = None,
is_change: Optional[bool] = None, is_my_account: Optional[bool] = None
) -> None:
super().__init__(tx_ref, position) super().__init__(tx_ref, position)
self.amount = amount self.amount = amount
self.script = script self.script = script
self.is_change = None self.is_change = is_change
self.is_my_account = is_my_account
@property @property
def ref(self): def ref(self):
@ -227,14 +237,16 @@ class BaseTransaction:
input_class = BaseInput input_class = BaseInput
output_class = BaseOutput output_class = BaseOutput
def __init__(self, raw=None, version=1, locktime=0, height=None) -> None: def __init__(self, raw=None, version: int=1, locktime: int=0,
height: int=-1, position: int=-1) -> None:
self._raw = raw self._raw = raw
self.ref = TXRefMutable(self) self.ref = TXRefMutable(self)
self.version = version # type: int self.version = version
self.locktime = locktime # type: int self.locktime = locktime
self._inputs = [] # type: List[BaseInput] self._inputs: List[BaseInput] = []
self._outputs = [] # type: List[BaseOutput] self._outputs: List[BaseOutput] = []
self.height = height self.height = height
self.position = position
if raw is not None: if raw is not None:
self._deserialize() self._deserialize()
@ -257,11 +269,11 @@ class BaseTransaction:
self.ref.reset() self.ref.reset()
@property @property
def inputs(self): # type: () -> ReadOnlyList[BaseInput] def inputs(self) -> ReadOnlyList[BaseInput]:
return ReadOnlyList(self._inputs) return ReadOnlyList(self._inputs)
@property @property
def outputs(self): # type: () -> ReadOnlyList[BaseOutput] def outputs(self) -> ReadOnlyList[BaseOutput]:
return ReadOnlyList(self._outputs) return ReadOnlyList(self._outputs)
def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction': def _add(self, new_ios: Iterable[InputOutput], existing_ios: List) -> 'BaseTransaction':
@ -301,18 +313,39 @@ class BaseTransaction:
return sum(o.amount for o in self.outputs) return sum(o.amount for o in self.outputs)
@property @property
def fee(self): def net_account_balance(self) -> int:
balance = 0
for txi in self.inputs:
if txi.is_my_account is None:
raise ValueError(
"Cannot access net_account_balance if inputs/outputs do not "
"have is_my_account set properly."
)
elif txi.is_my_account:
balance -= txi.amount
for txo in self.outputs:
if txo.is_my_account is None:
raise ValueError(
"Cannot access net_account_balance if inputs/outputs do not "
"have is_my_account set properly."
)
elif txo.is_my_account:
balance += txo.amount
return balance
@property
def fee(self) -> int:
return self.input_sum - self.output_sum return self.input_sum - self.output_sum
def get_base_fee(self, ledger): def get_base_fee(self, ledger) -> int:
""" Fee for base tx excluding inputs and outputs. """ """ Fee for base tx excluding inputs and outputs. """
return self.base_size * ledger.fee_per_byte return self.base_size * ledger.fee_per_byte
def get_effective_input_sum(self, ledger): def get_effective_input_sum(self, ledger) -> int:
""" Sum of input values *minus* the cost involved to spend them. """ """ Sum of input values *minus* the cost involved to spend them. """
return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs) return sum(txi.amount - txi.get_fee(ledger) for txi in self._inputs)
def get_total_output_sum(self, ledger): def get_total_output_sum(self, ledger) -> int:
""" Sum of output values *plus* the cost involved to spend them. """ """ Sum of output values *plus* the cost involved to spend them. """
return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs) return sum(txo.amount + txo.get_fee(ledger) for txo in self._outputs)