mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-31 01:11:35 +00:00
storage_db: fix tests, add modified flag to db class
This commit is contained in:
parent
dbca0a0e83
commit
d74f0c0947
5 changed files with 60 additions and 37 deletions
|
@ -282,27 +282,20 @@ class AddressSynchronizer(PrintError):
|
|||
def remove_transaction(self, tx_hash):
|
||||
def remove_from_spent_outpoints():
|
||||
# undo spends in spent_outpoints
|
||||
if tx is not None: # if we have the tx, this branch is faster
|
||||
if tx is not None:
|
||||
# if we have the tx, this branch is faster
|
||||
for txin in tx.inputs():
|
||||
if txin['type'] == 'coinbase':
|
||||
continue
|
||||
prevout_hash = txin['prevout_hash']
|
||||
prevout_n = txin['prevout_n']
|
||||
self.spent_outpoints[prevout_hash].pop(prevout_n, None) # FIXME
|
||||
if not self.spent_outpoints[prevout_hash]:
|
||||
self.spent_outpoints.pop(prevout_hash)
|
||||
else: # expensive but always works
|
||||
for prevout_hash, d in list(self.spent_outpoints.items()):
|
||||
for prevout_n, spending_txid in d.items():
|
||||
if spending_txid == tx_hash:
|
||||
self.spent_outpoints[prevout_hash].pop(prevout_n, None)
|
||||
if not self.spent_outpoints[prevout_hash]:
|
||||
self.spent_outpoints.pop(prevout_hash)
|
||||
# Remove this tx itself; if nothing spends from it.
|
||||
# It is not so clear what to do if other txns spend from it, but it will be
|
||||
# removed when those other txns are removed.
|
||||
if not self.spent_outpoints[tx_hash]:
|
||||
self.spent_outpoints.pop(tx_hash)
|
||||
self.db.remove_spent_outpoint(prevout_hash, prevout_n)
|
||||
else:
|
||||
# expensive but always works
|
||||
for prevout_hash, prevout_n in list(self.db.list_spent_outpoints()):
|
||||
spending_txid = self.db.get_spent_outpoint(prevout_hash, prevout_n)
|
||||
if spending_txid == tx_hash:
|
||||
self.db.remove_spent_outpoint(prevout_hash, prevout_n)
|
||||
|
||||
with self.transaction_lock:
|
||||
self.print_error("removing tx from history", tx_hash)
|
||||
|
|
|
@ -26,6 +26,7 @@ import os
|
|||
import ast
|
||||
import json
|
||||
import copy
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
|
@ -45,7 +46,9 @@ FINAL_SEED_VERSION = 18 # electrum >= 2.7 will set this to prevent
|
|||
class JsonDB(PrintError):
|
||||
|
||||
def __init__(self, raw, *, manual_upgrades):
|
||||
self.lock = threading.RLock()
|
||||
self.data = {}
|
||||
self._modified = False
|
||||
self.manual_upgrades = manual_upgrades
|
||||
if raw:
|
||||
self.load_data(raw)
|
||||
|
@ -53,6 +56,20 @@ class JsonDB(PrintError):
|
|||
self.put('seed_version', FINAL_SEED_VERSION)
|
||||
self.load_transactions()
|
||||
|
||||
def set_modified(self, b):
|
||||
with self.lock:
|
||||
self._modified = b
|
||||
|
||||
def modified(self):
|
||||
return self._modified
|
||||
|
||||
def modifier(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.lock:
|
||||
self._modified = True
|
||||
return func(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def get(self, key, default=None):
|
||||
v = self.data.get(key)
|
||||
if v is None:
|
||||
|
@ -61,6 +78,7 @@ class JsonDB(PrintError):
|
|||
v = copy.deepcopy(v)
|
||||
return v
|
||||
|
||||
@modifier
|
||||
def put(self, key, value):
|
||||
try:
|
||||
json.dumps(key, cls=util.MyEncoder)
|
||||
|
@ -483,6 +501,7 @@ class JsonDB(PrintError):
|
|||
def get_txo_addr(self, tx_hash, address):
|
||||
return self.txo.get(tx_hash, {}).get(address, [])
|
||||
|
||||
@modifier
|
||||
def add_txi_addr(self, tx_hash, addr, ser, v):
|
||||
if tx_hash not in self.txi:
|
||||
self.txi[tx_hash] = {}
|
||||
|
@ -492,6 +511,7 @@ class JsonDB(PrintError):
|
|||
d[addr] = set()
|
||||
d[addr].add((ser, v))
|
||||
|
||||
@modifier
|
||||
def add_txo_addr(self, tx_hash, addr, n, v, is_coinbase):
|
||||
if tx_hash not in self.txo:
|
||||
self.txo[tx_hash] = {}
|
||||
|
@ -507,26 +527,43 @@ class JsonDB(PrintError):
|
|||
def get_txo_keys(self):
|
||||
return self.txo.keys()
|
||||
|
||||
@modifier
|
||||
def remove_txi(self, tx_hash):
|
||||
self.txi.pop(tx_hash, None)
|
||||
|
||||
@modifier
|
||||
def remove_txo(self, tx_hash):
|
||||
self.txo.pop(tx_hash, None)
|
||||
|
||||
def list_spent_outpoints(self):
|
||||
return [(h, n)
|
||||
for h in self.spent_outpoints.keys()
|
||||
for n in self.get_spent_outpoints(h)
|
||||
]
|
||||
|
||||
def get_spent_outpoints(self, prevout_hash):
|
||||
return self.spent_outpoints.get(prevout_hash, {}).keys()
|
||||
|
||||
def get_spent_outpoint(self, prevout_hash, prevout_n):
|
||||
return self.spent_outpoints.get(prevout_hash, {}).get(str(prevout_n))
|
||||
|
||||
@modifier
|
||||
def remove_spent_outpoint(self, prevout_hash, prevout_n):
|
||||
self.spent_outpoints[prevout_hash].pop(prevout_n, None) # FIXME
|
||||
if not self.spent_outpoints[prevout_hash]:
|
||||
self.spent_outpoints.pop(prevout_hash)
|
||||
|
||||
@modifier
|
||||
def set_spent_outpoint(self, prevout_hash, prevout_n, tx_hash):
|
||||
if prevout_hash not in self.spent_outpoints:
|
||||
self.spent_outpoints[prevout_hash] = {}
|
||||
self.spent_outpoints[prevout_hash][str(prevout_n)] = tx_hash
|
||||
|
||||
@modifier
|
||||
def add_transaction(self, tx_hash, tx):
|
||||
self.transactions[tx_hash] = str(tx)
|
||||
|
||||
@modifier
|
||||
def remove_transaction(self, tx_hash):
|
||||
self.transactions.pop(tx_hash, None)
|
||||
|
||||
|
@ -543,9 +580,11 @@ class JsonDB(PrintError):
|
|||
def get_addr_history(self, addr):
|
||||
return self.history.get(addr, [])
|
||||
|
||||
@modifier
|
||||
def set_addr_history(self, addr, hist):
|
||||
self.history[addr] = hist
|
||||
|
||||
@modifier
|
||||
def remove_addr_history(self, addr):
|
||||
self.history.pop(addr, None)
|
||||
|
||||
|
@ -562,18 +601,22 @@ class JsonDB(PrintError):
|
|||
txpos=txpos,
|
||||
header_hash=header_hash)
|
||||
|
||||
@modifier
|
||||
def add_verified_tx(self, txid, info):
|
||||
self.verified_tx[txid] = (info.height, info.timestamp, info.txpos, info.header_hash)
|
||||
|
||||
@modifier
|
||||
def remove_verified_tx(self, txid):
|
||||
self.verified_tx.pop(txid, None)
|
||||
|
||||
@modifier
|
||||
def update_tx_fees(self, d):
|
||||
return self.tx_fees.update(d)
|
||||
|
||||
def get_tx_fee(self, txid):
|
||||
return self.tx_fees.get(txid)
|
||||
|
||||
@modifier
|
||||
def remove_tx_fee(self, txid):
|
||||
self.tx_fees.pop(txid, None)
|
||||
|
||||
|
|
|
@ -49,10 +49,9 @@ STO_EV_PLAINTEXT, STO_EV_USER_PW, STO_EV_XPUB_PW = range(0, 3)
|
|||
class WalletStorage(PrintError):
|
||||
|
||||
def __init__(self, path, *, manual_upgrades=False):
|
||||
self.db_lock = threading.RLock()
|
||||
self.lock = threading.RLock()
|
||||
self.path = standardize_path(path)
|
||||
self._file_exists = self.path and os.path.exists(self.path)
|
||||
self.modified = False
|
||||
|
||||
DB_Class = JsonDB
|
||||
self.path = path
|
||||
|
@ -70,23 +69,21 @@ class WalletStorage(PrintError):
|
|||
self.db = DB_Class('', manual_upgrades=False)
|
||||
|
||||
def put(self, key,value):
|
||||
with self.db_lock:
|
||||
self.modified |= self.db.put(key, value)
|
||||
self.db.put(key, value)
|
||||
|
||||
def get(self, key, default=None):
|
||||
with self.db_lock:
|
||||
return self.db.get(key, default)
|
||||
return self.db.get(key, default)
|
||||
|
||||
@profiler
|
||||
def write(self):
|
||||
with self.db_lock:
|
||||
with self.lock:
|
||||
self._write()
|
||||
|
||||
def _write(self):
|
||||
if threading.currentThread().isDaemon():
|
||||
self.print_error('warning: daemon thread cannot write db')
|
||||
return
|
||||
if not self.modified:
|
||||
if not self.db.modified():
|
||||
return
|
||||
self.db.commit()
|
||||
s = self.encrypt_before_writing(self.db.dump())
|
||||
|
@ -103,7 +100,7 @@ class WalletStorage(PrintError):
|
|||
os.chmod(self.path, mode)
|
||||
self._file_exists = True
|
||||
self.print_error("saved", self.path)
|
||||
self.modified = False
|
||||
self.db.set_modified(False)
|
||||
|
||||
def file_exists(self):
|
||||
return self._file_exists
|
||||
|
@ -209,8 +206,7 @@ class WalletStorage(PrintError):
|
|||
self.pubkey = None
|
||||
self._encryption_version = STO_EV_PLAINTEXT
|
||||
# make sure next storage.write() saves changes
|
||||
with self.db_lock:
|
||||
self.modified = True
|
||||
self.db.set_modified(True)
|
||||
|
||||
def requires_upgrade(self):
|
||||
return self.db.requires_upgrade()
|
||||
|
|
|
@ -1719,7 +1719,7 @@ class TestWalletHistory_EvilGapLimit(TestCaseForTestnet):
|
|||
w.storage.put('stored_height', 1316917 + 100)
|
||||
for txid in self.transactions:
|
||||
tx = Transaction(self.transactions[txid])
|
||||
w.transactions[tx.txid()] = tx
|
||||
w.add_transaction(tx.txid(), tx)
|
||||
# txn A is an external incoming txn paying to addr (3) and (15)
|
||||
# txn B is an external incoming txn paying to addr (4) and (25)
|
||||
# txn C is an internal transfer txn from addr (25) -- to -- (1) and (25)
|
||||
|
|
|
@ -1201,7 +1201,6 @@ class Abstract_Wallet(AddressSynchronizer):
|
|||
self._update_password_for_keystore(old_pw, new_pw)
|
||||
encrypt_keystore = self.can_have_keystore_encryption()
|
||||
self.storage.set_keystore_encryption(bool(new_pw) and encrypt_keystore)
|
||||
|
||||
self.storage.write()
|
||||
|
||||
def sign_message(self, address, message, password):
|
||||
|
@ -1385,7 +1384,6 @@ class Imported_Wallet(Simple_Wallet):
|
|||
self.addresses[address] = {}
|
||||
self.add_address(address)
|
||||
self.save_addresses()
|
||||
self.save_transactions(write=write_to_disk)
|
||||
return good_addr, bad_addr
|
||||
|
||||
def import_address(self, address: str) -> str:
|
||||
|
@ -1398,7 +1396,6 @@ class Imported_Wallet(Simple_Wallet):
|
|||
def delete_address(self, address):
|
||||
if address not in self.addresses:
|
||||
return
|
||||
|
||||
transactions_to_remove = set() # only referred to by this address
|
||||
transactions_new = set() # txs that are not only referred to by address
|
||||
with self.lock:
|
||||
|
@ -1412,20 +1409,15 @@ class Imported_Wallet(Simple_Wallet):
|
|||
transactions_new.add(tx_hash)
|
||||
transactions_to_remove -= transactions_new
|
||||
self.db.remove_history(address)
|
||||
|
||||
for tx_hash in transactions_to_remove:
|
||||
self.remove_transaction(tx_hash)
|
||||
self.db.remove_tx_fee(tx_hash)
|
||||
self.db.remove_verified_tx(tx_hash)
|
||||
self.unverified_tx.pop(tx_hash, None)
|
||||
self.db.remove_transaction(tx_hash)
|
||||
self.save_verified_tx()
|
||||
self.save_transactions()
|
||||
|
||||
self.set_label(address, None)
|
||||
self.remove_payment_request(address, {})
|
||||
self.set_frozen_state([address], False)
|
||||
|
||||
pubkey = self.get_public_key(address)
|
||||
self.addresses.pop(address)
|
||||
if pubkey:
|
||||
|
@ -1442,7 +1434,6 @@ class Imported_Wallet(Simple_Wallet):
|
|||
self.keystore.delete_imported_key(pubkey)
|
||||
self.save_keystore()
|
||||
self.save_addresses()
|
||||
|
||||
self.storage.write()
|
||||
|
||||
def get_address_index(self, address):
|
||||
|
|
Loading…
Add table
Reference in a new issue