network: tighten checks of server responses for type/sanity

This commit is contained in:
SomberNight 2020-10-16 19:30:42 +02:00
parent c70484455c
commit c5da22a9dd
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
5 changed files with 174 additions and 54 deletions

View file

@ -29,7 +29,7 @@ import sys
import traceback import traceback
import asyncio import asyncio
import socket import socket
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any, Sequence
from collections import defaultdict from collections import defaultdict
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
import itertools import itertools
@ -46,13 +46,14 @@ import certifi
from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy, from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
is_integer, is_non_negative_integer, is_hash256_str, is_hex_str, is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
is_real_number) is_int_or_float, is_non_negative_int_or_float)
from . import util from . import util
from . import x509 from . import x509
from . import pem from . import pem
from . import version from . import version
from . import blockchain from . import blockchain
from .blockchain import Blockchain, HEADER_SIZE from .blockchain import Blockchain, HEADER_SIZE
from . import bitcoin
from . import constants from . import constants
from .i18n import _ from .i18n import _
from .logging import Logger from .logging import Logger
@ -96,9 +97,14 @@ def assert_integer(val: Any) -> None:
raise RequestCorrupted(f'{val!r} should be an integer') raise RequestCorrupted(f'{val!r} should be an integer')
def assert_real_number(val: Any, *, as_str: bool = False) -> None: def assert_int_or_float(val: Any) -> None:
if not is_real_number(val, as_str=as_str): if not is_int_or_float(val):
raise RequestCorrupted(f'{val!r} should be a number') raise RequestCorrupted(f'{val!r} should be int or float')
def assert_non_negative_int_or_float(val: Any) -> None:
if not is_non_negative_int_or_float(val):
raise RequestCorrupted(f'{val!r} should be a non-negative int or float')
def assert_hash256_str(val: Any) -> None: def assert_hash256_str(val: Any) -> None:
@ -656,14 +662,13 @@ class Interface(Logger):
async def request_fee_estimates(self): async def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS from .simple_config import FEE_ETA_TARGETS
from .bitcoin import COIN
while True: while True:
async with TaskGroup() as group: async with TaskGroup() as group:
fee_tasks = [] fee_tasks = []
for i in FEE_ETA_TARGETS: for i in FEE_ETA_TARGETS:
fee_tasks.append((i, await group.spawn(self.session.send_request('blockchain.estimatefee', [i])))) fee_tasks.append((i, await group.spawn(self.get_estimatefee(i))))
for nblock_target, task in fee_tasks: for nblock_target, task in fee_tasks:
fee = int(task.result() * COIN) fee = task.result()
if fee < 0: continue if fee < 0: continue
self.fee_estimates_eta[nblock_target] = fee self.fee_estimates_eta[nblock_target] = fee
self.network.update_fee_estimates() self.network.update_fee_estimates()
@ -983,6 +988,61 @@ class Interface(Logger):
assert_hash256_str(res) assert_hash256_str(res)
return res return res
async def get_fee_histogram(self) -> Sequence[Tuple[Union[float, int], int]]:
# do request
res = await self.session.send_request('mempool.get_fee_histogram')
# check response
assert_list_or_tuple(res)
for fee, s in res:
assert_non_negative_int_or_float(fee)
assert_non_negative_integer(s)
return res
async def get_server_banner(self) -> str:
# do request
res = await self.session.send_request('server.banner')
# check response
if not isinstance(res, str):
raise RequestCorrupted(f'{res!r} should be a str')
return res
async def get_donation_address(self) -> str:
# do request
res = await self.session.send_request('server.donation_address')
# check response
if not res: # ignore empty string
return ''
if not bitcoin.is_address(res):
# note: do not hard-fail -- allow server to use future-type
# bitcoin address we do not recognize
self.logger.info(f"invalid donation address from server: {repr(res)}")
res = ''
return res
async def get_relay_fee(self) -> int:
"""Returns the min relay feerate in sat/kbyte."""
# do request
res = await self.session.send_request('blockchain.relayfee')
# check response
assert_non_negative_int_or_float(res)
relayfee = int(res * bitcoin.COIN)
relayfee = max(0, relayfee)
return relayfee
async def get_estimatefee(self, num_blocks: int) -> int:
"""Returns a feerate estimate for getting confirmed within
num_blocks blocks, in sat/kbyte.
"""
if not is_non_negative_integer(num_blocks):
raise Exception(f"{repr(num_blocks)} is not a num_blocks")
# do request
res = await self.session.send_request('blockchain.estimatefee', [num_blocks])
# check response
if res != -1:
assert_non_negative_int_or_float(res)
res = int(res * bitcoin.COIN)
return res
def _assert_header_does_not_check_against_any_chain(header: dict) -> None: def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header) chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)

View file

@ -418,20 +418,15 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
def is_connecting(self): def is_connecting(self):
return self.connection_status == 'connecting' return self.connection_status == 'connecting'
async def _request_server_info(self, interface): async def _request_server_info(self, interface: 'Interface'):
await interface.ready await interface.ready
session = interface.session session = interface.session
async def get_banner(): async def get_banner():
self.banner = await session.send_request('server.banner') self.banner = await interface.get_server_banner()
self.notify('banner') self.notify('banner')
async def get_donation_address(): async def get_donation_address():
addr = await session.send_request('server.donation_address') self.donation_address = await interface.get_donation_address()
if not bitcoin.is_address(addr):
if addr: # ignore empty string
self.logger.info(f"invalid donation address from server: {repr(addr)}")
addr = ''
self.donation_address = addr
async def get_server_peers(): async def get_server_peers():
server_peers = await session.send_request('server.peers.subscribe') server_peers = await session.send_request('server.peers.subscribe')
random.shuffle(server_peers) random.shuffle(server_peers)
@ -441,12 +436,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
self.server_peers = parse_servers(server_peers) self.server_peers = parse_servers(server_peers)
self.notify('servers') self.notify('servers')
async def get_relay_fee(): async def get_relay_fee():
relayfee = await session.send_request('blockchain.relayfee') self.relay_fee = await interface.get_relay_fee()
if relayfee is None:
self.relay_fee = None
else:
relayfee = int(relayfee * COIN)
self.relay_fee = max(0, relayfee)
async with TaskGroup() as group: async with TaskGroup() as group:
await group.spawn(get_banner) await group.spawn(get_banner)
@ -456,9 +446,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
await group.spawn(self._request_fee_estimates(interface)) await group.spawn(self._request_fee_estimates(interface))
async def _request_fee_estimates(self, interface): async def _request_fee_estimates(self, interface):
session = interface.session
self.config.requested_fee_estimates() self.config.requested_fee_estimates()
histogram = await session.send_request('mempool.get_fee_histogram') histogram = await interface.get_fee_histogram()
self.config.mempool_fees = histogram self.config.mempool_fees = histogram
self.logger.info(f'fee_histogram {histogram}') self.logger.info(f'fee_histogram {histogram}')
self.notify('fee_histogram') self.notify('fee_histogram')

View file

@ -5,7 +5,7 @@ import os
import stat import stat
import ssl import ssl
from decimal import Decimal from decimal import Decimal
from typing import Union, Optional, Dict from typing import Union, Optional, Dict, Sequence, Tuple
from numbers import Real from numbers import Real
from copy import deepcopy from copy import deepcopy
@ -65,7 +65,7 @@ class SimpleConfig(Logger):
# a thread-safe way. # a thread-safe way.
self.lock = threading.RLock() self.lock = threading.RLock()
self.mempool_fees = {} # type: Dict[Union[float, int], int] self.mempool_fees = [] # type: Sequence[Tuple[Union[float, int], int]]
self.fee_estimates = {} self.fee_estimates = {}
self.fee_estimates_last_updated = {} self.fee_estimates_last_updated = {}
self.last_time_fee_estimates_requested = 0 # zero ensures immediate fees self.last_time_fee_estimates_requested = 0 # zero ensures immediate fees

View file

@ -2,7 +2,9 @@ from decimal import Decimal
from electrum.util import (format_satoshis, format_fee_satoshis, parse_URI, from electrum.util import (format_satoshis, format_fee_satoshis, parse_URI,
is_hash256_str, chunks, is_ip_address, list_enabled_bits, is_hash256_str, chunks, is_ip_address, list_enabled_bits,
format_satoshis_plain, is_private_netaddress) format_satoshis_plain, is_private_netaddress, is_hex_str,
is_integer, is_non_negative_integer, is_int_or_float,
is_non_negative_int_or_float)
from . import ElectrumTestCase from . import ElectrumTestCase
@ -121,6 +123,89 @@ class TestUtil(ElectrumTestCase):
self.assertFalse(is_hash256_str(None)) self.assertFalse(is_hash256_str(None))
self.assertFalse(is_hash256_str(7)) self.assertFalse(is_hash256_str(7))
def test_is_hex_str(self):
self.assertTrue(is_hex_str('09a4'))
self.assertTrue(is_hex_str('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertTrue(is_hex_str('00' * 33))
self.assertFalse(is_hex_str('000'))
self.assertFalse(is_hex_str('qweqwe'))
self.assertFalse(is_hex_str(None))
self.assertFalse(is_hex_str(7))
def test_is_integer(self):
self.assertTrue(is_integer(7))
self.assertTrue(is_integer(0))
self.assertTrue(is_integer(-1))
self.assertTrue(is_integer(-7))
self.assertFalse(is_integer(Decimal("2.0")))
self.assertFalse(is_integer(Decimal(2.0)))
self.assertFalse(is_integer(Decimal(2)))
self.assertFalse(is_integer(0.72))
self.assertFalse(is_integer(2.0))
self.assertFalse(is_integer(-2.0))
self.assertFalse(is_integer('09a4'))
self.assertFalse(is_integer('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertFalse(is_integer('000'))
self.assertFalse(is_integer('qweqwe'))
self.assertFalse(is_integer(None))
def test_is_non_negative_integer(self):
self.assertTrue(is_non_negative_integer(7))
self.assertTrue(is_non_negative_integer(0))
self.assertFalse(is_non_negative_integer(Decimal("2.0")))
self.assertFalse(is_non_negative_integer(Decimal(2.0)))
self.assertFalse(is_non_negative_integer(Decimal(2)))
self.assertFalse(is_non_negative_integer(0.72))
self.assertFalse(is_non_negative_integer(2.0))
self.assertFalse(is_non_negative_integer(-2.0))
self.assertFalse(is_non_negative_integer(-1))
self.assertFalse(is_non_negative_integer(-7))
self.assertFalse(is_non_negative_integer('09a4'))
self.assertFalse(is_non_negative_integer('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertFalse(is_non_negative_integer('000'))
self.assertFalse(is_non_negative_integer('qweqwe'))
self.assertFalse(is_non_negative_integer(None))
def test_is_int_or_float(self):
self.assertTrue(is_int_or_float(7))
self.assertTrue(is_int_or_float(0))
self.assertTrue(is_int_or_float(-1))
self.assertTrue(is_int_or_float(-7))
self.assertTrue(is_int_or_float(0.72))
self.assertTrue(is_int_or_float(2.0))
self.assertTrue(is_int_or_float(-2.0))
self.assertFalse(is_int_or_float(Decimal("2.0")))
self.assertFalse(is_int_or_float(Decimal(2.0)))
self.assertFalse(is_int_or_float(Decimal(2)))
self.assertFalse(is_int_or_float('09a4'))
self.assertFalse(is_int_or_float('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertFalse(is_int_or_float('000'))
self.assertFalse(is_int_or_float('qweqwe'))
self.assertFalse(is_int_or_float(None))
def test_is_non_negative_int_or_float(self):
self.assertTrue(is_non_negative_int_or_float(7))
self.assertTrue(is_non_negative_int_or_float(0))
self.assertTrue(is_non_negative_int_or_float(0.0))
self.assertTrue(is_non_negative_int_or_float(0.72))
self.assertTrue(is_non_negative_int_or_float(2.0))
self.assertFalse(is_non_negative_int_or_float(-1))
self.assertFalse(is_non_negative_int_or_float(-7))
self.assertFalse(is_non_negative_int_or_float(-2.0))
self.assertFalse(is_non_negative_int_or_float(Decimal("2.0")))
self.assertFalse(is_non_negative_int_or_float(Decimal(2.0)))
self.assertFalse(is_non_negative_int_or_float(Decimal(2)))
self.assertFalse(is_non_negative_int_or_float('09a4'))
self.assertFalse(is_non_negative_int_or_float('2A5C3F4062E4F2FCCE7A1C7B4310CB647B327409F580F4ED72CB8FC0B1804DFA'))
self.assertFalse(is_non_negative_int_or_float('000'))
self.assertFalse(is_non_negative_int_or_float('qweqwe'))
self.assertFalse(is_non_negative_int_or_float(None))
def test_chunks(self): def test_chunks(self):
self.assertEqual([[1, 2], [3, 4], [5]], self.assertEqual([[1, 2], [3, 4], [5]],
list(chunks([1, 2, 3, 4, 5], 2))) list(chunks([1, 2, 3, 4, 5], 2)))

View file

@ -588,38 +588,24 @@ def is_hex_str(text: Any) -> bool:
return True return True
def is_non_negative_integer(val) -> bool: def is_integer(val: Any) -> bool:
try: return isinstance(val, int)
val = int(val)
if val >= 0:
return True def is_non_negative_integer(val: Any) -> bool:
except: if is_integer(val):
pass return val >= 0
return False return False
def is_integer(val) -> bool: def is_int_or_float(val: Any) -> bool:
try: return isinstance(val, (int, float))
int(val)
except:
return False
else:
return True
def is_real_number(val, *, as_str: bool = False) -> bool: def is_non_negative_int_or_float(val: Any) -> bool:
if as_str: # only accept str if is_int_or_float(val):
if not isinstance(val, str): return val >= 0
return False return False
else: # only accept int/float/etc.
if isinstance(val, str):
return False
try:
Decimal(val)
except:
return False
else:
return True
def chunks(items, size: int): def chunks(items, size: int):