storage_db: fix tests, add modified flag to db class

This commit is contained in:
ThomasV 2019-02-28 11:55:15 +01:00
parent dbca0a0e83
commit d74f0c0947
5 changed files with 60 additions and 37 deletions

View file

@ -282,27 +282,20 @@ class AddressSynchronizer(PrintError):
def remove_transaction(self, tx_hash): def remove_transaction(self, tx_hash):
def remove_from_spent_outpoints(): def remove_from_spent_outpoints():
# undo spends in spent_outpoints # undo spends in spent_outpoints
if tx is not None: # if we have the tx, this branch is faster if tx is not None:
# if we have the tx, this branch is faster
for txin in tx.inputs(): for txin in tx.inputs():
if txin['type'] == 'coinbase': if txin['type'] == 'coinbase':
continue continue
prevout_hash = txin['prevout_hash'] prevout_hash = txin['prevout_hash']
prevout_n = txin['prevout_n'] prevout_n = txin['prevout_n']
self.spent_outpoints[prevout_hash].pop(prevout_n, None) # FIXME self.db.remove_spent_outpoint(prevout_hash, prevout_n)
if not self.spent_outpoints[prevout_hash]: else:
self.spent_outpoints.pop(prevout_hash) # expensive but always works
else: # expensive but always works for prevout_hash, prevout_n in list(self.db.list_spent_outpoints()):
for prevout_hash, d in list(self.spent_outpoints.items()): spending_txid = self.db.get_spent_outpoint(prevout_hash, prevout_n)
for prevout_n, spending_txid in d.items(): if spending_txid == tx_hash:
if spending_txid == tx_hash: self.db.remove_spent_outpoint(prevout_hash, prevout_n)
self.spent_outpoints[prevout_hash].pop(prevout_n, None)
if not self.spent_outpoints[prevout_hash]:
self.spent_outpoints.pop(prevout_hash)
# Remove this tx itself; if nothing spends from it.
# It is not so clear what to do if other txns spend from it, but it will be
# removed when those other txns are removed.
if not self.spent_outpoints[tx_hash]:
self.spent_outpoints.pop(tx_hash)
with self.transaction_lock: with self.transaction_lock:
self.print_error("removing tx from history", tx_hash) self.print_error("removing tx from history", tx_hash)

View file

@ -26,6 +26,7 @@ import os
import ast import ast
import json import json
import copy import copy
import threading
from collections import defaultdict from collections import defaultdict
from typing import Dict from typing import Dict
@ -45,7 +46,9 @@ FINAL_SEED_VERSION = 18 # electrum >= 2.7 will set this to prevent
class JsonDB(PrintError): class JsonDB(PrintError):
def __init__(self, raw, *, manual_upgrades): def __init__(self, raw, *, manual_upgrades):
self.lock = threading.RLock()
self.data = {} self.data = {}
self._modified = False
self.manual_upgrades = manual_upgrades self.manual_upgrades = manual_upgrades
if raw: if raw:
self.load_data(raw) self.load_data(raw)
@ -53,6 +56,20 @@ class JsonDB(PrintError):
self.put('seed_version', FINAL_SEED_VERSION) self.put('seed_version', FINAL_SEED_VERSION)
self.load_transactions() self.load_transactions()
def set_modified(self, b):
with self.lock:
self._modified = b
def modified(self):
return self._modified
def modifier(func):
def wrapper(self, *args, **kwargs):
with self.lock:
self._modified = True
return func(self, *args, **kwargs)
return wrapper
def get(self, key, default=None): def get(self, key, default=None):
v = self.data.get(key) v = self.data.get(key)
if v is None: if v is None:
@ -61,6 +78,7 @@ class JsonDB(PrintError):
v = copy.deepcopy(v) v = copy.deepcopy(v)
return v return v
@modifier
def put(self, key, value): def put(self, key, value):
try: try:
json.dumps(key, cls=util.MyEncoder) json.dumps(key, cls=util.MyEncoder)
@ -483,6 +501,7 @@ class JsonDB(PrintError):
def get_txo_addr(self, tx_hash, address): def get_txo_addr(self, tx_hash, address):
return self.txo.get(tx_hash, {}).get(address, []) return self.txo.get(tx_hash, {}).get(address, [])
@modifier
def add_txi_addr(self, tx_hash, addr, ser, v): def add_txi_addr(self, tx_hash, addr, ser, v):
if tx_hash not in self.txi: if tx_hash not in self.txi:
self.txi[tx_hash] = {} self.txi[tx_hash] = {}
@ -492,6 +511,7 @@ class JsonDB(PrintError):
d[addr] = set() d[addr] = set()
d[addr].add((ser, v)) d[addr].add((ser, v))
@modifier
def add_txo_addr(self, tx_hash, addr, n, v, is_coinbase): def add_txo_addr(self, tx_hash, addr, n, v, is_coinbase):
if tx_hash not in self.txo: if tx_hash not in self.txo:
self.txo[tx_hash] = {} self.txo[tx_hash] = {}
@ -507,26 +527,43 @@ class JsonDB(PrintError):
def get_txo_keys(self): def get_txo_keys(self):
return self.txo.keys() return self.txo.keys()
@modifier
def remove_txi(self, tx_hash): def remove_txi(self, tx_hash):
self.txi.pop(tx_hash, None) self.txi.pop(tx_hash, None)
@modifier
def remove_txo(self, tx_hash): def remove_txo(self, tx_hash):
self.txo.pop(tx_hash, None) self.txo.pop(tx_hash, None)
def list_spent_outpoints(self):
return [(h, n)
for h in self.spent_outpoints.keys()
for n in self.get_spent_outpoints(h)
]
def get_spent_outpoints(self, prevout_hash): def get_spent_outpoints(self, prevout_hash):
return self.spent_outpoints.get(prevout_hash, {}).keys() return self.spent_outpoints.get(prevout_hash, {}).keys()
def get_spent_outpoint(self, prevout_hash, prevout_n): def get_spent_outpoint(self, prevout_hash, prevout_n):
return self.spent_outpoints.get(prevout_hash, {}).get(str(prevout_n)) return self.spent_outpoints.get(prevout_hash, {}).get(str(prevout_n))
@modifier
def remove_spent_outpoint(self, prevout_hash, prevout_n):
self.spent_outpoints[prevout_hash].pop(prevout_n, None) # FIXME
if not self.spent_outpoints[prevout_hash]:
self.spent_outpoints.pop(prevout_hash)
@modifier
def set_spent_outpoint(self, prevout_hash, prevout_n, tx_hash): def set_spent_outpoint(self, prevout_hash, prevout_n, tx_hash):
if prevout_hash not in self.spent_outpoints: if prevout_hash not in self.spent_outpoints:
self.spent_outpoints[prevout_hash] = {} self.spent_outpoints[prevout_hash] = {}
self.spent_outpoints[prevout_hash][str(prevout_n)] = tx_hash self.spent_outpoints[prevout_hash][str(prevout_n)] = tx_hash
@modifier
def add_transaction(self, tx_hash, tx): def add_transaction(self, tx_hash, tx):
self.transactions[tx_hash] = str(tx) self.transactions[tx_hash] = str(tx)
@modifier
def remove_transaction(self, tx_hash): def remove_transaction(self, tx_hash):
self.transactions.pop(tx_hash, None) self.transactions.pop(tx_hash, None)
@ -543,9 +580,11 @@ class JsonDB(PrintError):
def get_addr_history(self, addr): def get_addr_history(self, addr):
return self.history.get(addr, []) return self.history.get(addr, [])
@modifier
def set_addr_history(self, addr, hist): def set_addr_history(self, addr, hist):
self.history[addr] = hist self.history[addr] = hist
@modifier
def remove_addr_history(self, addr): def remove_addr_history(self, addr):
self.history.pop(addr, None) self.history.pop(addr, None)
@ -562,18 +601,22 @@ class JsonDB(PrintError):
txpos=txpos, txpos=txpos,
header_hash=header_hash) header_hash=header_hash)
@modifier
def add_verified_tx(self, txid, info): def add_verified_tx(self, txid, info):
self.verified_tx[txid] = (info.height, info.timestamp, info.txpos, info.header_hash) self.verified_tx[txid] = (info.height, info.timestamp, info.txpos, info.header_hash)
@modifier
def remove_verified_tx(self, txid): def remove_verified_tx(self, txid):
self.verified_tx.pop(txid, None) self.verified_tx.pop(txid, None)
@modifier
def update_tx_fees(self, d): def update_tx_fees(self, d):
return self.tx_fees.update(d) return self.tx_fees.update(d)
def get_tx_fee(self, txid): def get_tx_fee(self, txid):
return self.tx_fees.get(txid) return self.tx_fees.get(txid)
@modifier
def remove_tx_fee(self, txid): def remove_tx_fee(self, txid):
self.tx_fees.pop(txid, None) self.tx_fees.pop(txid, None)

View file

@ -49,10 +49,9 @@ STO_EV_PLAINTEXT, STO_EV_USER_PW, STO_EV_XPUB_PW = range(0, 3)
class WalletStorage(PrintError): class WalletStorage(PrintError):
def __init__(self, path, *, manual_upgrades=False): def __init__(self, path, *, manual_upgrades=False):
self.db_lock = threading.RLock() self.lock = threading.RLock()
self.path = standardize_path(path) self.path = standardize_path(path)
self._file_exists = self.path and os.path.exists(self.path) self._file_exists = self.path and os.path.exists(self.path)
self.modified = False
DB_Class = JsonDB DB_Class = JsonDB
self.path = path self.path = path
@ -70,23 +69,21 @@ class WalletStorage(PrintError):
self.db = DB_Class('', manual_upgrades=False) self.db = DB_Class('', manual_upgrades=False)
def put(self, key,value): def put(self, key,value):
with self.db_lock: self.db.put(key, value)
self.modified |= self.db.put(key, value)
def get(self, key, default=None): def get(self, key, default=None):
with self.db_lock: return self.db.get(key, default)
return self.db.get(key, default)
@profiler @profiler
def write(self): def write(self):
with self.db_lock: with self.lock:
self._write() self._write()
def _write(self): def _write(self):
if threading.currentThread().isDaemon(): if threading.currentThread().isDaemon():
self.print_error('warning: daemon thread cannot write db') self.print_error('warning: daemon thread cannot write db')
return return
if not self.modified: if not self.db.modified():
return return
self.db.commit() self.db.commit()
s = self.encrypt_before_writing(self.db.dump()) s = self.encrypt_before_writing(self.db.dump())
@ -103,7 +100,7 @@ class WalletStorage(PrintError):
os.chmod(self.path, mode) os.chmod(self.path, mode)
self._file_exists = True self._file_exists = True
self.print_error("saved", self.path) self.print_error("saved", self.path)
self.modified = False self.db.set_modified(False)
def file_exists(self): def file_exists(self):
return self._file_exists return self._file_exists
@ -209,8 +206,7 @@ class WalletStorage(PrintError):
self.pubkey = None self.pubkey = None
self._encryption_version = STO_EV_PLAINTEXT self._encryption_version = STO_EV_PLAINTEXT
# make sure next storage.write() saves changes # make sure next storage.write() saves changes
with self.db_lock: self.db.set_modified(True)
self.modified = True
def requires_upgrade(self): def requires_upgrade(self):
return self.db.requires_upgrade() return self.db.requires_upgrade()

View file

@ -1719,7 +1719,7 @@ class TestWalletHistory_EvilGapLimit(TestCaseForTestnet):
w.storage.put('stored_height', 1316917 + 100) w.storage.put('stored_height', 1316917 + 100)
for txid in self.transactions: for txid in self.transactions:
tx = Transaction(self.transactions[txid]) tx = Transaction(self.transactions[txid])
w.transactions[tx.txid()] = tx w.add_transaction(tx.txid(), tx)
# txn A is an external incoming txn paying to addr (3) and (15) # txn A is an external incoming txn paying to addr (3) and (15)
# txn B is an external incoming txn paying to addr (4) and (25) # txn B is an external incoming txn paying to addr (4) and (25)
# txn C is an internal transfer txn from addr (25) -- to -- (1) and (25) # txn C is an internal transfer txn from addr (25) -- to -- (1) and (25)

View file

@ -1201,7 +1201,6 @@ class Abstract_Wallet(AddressSynchronizer):
self._update_password_for_keystore(old_pw, new_pw) self._update_password_for_keystore(old_pw, new_pw)
encrypt_keystore = self.can_have_keystore_encryption() encrypt_keystore = self.can_have_keystore_encryption()
self.storage.set_keystore_encryption(bool(new_pw) and encrypt_keystore) self.storage.set_keystore_encryption(bool(new_pw) and encrypt_keystore)
self.storage.write() self.storage.write()
def sign_message(self, address, message, password): def sign_message(self, address, message, password):
@ -1385,7 +1384,6 @@ class Imported_Wallet(Simple_Wallet):
self.addresses[address] = {} self.addresses[address] = {}
self.add_address(address) self.add_address(address)
self.save_addresses() self.save_addresses()
self.save_transactions(write=write_to_disk)
return good_addr, bad_addr return good_addr, bad_addr
def import_address(self, address: str) -> str: def import_address(self, address: str) -> str:
@ -1398,7 +1396,6 @@ class Imported_Wallet(Simple_Wallet):
def delete_address(self, address): def delete_address(self, address):
if address not in self.addresses: if address not in self.addresses:
return return
transactions_to_remove = set() # only referred to by this address transactions_to_remove = set() # only referred to by this address
transactions_new = set() # txs that are not only referred to by address transactions_new = set() # txs that are not only referred to by address
with self.lock: with self.lock:
@ -1412,20 +1409,15 @@ class Imported_Wallet(Simple_Wallet):
transactions_new.add(tx_hash) transactions_new.add(tx_hash)
transactions_to_remove -= transactions_new transactions_to_remove -= transactions_new
self.db.remove_history(address) self.db.remove_history(address)
for tx_hash in transactions_to_remove: for tx_hash in transactions_to_remove:
self.remove_transaction(tx_hash) self.remove_transaction(tx_hash)
self.db.remove_tx_fee(tx_hash) self.db.remove_tx_fee(tx_hash)
self.db.remove_verified_tx(tx_hash) self.db.remove_verified_tx(tx_hash)
self.unverified_tx.pop(tx_hash, None) self.unverified_tx.pop(tx_hash, None)
self.db.remove_transaction(tx_hash) self.db.remove_transaction(tx_hash)
self.save_verified_tx()
self.save_transactions()
self.set_label(address, None) self.set_label(address, None)
self.remove_payment_request(address, {}) self.remove_payment_request(address, {})
self.set_frozen_state([address], False) self.set_frozen_state([address], False)
pubkey = self.get_public_key(address) pubkey = self.get_public_key(address)
self.addresses.pop(address) self.addresses.pop(address)
if pubkey: if pubkey:
@ -1442,7 +1434,6 @@ class Imported_Wallet(Simple_Wallet):
self.keystore.delete_imported_key(pubkey) self.keystore.delete_imported_key(pubkey)
self.save_keystore() self.save_keystore()
self.save_addresses() self.save_addresses()
self.storage.write() self.storage.write()
def get_address_index(self, address): def get_address_index(self, address):