aiorpcx: pin certificates

This commit is contained in:
Janus 2018-08-16 18:16:25 +02:00 committed by SomberNight
parent 8080a713b2
commit 89a01a6463
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
3 changed files with 109 additions and 342 deletions

View file

@ -4,7 +4,7 @@ from .wallet import Wallet
from .storage import WalletStorage from .storage import WalletStorage
from .coinchooser import COIN_CHOOSERS from .coinchooser import COIN_CHOOSERS
from .network import Network, pick_random_server from .network import Network, pick_random_server
from .interface import Connection, Interface from .interface import Interface
from .simple_config import SimpleConfig, get_config, set_config from .simple_config import SimpleConfig, get_config, set_config
from . import bitcoin from . import bitcoin
from . import transaction from . import transaction

View file

@ -28,343 +28,131 @@ import socket
import ssl import ssl
import sys import sys
import threading import threading
import time
import traceback import traceback
import aiorpcx
import asyncio
import requests import requests
from .util import print_error from .util import PrintError
ca_path = requests.certs.where() ca_path = requests.certs.where()
from . import util from . import util
from . import x509 from . import x509
from . import pem from . import pem
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
class Interface(PrintError):
def Connection(server, queue, config_path): def __init__(self, server, config_path, connecting):
"""Makes asynchronous connections to a remote Electrum server. self.connecting = connecting
Returns the running thread that is making the connection.
Once the thread has connected, it finishes, placing a tuple on the
queue of the form (server, socket), where socket is None if
connection failed.
"""
host, port, protocol = server.rsplit(':', 2)
if not protocol in 'st':
raise Exception('Unknown protocol: %s' % protocol)
c = TcpConnection(server, queue, config_path)
c.start()
return c
class TcpConnection(threading.Thread, util.PrintError):
verbosity_filter = 'i'
def __init__(self, server, queue, config_path):
threading.Thread.__init__(self)
self.config_path = config_path
self.queue = queue
self.server = server self.server = server
self.host, self.port, self.protocol = self.server.rsplit(':', 2) self.host, self.port, self.protocol = self.server.split(':')
self.host = str(self.host) self.config_path = config_path
self.port = int(self.port) self.cert_path = os.path.join(self.config_path, 'certs', self.host)
self.use_ssl = (self.protocol == 's') self.fut = asyncio.get_event_loop().create_task(self.run())
self.daemon = True
def diagnostic_name(self): def diagnostic_name(self):
return self.host return self.host
def check_host_name(self, peercert, name): async def is_server_ca_signed(self, sslc):
"""Simple certificate/host name checker. Returns True if the try:
certificate matches, False otherwise. Does not support await self.open_session(sslc, do_sleep=False)
wildcards.""" except ssl.SSLError as e:
# Check that the peer has supplied a certificate. assert e.reason == 'CERTIFICATE_VERIFY_FAILED'
# None/{} is not acceptable.
if not peercert:
return False return False
if 'subjectAltName' in peercert:
for typ, val in peercert["subjectAltName"]:
if typ == "DNS" and val == name:
return True return True
@util.aiosafe
async def run(self):
if self.protocol != 's':
await self.open_session(None, execute_after_connect=lambda: self.connecting.remove(self.server))
return
ca_sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
exists = os.path.exists(self.cert_path)
if exists:
with open(self.cert_path, 'r') as f:
contents = f.read()
if contents != '': # if not CA signed
try:
b = pem.dePem(contents, 'CERTIFICATE')
except SyntaxError:
exists = False
else: else:
# Only check the subject DN if there is no subject alternative x = x509.X509(b)
# name.
cn = None
for attr, val in peercert["subject"]:
# Use most-specific (last) commonName attribute.
if attr == "commonName":
cn = val
if cn is not None:
return cn == name
return False
def get_simple_socket(self):
try: try:
l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) x.check_date()
except socket.gaierror: except x509.CertificateError:
self.print_error("cannot resolve hostname") self.print_error("certificate has expired:", self.cert_path)
return os.unlink(self.cert_path)
e = None exists = False
for res in l: if not exists:
try: ca_signed = await self.is_server_ca_signed(ca_sslc)
s = socket.socket(res[0], socket.SOCK_STREAM) if ca_signed:
s.settimeout(10) with open(self.cert_path, 'w') as f:
s.connect(res[4]) # empty file means this is CA signed, not self-signed
s.settimeout(2) f.write('')
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
return s
except BaseException as _e:
e = _e
continue
else: else:
self.print_error("failed to connect", str(e)) await self.save_certificate()
siz = os.stat(self.cert_path).st_size
@staticmethod if siz == 0: # if CA signed
def get_ssl_context(cert_reqs, ca_certs): sslc = ca_sslc
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_certs)
context.check_hostname = False
context.verify_mode = cert_reqs
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
return context
def get_socket(self):
if self.use_ssl:
cert_path = os.path.join(self.config_path, 'certs', self.host)
if not os.path.exists(cert_path):
is_new = True
s = self.get_simple_socket()
if s is None:
return
# try with CA first
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path)
s = context.wrap_socket(s, do_handshake_on_connect=True)
except ssl.SSLError as e:
self.print_error(e)
except:
return
else: else:
try: sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
peer_cert = s.getpeercert() sslc.check_hostname = 0
except OSError: await self.open_session(sslc, execute_after_connect=lambda: self.connecting.remove(self.server))
return
if self.check_host_name(peer_cert, self.host):
self.print_error("SSL certificate signed by CA")
return s
# get server certificate.
# Do not use ssl.get_server_certificate because it does not work with proxy
s = self.get_simple_socket()
if s is None:
return
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None)
s = context.wrap_socket(s)
except ssl.SSLError as e:
self.print_error("SSL error retrieving SSL certificate:", e)
return
except:
return
try: async def save_certificate(self):
dercert = s.getpeercert(True) if not os.path.exists(self.cert_path):
except OSError: # we may need to retry this a few times, in case the handshake hasn't completed
return for _ in range(10):
s.close() dercert = await self.get_certificate()
if dercert:
self.print_error("succeeded in getting cert")
with open(self.cert_path, 'w') as f:
cert = ssl.DER_cert_to_PEM_cert(dercert) cert = ssl.DER_cert_to_PEM_cert(dercert)
# workaround android bug # workaround android bug
cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert) cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
temporary_path = cert_path + '.temp'
util.assert_datadir_available(self.config_path)
with open(temporary_path, "w", encoding='utf-8') as f:
f.write(cert) f.write(cert)
# even though close flushes we can't fsync when closed.
# and we must flush before fsyncing, cause flush flushes to OS buffer
# fsync writes to OS buffer to disk
f.flush() f.flush()
os.fsync(f.fileno()) os.fsync(f.fileno())
else: break
is_new = False await asyncio.sleep(1)
assert False, "could not get certificate"
s = self.get_simple_socket() async def get_certificate(self):
if s is None: sslc = ssl.SSLContext()
return
if self.use_ssl:
try: try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
ca_certs=(temporary_path if is_new else cert_path)) return session.transport._ssl_protocol._sslpipe._sslobj.getpeercert(True)
s = context.wrap_socket(s, do_handshake_on_connect=True) except ValueError:
except socket.timeout: return None
self.print_error('timeout')
return
except ssl.SSLError as e:
self.print_error("SSL error:", e)
if e.errno != 1:
return
if is_new:
rej = cert_path + '.rej'
if os.path.exists(rej):
os.unlink(rej)
os.rename(temporary_path, rej)
else:
util.assert_datadir_available(self.config_path)
with open(cert_path, encoding='utf-8') as f:
cert = f.read()
try:
b = pem.dePem(cert, 'CERTIFICATE')
x = x509.X509(b)
except:
traceback.print_exc(file=sys.stderr)
self.print_error("wrong certificate")
return
try:
x.check_date()
except:
self.print_error("certificate has expired:", cert_path)
os.unlink(cert_path)
return
self.print_error("wrong certificate")
if e.errno == 104:
return
return
except BaseException as e:
self.print_error(e)
traceback.print_exc(file=sys.stderr)
return
if is_new: async def open_session(self, sslc, do_sleep=True, execute_after_connect=lambda: None):
self.print_error("saving certificate") async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
os.rename(temporary_path, cert_path) ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])
print(ver)
return s connect_hook_executed = False
while do_sleep:
def run(self): if not connect_hook_executed:
socket = self.get_socket() connect_hook_executed = True
if socket: execute_after_connect()
self.print_error("connected") await asyncio.wait_for(session.send_request('server.ping'), 5)
self.queue.put((self.server, socket)) await asyncio.sleep(300)
class Interface(util.PrintError):
"""The Interface class handles a socket connected to a single remote
Electrum server. Its exposed API is:
- Member functions close(), fileno(), get_responses(), has_timed_out(),
ping_required(), queue_request(), send_requests()
- Member variable server.
"""
def __init__(self, server, socket):
self.server = server
self.host, _, _ = server.rsplit(':', 2)
self.socket = socket
self.pipe = util.SocketPipe(socket)
self.pipe.set_timeout(0.0) # Don't wait for data
# Dump network messages. Set at runtime from the console.
self.debug = False
self.unsent_requests = []
self.unanswered_requests = {}
self.last_send = time.time()
self.closed_remotely = False
def diagnostic_name(self):
return self.host
def fileno(self):
# Needed for select
return self.socket.fileno()
def close(self):
if not self.closed_remotely:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
def queue_request(self, *args): # method, params, _id
'''Queue a request, later to be send with send_requests when the
socket is available for writing.
'''
self.request_time = time.time()
self.unsent_requests.append(args)
def num_requests(self):
'''Keep unanswered requests below 100'''
n = 100 - len(self.unanswered_requests)
return min(n, len(self.unsent_requests))
def send_requests(self):
'''Sends queued requests. Returns False on failure.'''
self.last_send = time.time()
make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i}
n = self.num_requests()
wire_requests = self.unsent_requests[0:n]
try:
self.pipe.send_all([make_dict(*r) for r in wire_requests])
except BaseException as e:
self.print_error("pipe send error:", e)
return False
self.unsent_requests = self.unsent_requests[n:]
for request in wire_requests:
if self.debug:
self.print_error("-->", request)
self.unanswered_requests[request[2]] = request
return True
def ping_required(self):
'''Returns True if a ping should be sent.'''
return time.time() - self.last_send > 300
def has_timed_out(self): def has_timed_out(self):
'''Returns True if the interface has timed out.''' return self.fut.done()
if (self.unanswered_requests and time.time() - self.request_time > 10
and self.pipe.idle_time() > 10):
self.print_error("timeout", len(self.unanswered_requests))
return True
return False def queue_request(self, method, params, msg_id):
pass
def get_responses(self):
'''Call if there is data available on the socket. Returns a list of
(request, response) pairs. Notifications are singleton
unsolicited responses presumably as a result of prior
subscriptions, so request is None and there is no 'id' member.
Otherwise it is a response, which has an 'id' member and a
corresponding request. If the connection was closed remotely
or the remote server is misbehaving, a (None, None) will appear.
'''
responses = []
while True:
try:
response = self.pipe.get()
except util.timeout:
break
if not type(response) is dict:
responses.append((None, None))
if response is None:
self.closed_remotely = True
self.print_error("connection closed remotely")
break
if self.debug:
self.print_error("<--", response)
wire_id = response.get('id', None)
if wire_id is None: # Notification
responses.append((None, response))
else:
request = self.unanswered_requests.pop(wire_id, None)
if request:
responses.append((request, response))
else:
self.print_error("unknown wire ID", wire_id)
responses.append((None, None)) # Signal
break
return responses
def close(self):
self.fut.cancel()
def check_cert(host, cert): def check_cert(host, cert):
try: try:

View file

@ -47,38 +47,14 @@ from . import blockchain
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
from .i18n import _ from .i18n import _
from .blockchain import InvalidHeader from .blockchain import InvalidHeader
from .interface import Interface
import aiorpcx, asyncio, ssl import asyncio
import concurrent.futures import concurrent.futures
NODES_RETRY_INTERVAL = 60 NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10 SERVER_RETRY_INTERVAL = 10
class Interface(PrintError):
@util.aiosafe
async def run(self):
self.host, self.port, self.protocol = self.server.split(':')
sslc = ssl.SSLContext(ssl.PROTOCOL_TLS) if self.protocol == 's' else None
async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])
print(ver)
while True:
print("sleeping")
await asyncio.sleep(1)
def __init__(self, server):
self.exception = None
self.server = server
self.fut = asyncio.get_event_loop().create_task(self.run())
def has_timed_out(self):
return self.fut.done()
def queue_request(self, method, params, msg_id):
pass
def close(self):
self.fut.cancel()
def parse_servers(result): def parse_servers(result):
""" parse servers list into dict format""" """ parse servers list into dict format"""
@ -539,7 +515,7 @@ class Network(PrintError):
self.close_interface(self.interface) self.close_interface(self.interface)
assert self.interface is None assert self.interface is None
assert not self.interfaces assert not self.interfaces
self.connecting = set() self.connecting.clear()
# Get a new queue - no old pending connections thanks! # Get a new queue - no old pending connections thanks!
self.socket_queue = queue.Queue() self.socket_queue = queue.Queue()
@ -810,7 +786,7 @@ class Network(PrintError):
def new_interface(self, server): def new_interface(self, server):
# todo: get tip first, then decide which checkpoint to use. # todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server) self.add_recent_server(server)
interface = Interface(server) interface = Interface(server, self.config.path, self.connecting)
interface.blockchain = None interface.blockchain = None
interface.tip_header = None interface.tip_header = None
interface.tip = 0 interface.tip = 0
@ -1368,9 +1344,12 @@ class Network(PrintError):
for k, i in self.interfaces.items(): for k, i in self.interfaces.items():
if i.has_timed_out(): if i.has_timed_out():
remove.append(k) remove.append(k)
changed = False
for k in remove: for k in remove:
self.connection_down(k) self.connection_down(k)
changed = True
for i in range(self.num_server - len(self.interfaces)): for i in range(self.num_server - len(self.interfaces)):
self.start_random_interface() self.start_random_interface()
self.notify('updated') changed = True
if changed: self.notify('updated')
await asyncio.sleep(1) await asyncio.sleep(1)