diff --git a/electrum/__init__.py b/electrum/__init__.py index 48a60c15c..556fc329b 100644 --- a/electrum/__init__.py +++ b/electrum/__init__.py @@ -4,7 +4,7 @@ from .wallet import Wallet from .storage import WalletStorage from .coinchooser import COIN_CHOOSERS from .network import Network, pick_random_server -from .interface import Connection, Interface +from .interface import Interface from .simple_config import SimpleConfig, get_config, set_config from . import bitcoin from . import transaction diff --git a/electrum/interface.py b/electrum/interface.py index 357d913a2..678c726a6 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -28,343 +28,131 @@ import socket import ssl import sys import threading -import time import traceback +import aiorpcx +import asyncio import requests -from .util import print_error +from .util import PrintError ca_path = requests.certs.where() from . import util from . import x509 from . import pem +from .version import ELECTRUM_VERSION, PROTOCOL_VERSION +class Interface(PrintError): -def Connection(server, queue, config_path): - """Makes asynchronous connections to a remote Electrum server. - Returns the running thread that is making the connection. - - Once the thread has connected, it finishes, placing a tuple on the - queue of the form (server, socket), where socket is None if - connection failed. - """ - host, port, protocol = server.rsplit(':', 2) - if not protocol in 'st': - raise Exception('Unknown protocol: %s' % protocol) - c = TcpConnection(server, queue, config_path) - c.start() - return c - - -class TcpConnection(threading.Thread, util.PrintError): - verbosity_filter = 'i' - - def __init__(self, server, queue, config_path): - threading.Thread.__init__(self) + def __init__(self, server, config_path, connecting): + self.connecting = connecting + self.server = server + self.host, self.port, self.protocol = self.server.split(':') self.config_path = config_path - self.queue = queue - self.server = server - self.host, self.port, self.protocol = self.server.rsplit(':', 2) - self.host = str(self.host) - self.port = int(self.port) - self.use_ssl = (self.protocol == 's') - self.daemon = True + self.cert_path = os.path.join(self.config_path, 'certs', self.host) + self.fut = asyncio.get_event_loop().create_task(self.run()) def diagnostic_name(self): return self.host - def check_host_name(self, peercert, name): - """Simple certificate/host name checker. Returns True if the - certificate matches, False otherwise. Does not support - wildcards.""" - # Check that the peer has supplied a certificate. - # None/{} is not acceptable. - if not peercert: - return False - if 'subjectAltName' in peercert: - for typ, val in peercert["subjectAltName"]: - if typ == "DNS" and val == name: - return True - else: - # Only check the subject DN if there is no subject alternative - # name. - cn = None - for attr, val in peercert["subject"]: - # Use most-specific (last) commonName attribute. - if attr == "commonName": - cn = val - if cn is not None: - return cn == name - return False - - def get_simple_socket(self): + async def is_server_ca_signed(self, sslc): try: - l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) - except socket.gaierror: - self.print_error("cannot resolve hostname") - return - e = None - for res in l: - try: - s = socket.socket(res[0], socket.SOCK_STREAM) - s.settimeout(10) - s.connect(res[4]) - s.settimeout(2) - s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - return s - except BaseException as _e: - e = _e - continue - else: - self.print_error("failed to connect", str(e)) - - @staticmethod - def get_ssl_context(cert_reqs, ca_certs): - context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_certs) - context.check_hostname = False - context.verify_mode = cert_reqs - - context.options |= ssl.OP_NO_SSLv2 - context.options |= ssl.OP_NO_SSLv3 - context.options |= ssl.OP_NO_TLSv1 - - return context - - def get_socket(self): - if self.use_ssl: - cert_path = os.path.join(self.config_path, 'certs', self.host) - if not os.path.exists(cert_path): - is_new = True - s = self.get_simple_socket() - if s is None: - return - # try with CA first - try: - context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path) - s = context.wrap_socket(s, do_handshake_on_connect=True) - except ssl.SSLError as e: - self.print_error(e) - except: - return - else: - try: - peer_cert = s.getpeercert() - except OSError: - return - if self.check_host_name(peer_cert, self.host): - self.print_error("SSL certificate signed by CA") - return s - # get server certificate. - # Do not use ssl.get_server_certificate because it does not work with proxy - s = self.get_simple_socket() - if s is None: - return - try: - context = self.get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None) - s = context.wrap_socket(s) - except ssl.SSLError as e: - self.print_error("SSL error retrieving SSL certificate:", e) - return - except: - return - - try: - dercert = s.getpeercert(True) - except OSError: - return - s.close() - cert = ssl.DER_cert_to_PEM_cert(dercert) - # workaround android bug - cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert) - temporary_path = cert_path + '.temp' - util.assert_datadir_available(self.config_path) - with open(temporary_path, "w", encoding='utf-8') as f: - f.write(cert) - f.flush() - os.fsync(f.fileno()) - else: - is_new = False - - s = self.get_simple_socket() - if s is None: - return - - if self.use_ssl: - try: - context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, - ca_certs=(temporary_path if is_new else cert_path)) - s = context.wrap_socket(s, do_handshake_on_connect=True) - except socket.timeout: - self.print_error('timeout') - return - except ssl.SSLError as e: - self.print_error("SSL error:", e) - if e.errno != 1: - return - if is_new: - rej = cert_path + '.rej' - if os.path.exists(rej): - os.unlink(rej) - os.rename(temporary_path, rej) - else: - util.assert_datadir_available(self.config_path) - with open(cert_path, encoding='utf-8') as f: - cert = f.read() - try: - b = pem.dePem(cert, 'CERTIFICATE') - x = x509.X509(b) - except: - traceback.print_exc(file=sys.stderr) - self.print_error("wrong certificate") - return - try: - x.check_date() - except: - self.print_error("certificate has expired:", cert_path) - os.unlink(cert_path) - return - self.print_error("wrong certificate") - if e.errno == 104: - return - return - except BaseException as e: - self.print_error(e) - traceback.print_exc(file=sys.stderr) - return - - if is_new: - self.print_error("saving certificate") - os.rename(temporary_path, cert_path) - - return s - - def run(self): - socket = self.get_socket() - if socket: - self.print_error("connected") - self.queue.put((self.server, socket)) - - -class Interface(util.PrintError): - """The Interface class handles a socket connected to a single remote - Electrum server. Its exposed API is: - - - Member functions close(), fileno(), get_responses(), has_timed_out(), - ping_required(), queue_request(), send_requests() - - Member variable server. - """ - - def __init__(self, server, socket): - self.server = server - self.host, _, _ = server.rsplit(':', 2) - self.socket = socket - - self.pipe = util.SocketPipe(socket) - self.pipe.set_timeout(0.0) # Don't wait for data - # Dump network messages. Set at runtime from the console. - self.debug = False - self.unsent_requests = [] - self.unanswered_requests = {} - self.last_send = time.time() - self.closed_remotely = False - - def diagnostic_name(self): - return self.host - - def fileno(self): - # Needed for select - return self.socket.fileno() - - def close(self): - if not self.closed_remotely: - try: - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - self.socket.close() - - def queue_request(self, *args): # method, params, _id - '''Queue a request, later to be send with send_requests when the - socket is available for writing. - ''' - self.request_time = time.time() - self.unsent_requests.append(args) - - def num_requests(self): - '''Keep unanswered requests below 100''' - n = 100 - len(self.unanswered_requests) - return min(n, len(self.unsent_requests)) - - def send_requests(self): - '''Sends queued requests. Returns False on failure.''' - self.last_send = time.time() - make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i} - n = self.num_requests() - wire_requests = self.unsent_requests[0:n] - try: - self.pipe.send_all([make_dict(*r) for r in wire_requests]) - except BaseException as e: - self.print_error("pipe send error:", e) + await self.open_session(sslc, do_sleep=False) + except ssl.SSLError as e: + assert e.reason == 'CERTIFICATE_VERIFY_FAILED' return False - self.unsent_requests = self.unsent_requests[n:] - for request in wire_requests: - if self.debug: - self.print_error("-->", request) - self.unanswered_requests[request[2]] = request return True - def ping_required(self): - '''Returns True if a ping should be sent.''' - return time.time() - self.last_send > 300 + @util.aiosafe + async def run(self): + if self.protocol != 's': + await self.open_session(None, execute_after_connect=lambda: self.connecting.remove(self.server)) + return + + ca_sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + exists = os.path.exists(self.cert_path) + if exists: + with open(self.cert_path, 'r') as f: + contents = f.read() + if contents != '': # if not CA signed + try: + b = pem.dePem(contents, 'CERTIFICATE') + except SyntaxError: + exists = False + else: + x = x509.X509(b) + try: + x.check_date() + except x509.CertificateError: + self.print_error("certificate has expired:", self.cert_path) + os.unlink(self.cert_path) + exists = False + if not exists: + ca_signed = await self.is_server_ca_signed(ca_sslc) + if ca_signed: + with open(self.cert_path, 'w') as f: + # empty file means this is CA signed, not self-signed + f.write('') + else: + await self.save_certificate() + siz = os.stat(self.cert_path).st_size + if siz == 0: # if CA signed + sslc = ca_sslc + else: + sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=self.cert_path) + sslc.check_hostname = 0 + await self.open_session(sslc, execute_after_connect=lambda: self.connecting.remove(self.server)) + + async def save_certificate(self): + if not os.path.exists(self.cert_path): + # we may need to retry this a few times, in case the handshake hasn't completed + for _ in range(10): + dercert = await self.get_certificate() + if dercert: + self.print_error("succeeded in getting cert") + with open(self.cert_path, 'w') as f: + cert = ssl.DER_cert_to_PEM_cert(dercert) + # workaround android bug + cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert) + f.write(cert) + # even though close flushes we can't fsync when closed. + # and we must flush before fsyncing, cause flush flushes to OS buffer + # fsync writes to OS buffer to disk + f.flush() + os.fsync(f.fileno()) + break + await asyncio.sleep(1) + assert False, "could not get certificate" + + async def get_certificate(self): + sslc = ssl.SSLContext() + try: + async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session: + return session.transport._ssl_protocol._sslpipe._sslobj.getpeercert(True) + except ValueError: + return None + + async def open_session(self, sslc, do_sleep=True, execute_after_connect=lambda: None): + async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session: + ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION]) + print(ver) + connect_hook_executed = False + while do_sleep: + if not connect_hook_executed: + connect_hook_executed = True + execute_after_connect() + await asyncio.wait_for(session.send_request('server.ping'), 5) + await asyncio.sleep(300) def has_timed_out(self): - '''Returns True if the interface has timed out.''' - if (self.unanswered_requests and time.time() - self.request_time > 10 - and self.pipe.idle_time() > 10): - self.print_error("timeout", len(self.unanswered_requests)) - return True + return self.fut.done() - return False - - def get_responses(self): - '''Call if there is data available on the socket. Returns a list of - (request, response) pairs. Notifications are singleton - unsolicited responses presumably as a result of prior - subscriptions, so request is None and there is no 'id' member. - Otherwise it is a response, which has an 'id' member and a - corresponding request. If the connection was closed remotely - or the remote server is misbehaving, a (None, None) will appear. - ''' - responses = [] - while True: - try: - response = self.pipe.get() - except util.timeout: - break - if not type(response) is dict: - responses.append((None, None)) - if response is None: - self.closed_remotely = True - self.print_error("connection closed remotely") - break - if self.debug: - self.print_error("<--", response) - wire_id = response.get('id', None) - if wire_id is None: # Notification - responses.append((None, response)) - else: - request = self.unanswered_requests.pop(wire_id, None) - if request: - responses.append((request, response)) - else: - self.print_error("unknown wire ID", wire_id) - responses.append((None, None)) # Signal - break - - return responses + def queue_request(self, method, params, msg_id): + pass + def close(self): + self.fut.cancel() def check_cert(host, cert): try: diff --git a/electrum/network.py b/electrum/network.py index 01febd145..9e5b33d83 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -47,38 +47,14 @@ from . import blockchain from .version import ELECTRUM_VERSION, PROTOCOL_VERSION from .i18n import _ from .blockchain import InvalidHeader +from .interface import Interface -import aiorpcx, asyncio, ssl +import asyncio import concurrent.futures NODES_RETRY_INTERVAL = 60 SERVER_RETRY_INTERVAL = 10 -class Interface(PrintError): - @util.aiosafe - async def run(self): - self.host, self.port, self.protocol = self.server.split(':') - sslc = ssl.SSLContext(ssl.PROTOCOL_TLS) if self.protocol == 's' else None - async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session: - ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION]) - print(ver) - while True: - print("sleeping") - await asyncio.sleep(1) - - def __init__(self, server): - self.exception = None - self.server = server - self.fut = asyncio.get_event_loop().create_task(self.run()) - - def has_timed_out(self): - return self.fut.done() - - def queue_request(self, method, params, msg_id): - pass - - def close(self): - self.fut.cancel() def parse_servers(result): """ parse servers list into dict format""" @@ -539,7 +515,7 @@ class Network(PrintError): self.close_interface(self.interface) assert self.interface is None assert not self.interfaces - self.connecting = set() + self.connecting.clear() # Get a new queue - no old pending connections thanks! self.socket_queue = queue.Queue() @@ -810,7 +786,7 @@ class Network(PrintError): def new_interface(self, server): # todo: get tip first, then decide which checkpoint to use. self.add_recent_server(server) - interface = Interface(server) + interface = Interface(server, self.config.path, self.connecting) interface.blockchain = None interface.tip_header = None interface.tip = 0 @@ -1368,9 +1344,12 @@ class Network(PrintError): for k, i in self.interfaces.items(): if i.has_timed_out(): remove.append(k) + changed = False for k in remove: self.connection_down(k) + changed = True for i in range(self.num_server - len(self.interfaces)): self.start_random_interface() - self.notify('updated') + changed = True + if changed: self.notify('updated') await asyncio.sleep(1)