trezor: move the transport-related reimplemented parts into a separate module. disable the bridge transport.

The bridge transport uses requests.post, which uses socket.getaddrinfo under the hood, which on some OSes (MacOS, Windows) in CPython takes a lock. The enumerate method for the bridge transport can block for 10-30 seconds while waiting for this lock.
This commit is contained in:
SomberNight 2018-03-16 23:19:52 +01:00
parent c79de3ab3c
commit 680df7d6b6
2 changed files with 99 additions and 74 deletions

View file

@ -0,0 +1,95 @@
from electrum.util import PrintError
class TrezorTransport(PrintError):
@staticmethod
def all_transports():
"""Reimplemented trezorlib.transport.all_transports so that we can
enable/disable specific transports.
"""
try:
# only to detect trezorlib version
from trezorlib.transport import all_transports
except ImportError:
# old trezorlib. compat for trezorlib < 0.9.2
transports = []
#try:
# from trezorlib.transport_bridge import BridgeTransport
# transports.append(BridgeTransport)
#except BaseException:
# pass
try:
from trezorlib.transport_hid import HidTransport
transports.append(HidTransport)
except BaseException:
pass
try:
from trezorlib.transport_udp import UdpTransport
transports.append(UdpTransport)
except BaseException:
pass
try:
from trezorlib.transport_webusb import WebUsbTransport
transports.append(WebUsbTransport)
except BaseException:
pass
else:
# new trezorlib.
transports = []
#try:
# from trezorlib.transport.bridge import BridgeTransport
# transports.append(BridgeTransport)
#except BaseException:
# pass
try:
from trezorlib.transport.hid import HidTransport
transports.append(HidTransport)
except BaseException:
pass
try:
from trezorlib.transport.udp import UdpTransport
transports.append(UdpTransport)
except BaseException:
pass
try:
from trezorlib.transport.webusb import WebUsbTransport
transports.append(WebUsbTransport)
except BaseException:
pass
return transports
return transports
def enumerate_devices(self):
"""Just like trezorlib.transport.enumerate_devices,
but with exception catching, so that transports can fail separately.
"""
devices = []
for transport in self.all_transports():
try:
new_devices = transport.enumerate()
except BaseException as e:
self.print_error('enumerate failed for {}. error {}'
.format(transport.__name__, str(e)))
else:
devices.extend(new_devices)
return devices
def get_transport(self, path=None):
"""Reimplemented trezorlib.transport.get_transport,
(1) for old trezorlib
(2) to be able to disable specific transports
(3) to call our own enumerate_devices that catches exceptions
"""
if path is None:
try:
return self.enumerate_devices()[0]
except IndexError:
raise Exception("No TREZOR device found") from None
def match_prefix(a, b):
return a.startswith(b) or b.startswith(a)
transports = [t for t in self.all_transports() if match_prefix(path, t.PATH_PREFIX)]
if transports:
return transports[0].find_by_path(path)
raise Exception("Unknown path prefix '%s'" % path)

View file

@ -117,6 +117,7 @@ class TrezorPlugin(HW_PluginBase):
return return
from . import client from . import client
from . import transport
import trezorlib.ckd_public import trezorlib.ckd_public
import trezorlib.messages import trezorlib.messages
self.client_class = client.TrezorClient self.client_class = client.TrezorClient
@ -124,88 +125,17 @@ class TrezorPlugin(HW_PluginBase):
self.types = trezorlib.messages self.types = trezorlib.messages
self.DEVICE_IDS = ('TREZOR',) self.DEVICE_IDS = ('TREZOR',)
self.transport_handler = transport.TrezorTransport()
self.device_manager().register_enumerate_func(self.enumerate) self.device_manager().register_enumerate_func(self.enumerate)
@staticmethod
def _all_transports():
"""Reimplemented trezorlib.transport.all_transports for old trezorlib.
Remove this when we start to require trezorlib 0.9.2
"""
try:
from trezorlib.transport import all_transports
except ImportError:
# compat for trezorlib < 0.9.2
def all_transports():
transports = []
try:
from trezorlib.transport_bridge import BridgeTransport
transports.append(BridgeTransport)
except BaseException:
pass
try:
from trezorlib.transport_hid import HidTransport
transports.append(HidTransport)
except BaseException:
pass
try:
from trezorlib.transport_udp import UdpTransport
transports.append(UdpTransport)
except BaseException:
pass
try:
from trezorlib.transport_webusb import WebUsbTransport
transports.append(WebUsbTransport)
except BaseException:
pass
return transports
return all_transports()
def _enumerate_devices(self):
"""Just like trezorlib.transport.enumerate_devices,
but with exception catching, so that transports can fail separately.
"""
devices = []
for transport in self._all_transports():
try:
new_devices = transport.enumerate()
except BaseException as e:
self.print_error('enumerate failed for {}. error {}'
.format(transport.__name__, str(e)))
else:
devices.extend(new_devices)
return devices
def enumerate(self): def enumerate(self):
devices = self._enumerate_devices() devices = self.transport_handler.enumerate_devices()
return [Device(d.get_path(), -1, d.get_path(), 'TREZOR', 0) for d in devices] return [Device(d.get_path(), -1, d.get_path(), 'TREZOR', 0) for d in devices]
def _get_transport(self, path=None):
"""Reimplemented trezorlib.transport.get_transport for old trezorlib.
Remove this when we start to require trezorlib 0.9.2
"""
try:
from trezorlib.transport import get_transport
except ImportError:
# compat for trezorlib < 0.9.2
def get_transport(path=None, prefix_search=False):
if path is None:
try:
return self._enumerate_devices()[0]
except IndexError:
raise Exception("No TREZOR device found") from None
def match_prefix(a, b):
return a.startswith(b) or b.startswith(a)
transports = [t for t in self._all_transports() if match_prefix(path, t.PATH_PREFIX)]
if transports:
return transports[0].find_by_path(path)
raise Exception("Unknown path prefix '%s'" % path)
return get_transport(path)
def create_client(self, device, handler): def create_client(self, device, handler):
try: try:
self.print_error("connecting to device at", device.path) self.print_error("connecting to device at", device.path)
transport = self._get_transport(device.path) transport = self.transport_handler.get_transport(device.path)
except BaseException as e: except BaseException as e:
self.print_error("cannot connect at", device.path, str(e)) self.print_error("cannot connect at", device.path, str(e))
return None return None