mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-31 01:11:35 +00:00
interface: check server response for some methods
some basic sanity checks Previously if the server sent back a malformed response, it could partially corrupt a wallet file. (as sometimes the response would get persisted, and issues would only arise later when the values were used)
This commit is contained in:
parent
3393ff757e
commit
d19ff43266
4 changed files with 200 additions and 43 deletions
|
@ -29,7 +29,7 @@ import sys
|
|||
import traceback
|
||||
import asyncio
|
||||
import socket
|
||||
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
|
||||
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple, Any
|
||||
from collections import defaultdict
|
||||
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address, IPv4Address
|
||||
import itertools
|
||||
|
@ -44,16 +44,19 @@ from aiorpcx.jsonrpc import JSONRPC, CodeMessageError
|
|||
from aiorpcx.rawsocket import RSClient
|
||||
import certifi
|
||||
|
||||
from .util import ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy
|
||||
from .util import (ignore_exceptions, log_exceptions, bfh, SilentTaskGroup, MySocksProxy,
|
||||
is_integer, is_non_negative_integer, is_hash256_str, is_hex_str,
|
||||
is_real_number)
|
||||
from . import util
|
||||
from . import x509
|
||||
from . import pem
|
||||
from . import version
|
||||
from . import blockchain
|
||||
from .blockchain import Blockchain
|
||||
from .blockchain import Blockchain, HEADER_SIZE
|
||||
from . import constants
|
||||
from .i18n import _
|
||||
from .logging import Logger
|
||||
from .transaction import Transaction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import Network
|
||||
|
@ -82,6 +85,45 @@ class NetworkTimeout:
|
|||
RELAXED = 20
|
||||
MOST_RELAXED = 60
|
||||
|
||||
|
||||
def assert_non_negative_integer(val: Any) -> None:
|
||||
if not is_non_negative_integer(val):
|
||||
raise RequestCorrupted(f'{val!r} should be a non-negative integer')
|
||||
|
||||
|
||||
def assert_integer(val: Any) -> None:
|
||||
if not is_integer(val):
|
||||
raise RequestCorrupted(f'{val!r} should be an integer')
|
||||
|
||||
|
||||
def assert_real_number(val: Any, *, as_str: bool = False) -> None:
|
||||
if not is_real_number(val, as_str=as_str):
|
||||
raise RequestCorrupted(f'{val!r} should be a number')
|
||||
|
||||
|
||||
def assert_hash256_str(val: Any) -> None:
|
||||
if not is_hash256_str(val):
|
||||
raise RequestCorrupted(f'{val!r} should be a hash256 str')
|
||||
|
||||
|
||||
def assert_hex_str(val: Any) -> None:
|
||||
if not is_hex_str(val):
|
||||
raise RequestCorrupted(f'{val!r} should be a hex str')
|
||||
|
||||
|
||||
def assert_dict_contains_field(d: Any, *, field_name: str) -> Any:
|
||||
if not isinstance(d, dict):
|
||||
raise RequestCorrupted(f'{d!r} should be a dict')
|
||||
if field_name not in d:
|
||||
raise RequestCorrupted(f'required field {field_name!r} missing from dict')
|
||||
return d[field_name]
|
||||
|
||||
|
||||
def assert_list_or_tuple(val: Any) -> None:
|
||||
if not isinstance(val, (list, tuple)):
|
||||
raise RequestCorrupted(f'{val!r} should be a list or tuple')
|
||||
|
||||
|
||||
class NotificationSession(RPCSession):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -187,7 +229,7 @@ class RequestTimedOut(GracefulDisconnect):
|
|||
return _("Network request timed out.")
|
||||
|
||||
|
||||
class RequestCorrupted(GracefulDisconnect): pass
|
||||
class RequestCorrupted(Exception): pass
|
||||
|
||||
class ErrorParsingSSLCert(Exception): pass
|
||||
class ErrorGettingSSLCertFromServer(Exception): pass
|
||||
|
@ -529,6 +571,8 @@ class Interface(Logger):
|
|||
return blockchain.deserialize_header(bytes.fromhex(res), height)
|
||||
|
||||
async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
|
||||
if not is_non_negative_integer(height):
|
||||
raise Exception(f"{repr(height)} is not a block height")
|
||||
index = height // 2016
|
||||
if can_return_early and index in self._requested_chunks:
|
||||
return
|
||||
|
@ -542,6 +586,16 @@ class Interface(Logger):
|
|||
res = await self.session.send_request('blockchain.block.headers', [index * 2016, size])
|
||||
finally:
|
||||
self._requested_chunks.discard(index)
|
||||
assert_dict_contains_field(res, field_name='count')
|
||||
assert_dict_contains_field(res, field_name='hex')
|
||||
assert_dict_contains_field(res, field_name='max')
|
||||
assert_non_negative_integer(res['count'])
|
||||
assert_non_negative_integer(res['max'])
|
||||
assert_hex_str(res['hex'])
|
||||
if len(res['hex']) != HEADER_SIZE * 2 * res['count']:
|
||||
raise RequestCorrupted('inconsistent chunk hex and count')
|
||||
if res['count'] != size:
|
||||
raise RequestCorrupted(f"expected {size} headers but only got {res['count']}")
|
||||
conn = self.blockchain.connect_chunk(index, res['hex'])
|
||||
if not conn:
|
||||
return conn, 0
|
||||
|
@ -819,6 +873,108 @@ class Interface(Logger):
|
|||
self._ipaddr_bucket = do_bucket()
|
||||
return self._ipaddr_bucket
|
||||
|
||||
async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
|
||||
if not is_hash256_str(tx_hash):
|
||||
raise Exception(f"{repr(tx_hash)} is not a txid")
|
||||
if not is_non_negative_integer(tx_height):
|
||||
raise Exception(f"{repr(tx_height)} is not a block height")
|
||||
# do request
|
||||
res = await self.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
|
||||
# check response
|
||||
block_height = assert_dict_contains_field(res, field_name='block_height')
|
||||
merkle = assert_dict_contains_field(res, field_name='merkle')
|
||||
pos = assert_dict_contains_field(res, field_name='pos')
|
||||
# note: tx_height was just a hint to the server, don't enforce the response to match it
|
||||
assert_non_negative_integer(block_height)
|
||||
assert_non_negative_integer(pos)
|
||||
assert_list_or_tuple(merkle)
|
||||
for item in merkle:
|
||||
assert_hash256_str(item)
|
||||
return res
|
||||
|
||||
async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
|
||||
if not is_hash256_str(tx_hash):
|
||||
raise Exception(f"{repr(tx_hash)} is not a txid")
|
||||
raw = await self.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
|
||||
# validate response
|
||||
tx = Transaction(raw)
|
||||
try:
|
||||
tx.deserialize() # see if raises
|
||||
except Exception as e:
|
||||
raise RequestCorrupted(f"cannot deserialize received transaction (txid {tx_hash})") from e
|
||||
if tx.txid() != tx_hash:
|
||||
raise RequestCorrupted(f"received tx does not match expected txid {tx_hash} (got {tx.txid()})")
|
||||
return raw
|
||||
|
||||
async def get_history_for_scripthash(self, sh: str) -> List[dict]:
|
||||
if not is_hash256_str(sh):
|
||||
raise Exception(f"{repr(sh)} is not a scripthash")
|
||||
# do request
|
||||
res = await self.session.send_request('blockchain.scripthash.get_history', [sh])
|
||||
# check response
|
||||
assert_list_or_tuple(res)
|
||||
for tx_item in res:
|
||||
assert_dict_contains_field(tx_item, field_name='height')
|
||||
assert_dict_contains_field(tx_item, field_name='tx_hash')
|
||||
assert_integer(tx_item['height'])
|
||||
assert_hash256_str(tx_item['tx_hash'])
|
||||
if tx_item['height'] in (-1, 0):
|
||||
assert_dict_contains_field(tx_item, field_name='fee')
|
||||
assert_non_negative_integer(tx_item['fee'])
|
||||
return res
|
||||
|
||||
async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
|
||||
if not is_hash256_str(sh):
|
||||
raise Exception(f"{repr(sh)} is not a scripthash")
|
||||
# do request
|
||||
res = await self.session.send_request('blockchain.scripthash.listunspent', [sh])
|
||||
# check response
|
||||
assert_list_or_tuple(res)
|
||||
for utxo_item in res:
|
||||
assert_dict_contains_field(utxo_item, field_name='tx_pos')
|
||||
assert_dict_contains_field(utxo_item, field_name='value')
|
||||
assert_dict_contains_field(utxo_item, field_name='tx_hash')
|
||||
assert_dict_contains_field(utxo_item, field_name='height')
|
||||
assert_non_negative_integer(utxo_item['tx_pos'])
|
||||
assert_non_negative_integer(utxo_item['value'])
|
||||
assert_non_negative_integer(utxo_item['height'])
|
||||
assert_hash256_str(utxo_item['tx_hash'])
|
||||
return res
|
||||
|
||||
async def get_balance_for_scripthash(self, sh: str) -> dict:
|
||||
if not is_hash256_str(sh):
|
||||
raise Exception(f"{repr(sh)} is not a scripthash")
|
||||
# do request
|
||||
res = await self.session.send_request('blockchain.scripthash.get_balance', [sh])
|
||||
# check response
|
||||
assert_dict_contains_field(res, field_name='confirmed')
|
||||
assert_dict_contains_field(res, field_name='unconfirmed')
|
||||
assert_non_negative_integer(res['confirmed'])
|
||||
assert_non_negative_integer(res['unconfirmed'])
|
||||
return res
|
||||
|
||||
async def get_txid_from_txpos(self, tx_height: int, tx_pos: int, merkle: bool):
|
||||
if not is_non_negative_integer(tx_height):
|
||||
raise Exception(f"{repr(tx_height)} is not a block height")
|
||||
if not is_non_negative_integer(tx_pos):
|
||||
raise Exception(f"{repr(tx_pos)} should be non-negative integer")
|
||||
# do request
|
||||
res = await self.session.send_request(
|
||||
'blockchain.transaction.id_from_pos',
|
||||
[tx_height, tx_pos, merkle],
|
||||
)
|
||||
# check response
|
||||
if merkle:
|
||||
assert_dict_contains_field(res, field_name='tx_hash')
|
||||
assert_dict_contains_field(res, field_name='merkle')
|
||||
assert_hash256_str(res['tx_hash'])
|
||||
assert_list_or_tuple(res['merkle'])
|
||||
for node_hash in res['merkle']:
|
||||
assert_hash256_str(node_hash)
|
||||
else:
|
||||
assert_hash256_str(res)
|
||||
return res
|
||||
|
||||
|
||||
def _assert_header_does_not_check_against_any_chain(header: dict) -> None:
|
||||
chain_bad = blockchain.check_header(header) if 'mock' not in header else header['mock']['check'](header)
|
||||
|
|
|
@ -816,7 +816,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
|||
if success_fut.exception():
|
||||
try:
|
||||
raise success_fut.exception()
|
||||
except (RequestTimedOut, RequestCorrupted):
|
||||
except RequestTimedOut:
|
||||
await iface.close()
|
||||
await iface.got_disconnected
|
||||
continue # try again
|
||||
except RequestCorrupted as e:
|
||||
# TODO ban server?
|
||||
iface.logger.exception(f"RequestCorrupted: {e}")
|
||||
await iface.close()
|
||||
await iface.got_disconnected
|
||||
continue # try again
|
||||
|
@ -836,11 +842,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
|||
@best_effort_reliable
|
||||
@catch_server_exceptions
|
||||
async def get_merkle_for_transaction(self, tx_hash: str, tx_height: int) -> dict:
|
||||
if not is_hash256_str(tx_hash):
|
||||
raise Exception(f"{repr(tx_hash)} is not a txid")
|
||||
if not is_non_negative_integer(tx_height):
|
||||
raise Exception(f"{repr(tx_height)} is not a block height")
|
||||
return await self.interface.session.send_request('blockchain.transaction.get_merkle', [tx_hash, tx_height])
|
||||
return await self.interface.get_merkle_for_transaction(tx_hash=tx_hash, tx_height=tx_height)
|
||||
|
||||
@best_effort_reliable
|
||||
async def broadcast_transaction(self, tx: 'Transaction', *, timeout=None) -> None:
|
||||
|
@ -1012,54 +1014,32 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
|||
@best_effort_reliable
|
||||
@catch_server_exceptions
|
||||
async def request_chunk(self, height: int, tip=None, *, can_return_early=False):
|
||||
if not is_non_negative_integer(height):
|
||||
raise Exception(f"{repr(height)} is not a block height")
|
||||
return await self.interface.request_chunk(height, tip=tip, can_return_early=can_return_early)
|
||||
|
||||
@best_effort_reliable
|
||||
@catch_server_exceptions
|
||||
async def get_transaction(self, tx_hash: str, *, timeout=None) -> str:
|
||||
if not is_hash256_str(tx_hash):
|
||||
raise Exception(f"{repr(tx_hash)} is not a txid")
|
||||
iface = self.interface
|
||||
raw = await iface.session.send_request('blockchain.transaction.get', [tx_hash], timeout=timeout)
|
||||
# validate response
|
||||
tx = Transaction(raw)
|
||||
try:
|
||||
tx.deserialize() # see if raises
|
||||
except Exception as e:
|
||||
self.logger.warning(f"cannot deserialize received transaction (txid {tx_hash}). from {str(iface)}")
|
||||
raise RequestCorrupted() from e # TODO ban server?
|
||||
if tx.txid() != tx_hash:
|
||||
self.logger.warning(f"received tx does not match expected txid {tx_hash} (got {tx.txid()}). from {str(iface)}")
|
||||
raise RequestCorrupted() # TODO ban server?
|
||||
return raw
|
||||
return await self.interface.get_transaction(tx_hash=tx_hash, timeout=timeout)
|
||||
|
||||
@best_effort_reliable
|
||||
@catch_server_exceptions
|
||||
async def get_history_for_scripthash(self, sh: str) -> List[dict]:
|
||||
if not is_hash256_str(sh):
|
||||
raise Exception(f"{repr(sh)} is not a scripthash")
|
||||
return await self.interface.session.send_request('blockchain.scripthash.get_history', [sh])
|
||||
return await self.interface.get_history_for_scripthash(sh)
|
||||
|
||||
@best_effort_reliable
|
||||
@catch_server_exceptions
|
||||
async def listunspent_for_scripthash(self, sh: str) -> List[dict]:
|
||||
if not is_hash256_str(sh):
|
||||
raise Exception(f"{repr(sh)} is not a scripthash")
|
||||
return await self.interface.session.send_request('blockchain.scripthash.listunspent', [sh])
|
||||
return await self.interface.listunspent_for_scripthash(sh)
|
||||
|
||||
@best_effort_reliable
|
||||
@catch_server_exceptions
|
||||
async def get_balance_for_scripthash(self, sh: str) -> dict:
|
||||
if not is_hash256_str(sh):
|
||||
raise Exception(f"{repr(sh)} is not a scripthash")
|
||||
return await self.interface.session.send_request('blockchain.scripthash.get_balance', [sh])
|
||||
return await self.interface.get_balance_for_scripthash(sh)
|
||||
|
||||
@best_effort_reliable
|
||||
@catch_server_exceptions
|
||||
async def get_txid_from_txpos(self, tx_height, tx_pos, merkle):
|
||||
command = 'blockchain.transaction.id_from_pos'
|
||||
return await self.interface.session.send_request(command, [tx_height, tx_pos, merkle])
|
||||
return await self.interface.get_txid_from_txpos(tx_height, tx_pos, merkle)
|
||||
|
||||
def blockchain(self) -> Blockchain:
|
||||
interface = self.interface
|
||||
|
|
|
@ -168,15 +168,12 @@ class Synchronizer(SynchronizerBase):
|
|||
self.requested_histories.add((addr, status))
|
||||
h = address_to_scripthash(addr)
|
||||
self._requests_sent += 1
|
||||
result = await self.network.get_history_for_scripthash(h)
|
||||
result = await self.interface.get_history_for_scripthash(h)
|
||||
self._requests_answered += 1
|
||||
self.logger.info(f"receiving history {addr} {len(result)}")
|
||||
hashes = set(map(lambda item: item['tx_hash'], result))
|
||||
hist = list(map(lambda item: (item['tx_hash'], item['height']), result))
|
||||
# tx_fees
|
||||
for item in result:
|
||||
if item['height'] in (-1, 0) and 'fee' not in item:
|
||||
raise Exception("server response to get_history contains unconfirmed tx without fee")
|
||||
tx_fees = [(item['tx_hash'], item.get('fee')) for item in result]
|
||||
tx_fees = dict(filter(lambda x:x[1] is not None, tx_fees))
|
||||
# Check that txids are unique
|
||||
|
@ -214,7 +211,7 @@ class Synchronizer(SynchronizerBase):
|
|||
async def _get_transaction(self, tx_hash, *, allow_server_not_finding_tx=False):
|
||||
self._requests_sent += 1
|
||||
try:
|
||||
raw_tx = await self.network.get_transaction(tx_hash)
|
||||
raw_tx = await self.interface.get_transaction(tx_hash)
|
||||
except UntrustedServerReturnedError as e:
|
||||
# most likely, "No such mempool or blockchain transaction"
|
||||
if allow_server_not_finding_tx:
|
||||
|
|
|
@ -582,6 +582,30 @@ def is_non_negative_integer(val) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def is_integer(val) -> bool:
|
||||
try:
|
||||
int(val)
|
||||
except:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_real_number(val, *, as_str: bool = False) -> bool:
|
||||
if as_str: # only accept str
|
||||
if not isinstance(val, str):
|
||||
return False
|
||||
else: # only accept int/float/etc.
|
||||
if isinstance(val, str):
|
||||
return False
|
||||
try:
|
||||
Decimal(val)
|
||||
except:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def chunks(items, size: int):
|
||||
"""Break up items, an iterable, into chunks of length size."""
|
||||
if size < 1:
|
||||
|
|
Loading…
Add table
Reference in a new issue