diff --git a/electrum/plugin.py b/electrum/plugin.py index f8524e4a3..6f8ec0190 100644 --- a/electrum/plugin.py +++ b/electrum/plugin.py @@ -29,7 +29,10 @@ import time import threading import sys from typing import (NamedTuple, Any, Union, TYPE_CHECKING, Optional, Tuple, - Dict, Iterable, List) + Dict, Iterable, List, Sequence, Callable, TypeVar) +import concurrent +from concurrent import futures +from functools import wraps, partial from .i18n import _ from .util import (profiler, DaemonThread, UserCancelled, ThreadJob, UserFacingException) @@ -39,8 +42,9 @@ from .simple_config import SimpleConfig from .logging import get_logger, Logger if TYPE_CHECKING: - from .plugins.hw_wallet import HW_PluginBase, HardwareClientBase + from .plugins.hw_wallet import HW_PluginBase, HardwareClientBase, HardwareHandlerBase from .keystore import Hardware_KeyStore + from .wallet import Abstract_Wallet _logger = get_logger(__name__) @@ -60,7 +64,7 @@ class Plugins(DaemonThread): self.pkgpath = os.path.dirname(plugins.__file__) self.config = config self.hw_wallets = {} - self.plugins = {} + self.plugins = {} # type: Dict[str, BasePlugin] self.gui_name = gui_name self.descriptions = {} self.device_manager = DeviceMgr(config) @@ -105,7 +109,7 @@ class Plugins(DaemonThread): def count(self): return len(self.plugins) - def load_plugin(self, name): + def load_plugin(self, name) -> 'BasePlugin': if name in self.plugins: return self.plugins[name] full_name = f'electrum.plugins.{name}.{self.gui_name}' @@ -127,14 +131,14 @@ class Plugins(DaemonThread): def close_plugin(self, plugin): self.remove_jobs(plugin.thread_jobs()) - def enable(self, name): + def enable(self, name: str) -> 'BasePlugin': self.config.set_key('use_' + name, True, True) p = self.get(name) if p: return p return self.load_plugin(name) - def disable(self, name): + def disable(self, name: str) -> None: self.config.set_key('use_' + name, False, True) p = self.get(name) if not p: @@ -143,11 +147,11 @@ class Plugins(DaemonThread): p.close() self.logger.info(f"closed {name}") - def toggle(self, name): + def toggle(self, name: str) -> Optional['BasePlugin']: p = self.get(name) return self.disable(name) if p else self.enable(name) - def is_available(self, name, w): + def is_available(self, name: str, wallet: 'Abstract_Wallet') -> bool: d = self.descriptions.get(name) if not d: return False @@ -159,7 +163,7 @@ class Plugins(DaemonThread): self.logger.warning(f'Plugin {name} unavailable: {repr(e)}') return False requires = d.get('requires_wallet_type', []) - return not requires or w.wallet_type in requires + return not requires or wallet.wallet_type in requires def get_hardware_support(self): out = [] @@ -198,8 +202,8 @@ class Plugins(DaemonThread): self.logger.info(f"registering hardware {name}: {details}") register_keystore(details[1], dynamic_constructor) - def get_plugin(self, name): - if not name in self.plugins: + def get_plugin(self, name: str) -> 'BasePlugin': + if name not in self.plugins: self.load_plugin(name) return self.plugins[name] @@ -268,7 +272,7 @@ class BasePlugin(Logger): def on_close(self): pass - def requires_settings(self): + def requires_settings(self) -> bool: return False def thread_jobs(self): @@ -283,12 +287,16 @@ class BasePlugin(Logger): def can_user_disable(self): return True - def settings_dialog(self): - pass + def settings_widget(self, window): + raise NotImplementedError() + + def settings_dialog(self, window): + raise NotImplementedError() class DeviceUnpairableError(UserFacingException): pass class HardwarePluginLibraryUnavailable(Exception): pass +class CannotAutoSelectDevice(Exception): pass class Device(NamedTuple): @@ -305,6 +313,9 @@ class DeviceInfo(NamedTuple): label: Optional[str] = None initialized: Optional[bool] = None exception: Optional[Exception] = None + plugin_name: Optional[str] = None # manufacturer, e.g. "trezor" + soft_device_id: Optional[str] = None # if available, used to distinguish same-type hw devices + model_name: Optional[str] = None # e.g. "Ledger Nano S" class HardwarePluginToScan(NamedTuple): @@ -317,6 +328,46 @@ class HardwarePluginToScan(NamedTuple): PLACEHOLDER_HW_CLIENT_LABELS = {None, "", " "} +# hidapi is not thread-safe +# see https://github.com/signal11/hidapi/issues/205#issuecomment-527654560 +# https://github.com/libusb/hidapi/issues/45 +# https://github.com/signal11/hidapi/issues/45#issuecomment-4434598 +# https://github.com/signal11/hidapi/pull/414#issuecomment-445164238 +# It is not entirely clear to me, exactly what is safe and what isn't, when +# using multiple threads... +# Hence, we use a single thread for all device communications, including +# enumeration. Everything that uses hidapi, libusb, etc, MUST run on +# the following thread: +_hwd_comms_executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, + thread_name_prefix='hwd_comms_thread' +) + + +T = TypeVar('T') + + +def run_in_hwd_thread(func: Callable[[], T]) -> T: + if threading.current_thread().name.startswith("hwd_comms_thread"): + return func() + else: + fut = _hwd_comms_executor.submit(func) + return fut.result() + #except (concurrent.futures.CancelledError, concurrent.futures.TimeoutError) as e: + + +def runs_in_hwd_thread(func): + @wraps(func) + def wrapper(*args, **kwargs): + return run_in_hwd_thread(partial(func, *args, **kwargs)) + return wrapper + + +def assert_runs_in_hwd_thread(): + if not threading.current_thread().name.startswith("hwd_comms_thread"): + raise Exception("must only be called from HWD communication thread") + + class DeviceMgr(ThreadJob): '''Manages hardware clients. A client communicates over a hardware channel with the device. @@ -348,22 +399,22 @@ class DeviceMgr(ThreadJob): This plugin is thread-safe. Currently only devices supported by hidapi are implemented.''' - def __init__(self, config): + def __init__(self, config: SimpleConfig): ThreadJob.__init__(self) # Keyed by xpub. The value is the device id - # has been paired, and None otherwise. + # has been paired, and None otherwise. Needs self.lock. self.xpub_ids = {} # type: Dict[str, str] # A list of clients. The key is the client, the value is - # a (path, id_) pair. + # a (path, id_) pair. Needs self.lock. self.clients = {} # type: Dict[HardwareClientBase, Tuple[Union[str, bytes], str]] - # What we recognise. Each entry is a (vendor_id, product_id) - # pair. - self.recognised_hardware = set() + # What we recognise. (vendor_id, product_id) -> Plugin + self._recognised_hardware = {} # type: Dict[Tuple[int, int], HW_PluginBase] + self._recognised_vendor = {} # type: Dict[int, HW_PluginBase] # vendor_id -> Plugin # Custom enumerate functions for devices we don't know about. - self.enumerate_func = set() - # For synchronization + self._enumerate_func = set() # Needs self.lock. + self.lock = threading.RLock() - self.hid_lock = threading.RLock() + self.config = config def thread_jobs(self): @@ -379,16 +430,23 @@ class DeviceMgr(ThreadJob): for client in clients: client.timeout(cutoff) - def register_devices(self, device_pairs): + def register_devices(self, device_pairs, *, plugin: 'HW_PluginBase'): for pair in device_pairs: - self.recognised_hardware.add(pair) + self._recognised_hardware[pair] = plugin + + def register_vendor_ids(self, vendor_ids: Iterable[int], *, plugin: 'HW_PluginBase'): + for vendor_id in vendor_ids: + self._recognised_vendor[vendor_id] = plugin def register_enumerate_func(self, func): - self.enumerate_func.add(func) + with self.lock: + self._enumerate_func.add(func) - def create_client(self, device: 'Device', handler, plugin: 'HW_PluginBase') -> Optional['HardwareClientBase']: + @runs_in_hwd_thread + def create_client(self, device: 'Device', handler: Optional['HardwareHandlerBase'], + plugin: 'HW_PluginBase') -> Optional['HardwareClientBase']: # Get from cache first - client = self.client_lookup(device.id_) + client = self._client_by_id(device.id_) if client: return client client = plugin.create_client(device, handler) @@ -414,7 +472,7 @@ class DeviceMgr(ThreadJob): if xpub not in self.xpub_ids: return _id = self.xpub_ids.pop(xpub) - self._close_client(_id) + self._close_client(_id) def unpair_id(self, id_): xpub = self.xpub_by_id(id_) @@ -424,8 +482,9 @@ class DeviceMgr(ThreadJob): self._close_client(id_) def _close_client(self, id_): - client = self.client_lookup(id_) - self.clients.pop(client, None) + with self.lock: + client = self._client_by_id(id_) + self.clients.pop(client, None) if client: client.close() @@ -433,45 +492,57 @@ class DeviceMgr(ThreadJob): with self.lock: self.xpub_ids[xpub] = id_ - def client_lookup(self, id_) -> Optional['HardwareClientBase']: + def _client_by_id(self, id_) -> Optional['HardwareClientBase']: with self.lock: for client, (path, client_id) in self.clients.items(): if client_id == id_: return client return None - def client_by_id(self, id_) -> Optional['HardwareClientBase']: + def client_by_id(self, id_, *, scan_now: bool = True) -> Optional['HardwareClientBase']: '''Returns a client for the device ID if one is registered. If a device is wiped or in bootloader mode pairing is impossible; in such cases we communicate by device ID and not wallet.''' - self.scan_devices() - return self.client_lookup(id_) + if scan_now: + self.scan_devices() + return self._client_by_id(id_) - def client_for_keystore(self, plugin: 'HW_PluginBase', handler, keystore: 'Hardware_KeyStore', - force_pair: bool) -> Optional['HardwareClientBase']: + @runs_in_hwd_thread + def client_for_keystore(self, plugin: 'HW_PluginBase', handler: Optional['HardwareHandlerBase'], + keystore: 'Hardware_KeyStore', + force_pair: bool, *, + devices: Sequence['Device'] = None, + allow_user_interaction: bool = True) -> Optional['HardwareClientBase']: self.logger.info("getting client for keystore") if handler is None: raise Exception(_("Handler not found for") + ' ' + plugin.name + '\n' + _("A library is probably missing.")) handler.update_status(False) - devices = self.scan_devices() + if devices is None: + devices = self.scan_devices() xpub = keystore.xpub derivation = keystore.get_derivation_prefix() assert derivation is not None client = self.client_by_xpub(plugin, xpub, handler, devices) if client is None and force_pair: - info = self.select_device(plugin, handler, keystore, devices) - client = self.force_pair_xpub(plugin, handler, info, xpub, derivation) + try: + info = self.select_device(plugin, handler, keystore, devices, + allow_user_interaction=allow_user_interaction) + except CannotAutoSelectDevice: + pass + else: + client = self.force_pair_xpub(plugin, handler, info, xpub, derivation) if client: handler.update_status(True) if client: + # note: if select_device was called, we might also update label etc here: keystore.opportunistically_fill_in_missing_info_from_device(client) self.logger.info("end client for keystore") return client - def client_by_xpub(self, plugin: 'HW_PluginBase', xpub, handler, - devices: Iterable['Device']) -> Optional['HardwareClientBase']: + def client_by_xpub(self, plugin: 'HW_PluginBase', xpub, handler: 'HardwareHandlerBase', + devices: Sequence['Device']) -> Optional['HardwareClientBase']: _id = self.xpub_id(xpub) - client = self.client_lookup(_id) + client = self._client_by_id(_id) if client: # An unpaired client might have another wallet's handler # from a prior scan. Replace to fix dialog parenting. @@ -482,12 +553,12 @@ class DeviceMgr(ThreadJob): if device.id_ == _id: return self.create_client(device, handler, plugin) - def force_pair_xpub(self, plugin: 'HW_PluginBase', handler, + def force_pair_xpub(self, plugin: 'HW_PluginBase', handler: 'HardwareHandlerBase', info: 'DeviceInfo', xpub, derivation) -> Optional['HardwareClientBase']: # The wallet has not been previously paired, so let the user # choose an unpaired device and compare its first address. xtype = bip32.xpub_type(xpub) - client = self.client_lookup(info.device.id_) + client = self._client_by_id(info.device.id_) if client and client.is_pairable(): # See comment above for same code client.handler = handler @@ -510,7 +581,8 @@ class DeviceMgr(ThreadJob): 'its seed (and passphrase, if any). Otherwise all bitcoins you ' 'receive will be unspendable.').format(plugin.device)) - def unpaired_device_infos(self, handler, plugin: 'HW_PluginBase', devices: List['Device'] = None, + def unpaired_device_infos(self, handler: Optional['HardwareHandlerBase'], plugin: 'HW_PluginBase', + devices: Sequence['Device'] = None, include_failing_clients=False) -> List['DeviceInfo']: '''Returns a list of DeviceInfo objects: one for each connected, unpaired device accepted by the plugin.''' @@ -522,31 +594,38 @@ class DeviceMgr(ThreadJob): devices = [dev for dev in devices if not self.xpub_by_id(dev.id_)] infos = [] for device in devices: - if device.product_key not in plugin.DEVICE_IDS: + if not plugin.can_recognize_device(device): continue try: client = self.create_client(device, handler, plugin) except Exception as e: self.logger.error(f'failed to create client for {plugin.name} at {device.path}: {repr(e)}') if include_failing_clients: - infos.append(DeviceInfo(device=device, exception=e)) + infos.append(DeviceInfo(device=device, exception=e, plugin_name=plugin.name)) continue if not client: continue infos.append(DeviceInfo(device=device, label=client.label(), - initialized=client.is_initialized())) + initialized=client.is_initialized(), + plugin_name=plugin.name, + soft_device_id=client.get_soft_device_id(), + model_name=client.device_model_name())) return infos - def select_device(self, plugin: 'HW_PluginBase', handler, - keystore: 'Hardware_KeyStore', devices: List['Device'] = None) -> 'DeviceInfo': - '''Ask the user to select a device to use if there is more than one, - and return the DeviceInfo for the device.''' + def select_device(self, plugin: 'HW_PluginBase', handler: 'HardwareHandlerBase', + keystore: 'Hardware_KeyStore', devices: Sequence['Device'] = None, + *, allow_user_interaction: bool = True) -> 'DeviceInfo': + """Select the device to use for keystore.""" + # ideally this should not be called from the GUI thread... + # assert handler.get_gui_thread() != threading.current_thread(), 'must not be called from GUI thread' while True: infos = self.unpaired_device_infos(handler, plugin, devices) if infos: break + if not allow_user_interaction: + raise CannotAutoSelectDevice() msg = _('Please insert your {}').format(plugin.device) if keystore.label: msg += ' ({})'.format(keystore.label) @@ -558,69 +637,79 @@ class DeviceMgr(ThreadJob): if not handler.yes_no_question(msg): raise UserCancelled() devices = None - if len(infos) == 1: - return infos[0] - # select device by label automatically; - # but only if not a placeholder label and only if there is no collision + + # select device automatically. (but only if we have reasonable expectation it is the correct one) + # method 1: select device by id + if keystore.soft_device_id: + for info in infos: + if info.soft_device_id == keystore.soft_device_id: + return info + # method 2: select device by label + # but only if not a placeholder label and only if there is no collision device_labels = [info.label for info in infos] if (keystore.label not in PLACEHOLDER_HW_CLIENT_LABELS and device_labels.count(keystore.label) == 1): for info in infos: if info.label == keystore.label: return info - # ask user to select device + # method 3: if there is only one device connected, and we don't have useful label/soft_device_id + # saved for keystore anyway, select it + if (len(infos) == 1 + and keystore.label in PLACEHOLDER_HW_CLIENT_LABELS + and keystore.soft_device_id is None): + return infos[0] + + if not allow_user_interaction: + raise CannotAutoSelectDevice() + # ask user to select device manually msg = _("Please select which {} device to use:").format(plugin.device) - descriptions = ["{label} ({init}, {transport})" - .format(label=info.label, + descriptions = ["{label} ({maybe_model}{init}, {transport})" + .format(label=info.label or _("An unnamed {}").format(info.plugin_name), init=(_("initialized") if info.initialized else _("wiped")), - transport=info.device.transport_ui_string) + transport=info.device.transport_ui_string, + maybe_model=f"{info.model_name}, " if info.model_name else "") for info in infos] c = handler.query_choice(msg, descriptions) if c is None: raise UserCancelled() info = infos[c] - # save new label - keystore.set_label(info.label) - if handler.win.wallet is not None: - handler.win.wallet.save_keystore() + # note: updated label/soft_device_id will be saved after pairing succeeds return info + @runs_in_hwd_thread def _scan_devices_with_hid(self) -> List['Device']: try: import hid except ImportError: return [] - with self.hid_lock: - hid_list = hid.enumerate(0, 0) - devices = [] - for d in hid_list: - product_key = (d['vendor_id'], d['product_id']) - if product_key in self.recognised_hardware: - # Older versions of hid don't provide interface_number - interface_number = d.get('interface_number', -1) - usage_page = d['usage_page'] - id_ = d['serial_number'] - if len(id_) == 0: - id_ = str(d['path']) - id_ += str(interface_number) + str(usage_page) - devices.append(Device(path=d['path'], - interface_number=interface_number, - id_=id_, - product_key=product_key, - usage_page=usage_page, - transport_ui_string='hid')) + for d in hid.enumerate(0, 0): + vendor_id = d['vendor_id'] + product_key = (vendor_id, d['product_id']) + plugin = None + if product_key in self._recognised_hardware: + plugin = self._recognised_hardware[product_key] + elif vendor_id in self._recognised_vendor: + plugin = self._recognised_vendor[vendor_id] + if plugin: + device = plugin.create_device_from_hid_enumeration(d, product_key=product_key) + if device: + devices.append(device) return devices - def scan_devices(self) -> List['Device']: + @runs_in_hwd_thread + @profiler + def scan_devices(self) -> Sequence['Device']: self.logger.info("scanning devices...") # First see what's connected that we know about devices = self._scan_devices_with_hid() # Let plugin handlers enumerate devices we don't know about - for f in self.enumerate_func: + with self.lock: + enumerate_funcs = list(self._enumerate_func) + for f in enumerate_funcs: try: new_devices = f() except BaseException as e: @@ -631,18 +720,20 @@ class DeviceMgr(ThreadJob): # find out what was disconnected pairs = [(dev.path, dev.id_) for dev in devices] - disconnected_ids = [] + disconnected_clients = [] with self.lock: connected = {} for client, pair in self.clients.items(): if pair in pairs and client.has_usable_connection_with_device(): connected[client] = pair else: - disconnected_ids.append(pair[1]) + disconnected_clients.append((client, pair[1])) self.clients = connected # Unpair disconnected devices - for id_ in disconnected_ids: + for client, id_ in disconnected_clients: self.unpair_id(id_) + if client.handler: + client.handler.update_status(False) return devices