wallet: make labels private, and access to need lock

e.g. labels plugin iterated over wallet.labels on asyncio thread while user could trigger an edit from Qt thread
This commit is contained in:
SomberNight 2020-10-13 18:30:24 +02:00
parent da4f11dbd3
commit 4b6c86ecbe
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
13 changed files with 71 additions and 54 deletions

View file

@ -723,7 +723,7 @@ class Commands:
if balance:
item += (format_satoshis(sum(wallet.get_addr_balance(addr))),)
if labels:
item += (repr(wallet.labels.get(addr, '')),)
item += (repr(wallet.get_label(addr)),)
out.append(item)
return out

View file

@ -240,7 +240,7 @@ class AddressesDialog(Factory.Popup):
n = 0
cards = []
for address in _list:
label = wallet.labels.get(address, '')
label = wallet.get_label(address)
balance = sum(wallet.get_addr_balance(address))
is_used_and_empty = wallet.is_used(address) and balance == 0
if self.show_used == 1 and (balance or is_used_and_empty):

View file

@ -297,7 +297,7 @@ class TxDialog(Factory.Popup):
def label_dialog(self):
from .label_dialog import LabelDialog
key = self.tx.txid()
text = self.app.wallet.get_label(key)
text = self.app.wallet.get_label_for_txid(key)
def callback(text):
self.app.wallet.set_label(key, text)
self.update()

View file

@ -160,7 +160,7 @@ class AddressList(MyTreeView):
addresses_beyond_gap_limit = self.wallet.get_all_known_addresses_beyond_gap_limit()
for address in addr_list:
num = self.wallet.get_address_history_len(address)
label = self.wallet.labels.get(address, '')
label = self.wallet.get_label(address)
c, u, x = self.wallet.get_addr_balance(address)
balance = c + u + x
is_used_and_empty = self.wallet.is_used(address) and balance == 0

View file

@ -237,7 +237,7 @@ class HistoryModel(CustomModel, Logger):
def update_label(self, index):
tx_item = index.internalPointer().get_data()
tx_item['label'] = self.parent.wallet.get_label(get_item_key(tx_item))
tx_item['label'] = self.parent.wallet.get_label_for_txid(get_item_key(tx_item))
topLeft = bottomRight = self.createIndex(index.row(), HistoryColumns.DESCRIPTION)
self.dataChanged.emit(topLeft, bottomRight, [Qt.DisplayRole])
self.parent.utxo_list.update()
@ -633,7 +633,7 @@ class HistoryList(MyTreeView, AcceptFileDragDrop):
def show_transaction(self, tx_item, tx):
tx_hash = tx_item['txid']
label = self.wallet.get_label(tx_hash) or None # prefer 'None' if not defined (force tx dialog to hide Description field if missing)
label = self.wallet.get_label_for_txid(tx_hash) or None # prefer 'None' if not defined (force tx dialog to hide Description field if missing)
self.parent.show_transaction(tx, tx_desc=label)
def add_copy_menu(self, menu, idx):

View file

@ -3188,7 +3188,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger):
if fee is None:
self.show_error(_("Can't bump fee: unknown fee for original transaction."))
return
tx_label = self.wallet.get_label(txid)
tx_label = self.wallet.get_label_for_txid(txid)
tx_size = tx.estimated_size()
old_fee_rate = fee / tx_size # sat/vbyte
d = WindowModalDialog(self, _('Bump Fee'))

View file

@ -98,7 +98,7 @@ class UTXOList(MyTreeView):
name = utxo.prevout.to_str()
name_short = utxo.prevout.txid.hex()[:16] + '...' + ":%d" % utxo.prevout.out_idx
self._utxo_dict[name] = utxo
label = self.wallet.get_label(utxo.prevout.txid.hex())
label = self.wallet.get_label_for_txid(utxo.prevout.txid.hex())
amount = self.parent.format_amount(utxo.value_sats(), whitespaces=True)
labels = [name_short, address, label, amount, '%d'%height]
utxo_item = [QStandardItem(x) for x in labels]
@ -178,7 +178,7 @@ class UTXOList(MyTreeView):
# "Details"
tx = self.wallet.db.get_transaction(txid)
if tx:
label = self.wallet.get_label(txid) or None # Prefer None if empty (None hides the Description: field in the window)
label = self.wallet.get_label_for_txid(txid) or None # Prefer None if empty (None hides the Description: field in the window)
menu.addAction(_("Details"), lambda: self.parent.show_transaction(tx, tx_desc=label))
# "Copy ..."
idx = self.indexAt(position)

View file

@ -108,7 +108,7 @@ class ElectrumGui:
else:
time_str = 'unconfirmed'
label = self.wallet.get_label(hist_item.txid)
label = self.wallet.get_label_for_txid(hist_item.txid)
messages.append(format_str % (time_str, label, format_satoshis(delta, whitespaces=True),
format_satoshis(hist_item.balance, whitespaces=True)))
@ -140,7 +140,7 @@ class ElectrumGui:
self.print_list(messages, "%19s %25s "%("Key", "Value"))
def print_addresses(self):
messages = map(lambda addr: "%30s %30s "%(addr, self.wallet.labels.get(addr,"")), self.wallet.get_addresses())
messages = map(lambda addr: "%30s %30s "%(addr, self.wallet.get_label(addr)), self.wallet.get_addresses())
self.print_list(messages, "%19s %25s "%("Address", "Label"))
def print_order(self):
@ -209,7 +209,7 @@ class ElectrumGui:
return
if self.str_description:
self.wallet.labels[tx.txid()] = self.str_description
self.wallet.set_label(tx.txid(), self.str_description)
print(_("Please wait..."))
try:

View file

@ -136,7 +136,7 @@ class ElectrumGui:
else:
time_str = 'unconfirmed'
label = self.wallet.get_label(hist_item.txid)
label = self.wallet.get_label_for_txid(hist_item.txid)
if len(label) > 40:
label = label[0:37] + '...'
self.history.append(format_str % (time_str, label, format_satoshis(hist_item.delta, whitespaces=True),
@ -177,7 +177,7 @@ class ElectrumGui:
def print_addresses(self):
fmt = "%-35s %-30s"
messages = map(lambda addr: fmt % (addr, self.wallet.labels.get(addr,"")), self.wallet.get_addresses())
messages = map(lambda addr: fmt % (addr, self.wallet.get_label(addr)), self.wallet.get_addresses())
self.print_list(messages, fmt % ("Address", "Label"))
def print_edit_line(self, y, label, text, index, size):
@ -314,7 +314,7 @@ class ElectrumGui:
elif out == "Edit label":
s = self.get_string(6 + self.pos, 18)
if s:
self.wallet.labels[key] = s
self.wallet.set_label(key, s)
def run_banner_tab(self, c):
self.show_message(repr(c))
@ -379,7 +379,7 @@ class ElectrumGui:
return
if self.str_description:
self.wallet.labels[tx.txid()] = self.str_description
self.wallet.set_label(tx.txid(), self.str_description)
self.show_message(_("Please wait..."), getchar=False)
try:

View file

@ -678,7 +678,7 @@ class LNWallet(LNWorker):
item = {
'channel_id': bh2u(chan.channel_id),
'type': 'channel_opening',
'label': self.wallet.get_label(funding_txid) or (_('Open channel') + ' ' + chan.get_id_for_log()),
'label': self.wallet.get_label_for_txid(funding_txid) or (_('Open channel') + ' ' + chan.get_id_for_log()),
'txid': funding_txid,
'amount_msat': chan.balance(LOCAL, ctn=0),
'direction': 'received',
@ -693,7 +693,7 @@ class LNWallet(LNWorker):
item = {
'channel_id': bh2u(chan.channel_id),
'txid': closing_txid,
'label': self.wallet.get_label(closing_txid) or (_('Close channel') + ' ' + chan.get_id_for_log()),
'label': self.wallet.get_label_for_txid(closing_txid) or (_('Close channel') + ' ' + chan.get_id_for_log()),
'type': 'channel_closure',
'amount_msat': -chan.balance_minus_outgoing_htlcs(LOCAL),
'direction': 'sent',
@ -724,7 +724,7 @@ class LNWallet(LNWorker):
'amount_msat': 0,
#'amount_msat': amount_msat, # must not be added
'type': 'swap',
'label': self.wallet.get_label(txid) or label,
'label': self.wallet.get_label_for_txid(txid) or label,
}
return out

View file

@ -36,18 +36,18 @@ class LabelsPlugin(BasePlugin):
self.target_host = 'labels.electrum.org'
self.wallets = {}
def encode(self, wallet, msg):
def encode(self, wallet: 'Abstract_Wallet', msg: str) -> str:
password, iv, wallet_id = self.wallets[wallet]
encrypted = aes_encrypt_with_iv(password, iv, msg.encode('utf8'))
return base64.b64encode(encrypted).decode()
def decode(self, wallet, message):
def decode(self, wallet: 'Abstract_Wallet', message: str) -> str:
password, iv, wallet_id = self.wallets[wallet]
decoded = base64.b64decode(message)
decrypted = aes_decrypt_with_iv(password, iv, decoded)
return decrypted.decode('utf8')
def get_nonce(self, wallet):
def get_nonce(self, wallet: 'Abstract_Wallet'):
# nonce is the nonce to be used with the next change
nonce = wallet.db.get('wallet_nonce')
if nonce is None:
@ -55,12 +55,12 @@ class LabelsPlugin(BasePlugin):
self.set_nonce(wallet, nonce)
return nonce
def set_nonce(self, wallet, nonce):
def set_nonce(self, wallet: 'Abstract_Wallet', nonce):
self.logger.info(f"set {wallet.basename()} nonce to {nonce}")
wallet.db.put("wallet_nonce", nonce)
@hook
def set_label(self, wallet, item, label):
def set_label(self, wallet: 'Abstract_Wallet', item, label):
if wallet not in self.wallets:
return
if not item:
@ -99,7 +99,7 @@ class LabelsPlugin(BasePlugin):
except Exception as e:
raise Exception('Could not decode: ' + await result.text()) from e
async def push_thread(self, wallet):
async def push_thread(self, wallet: 'Abstract_Wallet'):
wallet_data = self.wallets.get(wallet, None)
if not wallet_data:
raise Exception('Wallet {} not loaded'.format(wallet))
@ -107,7 +107,7 @@ class LabelsPlugin(BasePlugin):
bundle = {"labels": [],
"walletId": wallet_id,
"walletNonce": self.get_nonce(wallet)}
for key, value in wallet.labels.items():
for key, value in wallet.get_all_labels().items():
try:
encoded_key = self.encode(wallet, key)
encoded_value = self.encode(wallet, value)
@ -118,7 +118,7 @@ class LabelsPlugin(BasePlugin):
'externalId': encoded_key})
await self.do_post("/labels", bundle)
async def pull_thread(self, wallet, force):
async def pull_thread(self, wallet: 'Abstract_Wallet', force: bool):
wallet_data = self.wallets.get(wallet, None)
if not wallet_data:
raise Exception('Wallet {} not loaded'.format(wallet))
@ -148,8 +148,8 @@ class LabelsPlugin(BasePlugin):
result[key] = value
for key, value in result.items():
if force or not wallet.labels.get(key):
wallet.labels[key] = value
if force or not wallet.get_label(key):
wallet._set_label(key, value)
self.logger.info(f"received {len(response)} labels")
self.set_nonce(wallet, response["nonce"] + 1)
@ -160,21 +160,21 @@ class LabelsPlugin(BasePlugin):
@ignore_exceptions
@log_exceptions
async def pull_safe_thread(self, wallet, force):
async def pull_safe_thread(self, wallet: 'Abstract_Wallet', force: bool):
try:
await self.pull_thread(wallet, force)
except ErrorConnectingServer as e:
self.logger.info(repr(e))
def pull(self, wallet, force):
def pull(self, wallet: 'Abstract_Wallet', force: bool):
if not wallet.network: raise Exception(_('You are offline.'))
return asyncio.run_coroutine_threadsafe(self.pull_thread(wallet, force), wallet.network.asyncio_loop).result()
def push(self, wallet):
def push(self, wallet: 'Abstract_Wallet'):
if not wallet.network: raise Exception(_('You are offline.'))
return asyncio.run_coroutine_threadsafe(self.push_thread(wallet), wallet.network.asyncio_loop).result()
def start_wallet(self, wallet):
def start_wallet(self, wallet: 'Abstract_Wallet'):
if not wallet.network: return # 'offline' mode
nonce = self.get_nonce(wallet)
self.logger.info(f"wallet {wallet.basename()} nonce is {nonce}")

View file

@ -263,7 +263,7 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
# saved fields
self.use_change = db.get('use_change', True)
self.multiple_change = db.get('multiple_change', False)
self.labels = db.get_dict('labels')
self._labels = db.get_dict('labels')
self.frozen_addresses = set(db.get('frozen_addresses', []))
self.frozen_coins = set(db.get('frozen_coins', [])) # set of txid:vout strings
self.fiat_value = db.get_dict('fiat_value')
@ -423,20 +423,28 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
def is_deterministic(self) -> bool:
return self.keystore.is_deterministic()
def _set_label(self, key: str, value: Optional[str]) -> None:
with self.lock:
if value is None:
self._labels.pop(key, None)
else:
self._labels[key] = value
def set_label(self, name: str, text: str = None) -> bool:
if not name:
return False
changed = False
old_text = self.labels.get(name)
if text:
text = text.replace("\n", " ")
if old_text != text:
self.labels[name] = text
changed = True
else:
if old_text is not None:
self.labels.pop(name)
changed = True
with self.lock:
old_text = self._labels.get(name)
if text:
text = text.replace("\n", " ")
if old_text != text:
self._labels[name] = text
changed = True
else:
if old_text is not None:
self._labels.pop(name)
changed = True
if changed:
run_hook('set_label', self, name, text)
return changed
@ -447,7 +455,7 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
self.set_label(key, value)
def export_labels(self, path):
write_json_file(path, self.labels)
write_json_file(path, self.get_all_labels())
def set_fiat_value(self, txid, ccy, text, fx, value_sat):
if not self.db.get_transaction(txid):
@ -568,7 +576,7 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
and bool(tx_we_already_have_in_db))
if tx.is_complete():
if tx_we_already_have_in_db:
label = self.get_label(tx_hash)
label = self.get_label_for_txid(tx_hash)
if tx_mined_status.height > 0:
if tx_mined_status.conf:
status = _("{} confirmations").format(tx_mined_status.conf)
@ -680,7 +688,7 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
'bc_value': Satoshis(hist_item.delta),
'bc_balance': Satoshis(hist_item.balance),
'date': timestamp_to_datetime(hist_item.tx_mined_status.timestamp),
'label': self.get_label(hist_item.txid),
'label': self.get_label_for_txid(hist_item.txid),
'txpos_in_block': hist_item.tx_mined_status.txpos,
}
@ -823,7 +831,7 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
for invoice in self.get_relevant_invoices_for_tx(tx):
if invoice.message:
labels.append(invoice.message)
if labels and not self.labels.get(tx_hash, ''):
if labels and not self._labels.get(tx_hash, ''):
self.set_label(tx_hash, "; ".join(labels))
return bool(labels)
@ -994,19 +1002,28 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
item['capital_gain'] = Fiat(cg, fx.ccy)
return item
def get_label(self, tx_hash: str) -> str:
return self.labels.get(tx_hash, '') or self.get_default_label(tx_hash)
def get_label(self, key: str) -> str:
# key is typically: address / txid / LN-payment-hash-hex
return self._labels.get(key) or ''
def get_default_label(self, tx_hash) -> str:
def get_label_for_txid(self, tx_hash: str) -> str:
return self._labels.get(tx_hash) or self._get_default_label_for_txid(tx_hash)
def _get_default_label_for_txid(self, tx_hash: str) -> str:
# if no inputs are ismine, concat labels of output addresses
if not self.db.get_txi_addresses(tx_hash):
labels = []
for addr in self.db.get_txo_addresses(tx_hash):
label = self.labels.get(addr)
label = self._labels.get(addr)
if label:
labels.append(label)
return ', '.join(labels)
return ''
def get_all_labels(self) -> Dict[str, str]:
with self.lock:
return copy.copy(self._labels)
def get_tx_status(self, tx_hash, tx_mined_info: TxMinedInfo):
extra = []
height = tx_mined_info.height
@ -1672,7 +1689,7 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
def get_request_URI(self, req: OnchainInvoice) -> str:
addr = req.get_address()
message = self.labels.get(addr, '')
message = self.get_label(addr)
amount = req.amount_sat
extra_query_params = {}
if req.time:

View file

@ -1051,7 +1051,7 @@ class WalletDB(JsonDB):
self.tx_fees.pop(txid, None)
@locked
def get_dict(self, name):
def get_dict(self, name) -> dict:
# Warning: interacts un-intuitively with 'put': certain parts
# of 'data' will have pointers saved as separate variables.
if name not in self.data: