mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-09-04 21:05:11 +00:00
storage: factor out 'JsonDB'
This commit is contained in:
parent
d2abaf54e8
commit
53130da682
1 changed files with 86 additions and 71 deletions
|
@ -67,15 +67,84 @@ def get_derivation_used_for_hw_device_encryption():
|
|||
# storage encryption version
|
||||
STO_EV_PLAINTEXT, STO_EV_USER_PW, STO_EV_XPUB_PW = range(0, 3)
|
||||
|
||||
class WalletStorage(PrintError):
|
||||
|
||||
def __init__(self, path, manual_upgrades=False):
|
||||
self.print_error("wallet path", path)
|
||||
self.manual_upgrades = manual_upgrades
|
||||
self.lock = threading.RLock()
|
||||
class JsonDB(PrintError):
|
||||
|
||||
def __init__(self, path):
|
||||
self.db_lock = threading.RLock()
|
||||
self.data = {}
|
||||
self.path = path
|
||||
self.modified = False
|
||||
|
||||
def get(self, key, default=None):
|
||||
with self.db_lock:
|
||||
v = self.data.get(key)
|
||||
if v is None:
|
||||
v = default
|
||||
else:
|
||||
v = copy.deepcopy(v)
|
||||
return v
|
||||
|
||||
def put(self, key, value):
|
||||
try:
|
||||
json.dumps(key, cls=util.MyEncoder)
|
||||
json.dumps(value, cls=util.MyEncoder)
|
||||
except:
|
||||
self.print_error("json error: cannot save", key)
|
||||
return
|
||||
with self.db_lock:
|
||||
if value is not None:
|
||||
if self.data.get(key) != value:
|
||||
self.modified = True
|
||||
self.data[key] = copy.deepcopy(value)
|
||||
elif key in self.data:
|
||||
self.modified = True
|
||||
self.data.pop(key)
|
||||
|
||||
@profiler
|
||||
def write(self):
|
||||
with self.db_lock:
|
||||
self._write()
|
||||
|
||||
def _write(self):
|
||||
if threading.currentThread().isDaemon():
|
||||
self.print_error('warning: daemon thread cannot write db')
|
||||
return
|
||||
if not self.modified:
|
||||
return
|
||||
s = json.dumps(self.data, indent=4, sort_keys=True, cls=util.MyEncoder)
|
||||
s = self.encrypt_before_writing(s)
|
||||
|
||||
temp_path = "%s.tmp.%s" % (self.path, os.getpid())
|
||||
with open(temp_path, "w", encoding='utf-8') as f:
|
||||
f.write(s)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
mode = os.stat(self.path).st_mode if os.path.exists(self.path) else stat.S_IREAD | stat.S_IWRITE
|
||||
# perform atomic write on POSIX systems
|
||||
try:
|
||||
os.rename(temp_path, self.path)
|
||||
except:
|
||||
os.remove(self.path)
|
||||
os.rename(temp_path, self.path)
|
||||
os.chmod(self.path, mode)
|
||||
self.print_error("saved", self.path)
|
||||
self.modified = False
|
||||
|
||||
def encrypt_before_writing(self, plaintext: str) -> str:
|
||||
return plaintext
|
||||
|
||||
def file_exists(self):
|
||||
return self.path and os.path.exists(self.path)
|
||||
|
||||
|
||||
class WalletStorage(JsonDB):
|
||||
|
||||
def __init__(self, path, manual_upgrades=False):
|
||||
self.print_error("wallet path", path)
|
||||
JsonDB.__init__(self, path)
|
||||
self.manual_upgrades = manual_upgrades
|
||||
self.pubkey = None
|
||||
if self.file_exists():
|
||||
with open(self.path, "r", encoding='utf-8') as f:
|
||||
|
@ -160,9 +229,6 @@ class WalletStorage(PrintError):
|
|||
except:
|
||||
return STO_EV_PLAINTEXT
|
||||
|
||||
def file_exists(self):
|
||||
return self.path and os.path.exists(self.path)
|
||||
|
||||
@staticmethod
|
||||
def get_eckey_from_password(password):
|
||||
secret = hashlib.pbkdf2_hmac('sha512', password.encode('utf-8'), b'', iterations=1024)
|
||||
|
@ -189,6 +255,17 @@ class WalletStorage(PrintError):
|
|||
s = s.decode('utf8')
|
||||
self.load_data(s)
|
||||
|
||||
def encrypt_before_writing(self, plaintext: str) -> str:
|
||||
s = plaintext
|
||||
if self.pubkey:
|
||||
s = bytes(s, 'utf8')
|
||||
c = zlib.compress(s)
|
||||
enc_magic = self._get_encryption_magic()
|
||||
public_key = ecc.ECPubkey(bfh(self.pubkey))
|
||||
s = public_key.encrypt_message(c, enc_magic)
|
||||
s = s.decode('utf8')
|
||||
return s
|
||||
|
||||
def check_password(self, password):
|
||||
"""Raises an InvalidPassword exception on invalid password"""
|
||||
if not self.is_encrypted():
|
||||
|
@ -211,71 +288,9 @@ class WalletStorage(PrintError):
|
|||
self.pubkey = None
|
||||
self._encryption_version = STO_EV_PLAINTEXT
|
||||
# make sure next storage.write() saves changes
|
||||
with self.lock:
|
||||
with self.db_lock:
|
||||
self.modified = True
|
||||
|
||||
def get(self, key, default=None):
|
||||
with self.lock:
|
||||
v = self.data.get(key)
|
||||
if v is None:
|
||||
v = default
|
||||
else:
|
||||
v = copy.deepcopy(v)
|
||||
return v
|
||||
|
||||
def put(self, key, value):
|
||||
try:
|
||||
json.dumps(key, cls=util.MyEncoder)
|
||||
json.dumps(value, cls=util.MyEncoder)
|
||||
except:
|
||||
self.print_error("json error: cannot save", key)
|
||||
return
|
||||
with self.lock:
|
||||
if value is not None:
|
||||
if self.data.get(key) != value:
|
||||
self.modified = True
|
||||
self.data[key] = copy.deepcopy(value)
|
||||
elif key in self.data:
|
||||
self.modified = True
|
||||
self.data.pop(key)
|
||||
|
||||
@profiler
|
||||
def write(self):
|
||||
with self.lock:
|
||||
self._write()
|
||||
|
||||
def _write(self):
|
||||
if threading.currentThread().isDaemon():
|
||||
self.print_error('warning: daemon thread cannot write wallet')
|
||||
return
|
||||
if not self.modified:
|
||||
return
|
||||
s = json.dumps(self.data, indent=4, sort_keys=True, cls=util.MyEncoder)
|
||||
if self.pubkey:
|
||||
s = bytes(s, 'utf8')
|
||||
c = zlib.compress(s)
|
||||
enc_magic = self._get_encryption_magic()
|
||||
public_key = ecc.ECPubkey(bfh(self.pubkey))
|
||||
s = public_key.encrypt_message(c, enc_magic)
|
||||
s = s.decode('utf8')
|
||||
|
||||
temp_path = "%s.tmp.%s" % (self.path, os.getpid())
|
||||
with open(temp_path, "w", encoding='utf-8') as f:
|
||||
f.write(s)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
|
||||
mode = os.stat(self.path).st_mode if os.path.exists(self.path) else stat.S_IREAD | stat.S_IWRITE
|
||||
# perform atomic write on POSIX systems
|
||||
try:
|
||||
os.rename(temp_path, self.path)
|
||||
except:
|
||||
os.remove(self.path)
|
||||
os.rename(temp_path, self.path)
|
||||
os.chmod(self.path, mode)
|
||||
self.print_error("saved", self.path)
|
||||
self.modified = False
|
||||
|
||||
def requires_split(self):
|
||||
d = self.get('accounts', {})
|
||||
return len(d) > 1
|
||||
|
|
Loading…
Add table
Reference in a new issue