aiorpcx: pin certificates

This commit is contained in:
Janus 2018-08-16 18:16:25 +02:00 committed by SomberNight
parent 8080a713b2
commit 89a01a6463
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
3 changed files with 109 additions and 342 deletions

View file

@ -4,7 +4,7 @@ from .wallet import Wallet
from .storage import WalletStorage from .storage import WalletStorage
from .coinchooser import COIN_CHOOSERS from .coinchooser import COIN_CHOOSERS
from .network import Network, pick_random_server from .network import Network, pick_random_server
from .interface import Connection, Interface from .interface import Interface
from .simple_config import SimpleConfig, get_config, set_config from .simple_config import SimpleConfig, get_config, set_config
from . import bitcoin from . import bitcoin
from . import transaction from . import transaction

View file

@ -28,343 +28,131 @@ import socket
import ssl import ssl
import sys import sys
import threading import threading
import time
import traceback import traceback
import aiorpcx
import asyncio
import requests import requests
from .util import print_error from .util import PrintError
ca_path = requests.certs.where() ca_path = requests.certs.where()
from . import util from . import util
from . import x509 from . import x509
from . import pem from . import pem
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
class Interface(PrintError):
def Connection(server, queue, config_path): def __init__(self, server, config_path, connecting):
"""Makes asynchronous connections to a remote Electrum server. self.connecting = connecting
Returns the running thread that is making the connection. self.server = server
self.host, self.port, self.protocol = self.server.split(':')
Once the thread has connected, it finishes, placing a tuple on the
queue of the form (server, socket), where socket is None if
connection failed.
"""
host, port, protocol = server.rsplit(':', 2)
if not protocol in 'st':
raise Exception('Unknown protocol: %s' % protocol)
c = TcpConnection(server, queue, config_path)
c.start()
return c
class TcpConnection(threading.Thread, util.PrintError):
verbosity_filter = 'i'
def __init__(self, server, queue, config_path):
threading.Thread.__init__(self)
self.config_path = config_path self.config_path = config_path
self.queue = queue self.cert_path = os.path.join(self.config_path, 'certs', self.host)
self.server = server self.fut = asyncio.get_event_loop().create_task(self.run())
self.host, self.port, self.protocol = self.server.rsplit(':', 2)
self.host = str(self.host)
self.port = int(self.port)
self.use_ssl = (self.protocol == 's')
self.daemon = True
def diagnostic_name(self): def diagnostic_name(self):
return self.host return self.host
def check_host_name(self, peercert, name): async def is_server_ca_signed(self, sslc):
"""Simple certificate/host name checker. Returns True if the
certificate matches, False otherwise. Does not support
wildcards."""
# Check that the peer has supplied a certificate.
# None/{} is not acceptable.
if not peercert:
return False
if 'subjectAltName' in peercert:
for typ, val in peercert["subjectAltName"]:
if typ == "DNS" and val == name:
return True
else:
# Only check the subject DN if there is no subject alternative
# name.
cn = None
for attr, val in peercert["subject"]:
# Use most-specific (last) commonName attribute.
if attr == "commonName":
cn = val
if cn is not None:
return cn == name
return False
def get_simple_socket(self):
try: try:
l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) await self.open_session(sslc, do_sleep=False)
except socket.gaierror: except ssl.SSLError as e:
self.print_error("cannot resolve hostname") assert e.reason == 'CERTIFICATE_VERIFY_FAILED'
return
e = None
for res in l:
try:
s = socket.socket(res[0], socket.SOCK_STREAM)
s.settimeout(10)
s.connect(res[4])
s.settimeout(2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
return s
except BaseException as _e:
e = _e
continue
else:
self.print_error("failed to connect", str(e))
@staticmethod
def get_ssl_context(cert_reqs, ca_certs):
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_certs)
context.check_hostname = False
context.verify_mode = cert_reqs
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
return context
def get_socket(self):
if self.use_ssl:
cert_path = os.path.join(self.config_path, 'certs', self.host)
if not os.path.exists(cert_path):
is_new = True
s = self.get_simple_socket()
if s is None:
return
# try with CA first
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path)
s = context.wrap_socket(s, do_handshake_on_connect=True)
except ssl.SSLError as e:
self.print_error(e)
except:
return
else:
try:
peer_cert = s.getpeercert()
except OSError:
return
if self.check_host_name(peer_cert, self.host):
self.print_error("SSL certificate signed by CA")
return s
# get server certificate.
# Do not use ssl.get_server_certificate because it does not work with proxy
s = self.get_simple_socket()
if s is None:
return
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None)
s = context.wrap_socket(s)
except ssl.SSLError as e:
self.print_error("SSL error retrieving SSL certificate:", e)
return
except:
return
try:
dercert = s.getpeercert(True)
except OSError:
return
s.close()
cert = ssl.DER_cert_to_PEM_cert(dercert)
# workaround android bug
cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
temporary_path = cert_path + '.temp'
util.assert_datadir_available(self.config_path)
with open(temporary_path, "w", encoding='utf-8') as f:
f.write(cert)
f.flush()
os.fsync(f.fileno())
else:
is_new = False
s = self.get_simple_socket()
if s is None:
return
if self.use_ssl:
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED,
ca_certs=(temporary_path if is_new else cert_path))
s = context.wrap_socket(s, do_handshake_on_connect=True)
except socket.timeout:
self.print_error('timeout')
return
except ssl.SSLError as e:
self.print_error("SSL error:", e)
if e.errno != 1:
return
if is_new:
rej = cert_path + '.rej'
if os.path.exists(rej):
os.unlink(rej)
os.rename(temporary_path, rej)
else:
util.assert_datadir_available(self.config_path)
with open(cert_path, encoding='utf-8') as f:
cert = f.read()
try:
b = pem.dePem(cert, 'CERTIFICATE')
x = x509.X509(b)
except:
traceback.print_exc(file=sys.stderr)
self.print_error("wrong certificate")
return
try:
x.check_date()
except:
self.print_error("certificate has expired:", cert_path)
os.unlink(cert_path)
return
self.print_error("wrong certificate")
if e.errno == 104:
return
return
except BaseException as e:
self.print_error(e)
traceback.print_exc(file=sys.stderr)
return
if is_new:
self.print_error("saving certificate")
os.rename(temporary_path, cert_path)
return s
def run(self):
socket = self.get_socket()
if socket:
self.print_error("connected")
self.queue.put((self.server, socket))
class Interface(util.PrintError):
"""The Interface class handles a socket connected to a single remote
Electrum server. Its exposed API is:
- Member functions close(), fileno(), get_responses(), has_timed_out(),
ping_required(), queue_request(), send_requests()
- Member variable server.
"""
def __init__(self, server, socket):
self.server = server
self.host, _, _ = server.rsplit(':', 2)
self.socket = socket
self.pipe = util.SocketPipe(socket)
self.pipe.set_timeout(0.0) # Don't wait for data
# Dump network messages. Set at runtime from the console.
self.debug = False
self.unsent_requests = []
self.unanswered_requests = {}
self.last_send = time.time()
self.closed_remotely = False
def diagnostic_name(self):
return self.host
def fileno(self):
# Needed for select
return self.socket.fileno()
def close(self):
if not self.closed_remotely:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
def queue_request(self, *args): # method, params, _id
'''Queue a request, later to be send with send_requests when the
socket is available for writing.
'''
self.request_time = time.time()
self.unsent_requests.append(args)
def num_requests(self):
'''Keep unanswered requests below 100'''
n = 100 - len(self.unanswered_requests)
return min(n, len(self.unsent_requests))
def send_requests(self):
'''Sends queued requests. Returns False on failure.'''
self.last_send = time.time()
make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i}
n = self.num_requests()
wire_requests = self.unsent_requests[0:n]
try:
self.pipe.send_all([make_dict(*r) for r in wire_requests])
except BaseException as e:
self.print_error("pipe send error:", e)
return False return False
self.unsent_requests = self.unsent_requests[n:]
for request in wire_requests:
if self.debug:
self.print_error("-->", request)
self.unanswered_requests[request[2]] = request
return True return True
def ping_required(self): @util.aiosafe
'''Returns True if a ping should be sent.''' async def run(self):
return time.time() - self.last_send > 300 if self.protocol != 's':
await self.open_session(None, execute_after_connect=lambda: self.connecting.remove(self.server))
return
ca_sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
exists = os.path.exists(self.cert_path)
if exists:
with open(self.cert_path, 'r') as f:
contents = f.read()
if contents != '': # if not CA signed
try:
b = pem.dePem(contents, 'CERTIFICATE')
except SyntaxError:
exists = False
else:
x = x509.X509(b)
try:
x.check_date()
except x509.CertificateError:
self.print_error("certificate has expired:", self.cert_path)
os.unlink(self.cert_path)
exists = False
if not exists:
ca_signed = await self.is_server_ca_signed(ca_sslc)
if ca_signed:
with open(self.cert_path, 'w') as f:
# empty file means this is CA signed, not self-signed
f.write('')
else:
await self.save_certificate()
siz = os.stat(self.cert_path).st_size
if siz == 0: # if CA signed
sslc = ca_sslc
else:
sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
sslc.check_hostname = 0
await self.open_session(sslc, execute_after_connect=lambda: self.connecting.remove(self.server))
async def save_certificate(self):
if not os.path.exists(self.cert_path):
# we may need to retry this a few times, in case the handshake hasn't completed
for _ in range(10):
dercert = await self.get_certificate()
if dercert:
self.print_error("succeeded in getting cert")
with open(self.cert_path, 'w') as f:
cert = ssl.DER_cert_to_PEM_cert(dercert)
# workaround android bug
cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
f.write(cert)
# even though close flushes we can't fsync when closed.
# and we must flush before fsyncing, cause flush flushes to OS buffer
# fsync writes to OS buffer to disk
f.flush()
os.fsync(f.fileno())
break
await asyncio.sleep(1)
assert False, "could not get certificate"
async def get_certificate(self):
sslc = ssl.SSLContext()
try:
async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
return session.transport._ssl_protocol._sslpipe._sslobj.getpeercert(True)
except ValueError:
return None
async def open_session(self, sslc, do_sleep=True, execute_after_connect=lambda: None):
async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])
print(ver)
connect_hook_executed = False
while do_sleep:
if not connect_hook_executed:
connect_hook_executed = True
execute_after_connect()
await asyncio.wait_for(session.send_request('server.ping'), 5)
await asyncio.sleep(300)
def has_timed_out(self): def has_timed_out(self):
'''Returns True if the interface has timed out.''' return self.fut.done()
if (self.unanswered_requests and time.time() - self.request_time > 10
and self.pipe.idle_time() > 10):
self.print_error("timeout", len(self.unanswered_requests))
return True
return False def queue_request(self, method, params, msg_id):
pass
def get_responses(self):
'''Call if there is data available on the socket. Returns a list of
(request, response) pairs. Notifications are singleton
unsolicited responses presumably as a result of prior
subscriptions, so request is None and there is no 'id' member.
Otherwise it is a response, which has an 'id' member and a
corresponding request. If the connection was closed remotely
or the remote server is misbehaving, a (None, None) will appear.
'''
responses = []
while True:
try:
response = self.pipe.get()
except util.timeout:
break
if not type(response) is dict:
responses.append((None, None))
if response is None:
self.closed_remotely = True
self.print_error("connection closed remotely")
break
if self.debug:
self.print_error("<--", response)
wire_id = response.get('id', None)
if wire_id is None: # Notification
responses.append((None, response))
else:
request = self.unanswered_requests.pop(wire_id, None)
if request:
responses.append((request, response))
else:
self.print_error("unknown wire ID", wire_id)
responses.append((None, None)) # Signal
break
return responses
def close(self):
self.fut.cancel()
def check_cert(host, cert): def check_cert(host, cert):
try: try:

View file

@ -47,38 +47,14 @@ from . import blockchain
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
from .i18n import _ from .i18n import _
from .blockchain import InvalidHeader from .blockchain import InvalidHeader
from .interface import Interface
import aiorpcx, asyncio, ssl import asyncio
import concurrent.futures import concurrent.futures
NODES_RETRY_INTERVAL = 60 NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10 SERVER_RETRY_INTERVAL = 10
class Interface(PrintError):
@util.aiosafe
async def run(self):
self.host, self.port, self.protocol = self.server.split(':')
sslc = ssl.SSLContext(ssl.PROTOCOL_TLS) if self.protocol == 's' else None
async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])
print(ver)
while True:
print("sleeping")
await asyncio.sleep(1)
def __init__(self, server):
self.exception = None
self.server = server
self.fut = asyncio.get_event_loop().create_task(self.run())
def has_timed_out(self):
return self.fut.done()
def queue_request(self, method, params, msg_id):
pass
def close(self):
self.fut.cancel()
def parse_servers(result): def parse_servers(result):
""" parse servers list into dict format""" """ parse servers list into dict format"""
@ -539,7 +515,7 @@ class Network(PrintError):
self.close_interface(self.interface) self.close_interface(self.interface)
assert self.interface is None assert self.interface is None
assert not self.interfaces assert not self.interfaces
self.connecting = set() self.connecting.clear()
# Get a new queue - no old pending connections thanks! # Get a new queue - no old pending connections thanks!
self.socket_queue = queue.Queue() self.socket_queue = queue.Queue()
@ -810,7 +786,7 @@ class Network(PrintError):
def new_interface(self, server): def new_interface(self, server):
# todo: get tip first, then decide which checkpoint to use. # todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server) self.add_recent_server(server)
interface = Interface(server) interface = Interface(server, self.config.path, self.connecting)
interface.blockchain = None interface.blockchain = None
interface.tip_header = None interface.tip_header = None
interface.tip = 0 interface.tip = 0
@ -1368,9 +1344,12 @@ class Network(PrintError):
for k, i in self.interfaces.items(): for k, i in self.interfaces.items():
if i.has_timed_out(): if i.has_timed_out():
remove.append(k) remove.append(k)
changed = False
for k in remove: for k in remove:
self.connection_down(k) self.connection_down(k)
changed = True
for i in range(self.num_server - len(self.interfaces)): for i in range(self.num_server - len(self.interfaces)):
self.start_random_interface() self.start_random_interface()
self.notify('updated') changed = True
if changed: self.notify('updated')
await asyncio.sleep(1) await asyncio.sleep(1)