mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-27 07:23:25 +00:00
network: replace "server" strings with ServerAddr objects
This commit is contained in:
parent
ef2ff11926
commit
cf1f2ba4dc
6 changed files with 172 additions and 103 deletions
|
@ -270,6 +270,8 @@ class AuthenticationCredentialsInvalid(AuthenticationError):
|
|||
|
||||
class Daemon(Logger):
|
||||
|
||||
network: Optional[Network]
|
||||
|
||||
@profiler
|
||||
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
|
||||
Logger.__init__(self)
|
||||
|
|
|
@ -453,7 +453,7 @@ def get_exchanges_by_ccy(history=True):
|
|||
|
||||
class FxThread(ThreadJob):
|
||||
|
||||
def __init__(self, config: SimpleConfig, network: Network):
|
||||
def __init__(self, config: SimpleConfig, network: Optional[Network]):
|
||||
ThreadJob.__init__(self)
|
||||
self.config = config
|
||||
self.network = network
|
||||
|
|
|
@ -36,7 +36,7 @@ from PyQt5.QtGui import QFontMetrics
|
|||
|
||||
from electrum.i18n import _
|
||||
from electrum import constants, blockchain, util
|
||||
from electrum.interface import serialize_server, deserialize_server
|
||||
from electrum.interface import ServerAddr
|
||||
from electrum.network import Network
|
||||
from electrum.logging import get_logger
|
||||
|
||||
|
@ -72,10 +72,13 @@ class NetworkDialog(QDialog):
|
|||
|
||||
|
||||
class NodesListWidget(QTreeWidget):
|
||||
SERVER_ADDR_ROLE = Qt.UserRole + 100
|
||||
CHAIN_ID_ROLE = Qt.UserRole + 101
|
||||
IS_SERVER_ROLE = Qt.UserRole + 102
|
||||
|
||||
def __init__(self, parent):
|
||||
QTreeWidget.__init__(self)
|
||||
self.parent = parent
|
||||
self.parent = parent # type: NetworkChoiceLayout
|
||||
self.setHeaderLabels([_('Connected node'), _('Height')])
|
||||
self.setContextMenuPolicy(Qt.CustomContextMenu)
|
||||
self.customContextMenuRequested.connect(self.create_menu)
|
||||
|
@ -84,13 +87,13 @@ class NodesListWidget(QTreeWidget):
|
|||
item = self.currentItem()
|
||||
if not item:
|
||||
return
|
||||
is_server = not bool(item.data(0, Qt.UserRole))
|
||||
is_server = bool(item.data(0, self.IS_SERVER_ROLE))
|
||||
menu = QMenu()
|
||||
if is_server:
|
||||
server = item.data(1, Qt.UserRole)
|
||||
server = item.data(0, self.SERVER_ADDR_ROLE) # type: ServerAddr
|
||||
menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server))
|
||||
else:
|
||||
chain_id = item.data(1, Qt.UserRole)
|
||||
chain_id = item.data(0, self.CHAIN_ID_ROLE)
|
||||
menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id))
|
||||
menu.exec_(self.viewport().mapToGlobal(position))
|
||||
|
||||
|
@ -117,15 +120,15 @@ class NodesListWidget(QTreeWidget):
|
|||
name = b.get_name()
|
||||
if n_chains > 1:
|
||||
x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
|
||||
x.setData(0, Qt.UserRole, 1)
|
||||
x.setData(1, Qt.UserRole, b.get_id())
|
||||
x.setData(0, self.IS_SERVER_ROLE, 0)
|
||||
x.setData(0, self.CHAIN_ID_ROLE, b.get_id())
|
||||
else:
|
||||
x = self
|
||||
for i in interfaces:
|
||||
star = ' *' if i == network.interface else ''
|
||||
item = QTreeWidgetItem([i.host + star, '%d'%i.tip])
|
||||
item.setData(0, Qt.UserRole, 0)
|
||||
item.setData(1, Qt.UserRole, i.server)
|
||||
item.setData(0, self.IS_SERVER_ROLE, 1)
|
||||
item.setData(0, self.SERVER_ADDR_ROLE, i.server)
|
||||
x.addChild(item)
|
||||
if n_chains > 1:
|
||||
self.addTopLevelItem(x)
|
||||
|
@ -144,11 +147,11 @@ class ServerListWidget(QTreeWidget):
|
|||
HOST = 0
|
||||
PORT = 1
|
||||
|
||||
SERVER_STR_ROLE = Qt.UserRole + 100
|
||||
SERVER_ADDR_ROLE = Qt.UserRole + 100
|
||||
|
||||
def __init__(self, parent):
|
||||
QTreeWidget.__init__(self)
|
||||
self.parent = parent
|
||||
self.parent = parent # type: NetworkChoiceLayout
|
||||
self.setHeaderLabels([_('Host'), _('Port')])
|
||||
self.setContextMenuPolicy(Qt.CustomContextMenu)
|
||||
self.customContextMenuRequested.connect(self.create_menu)
|
||||
|
@ -158,14 +161,13 @@ class ServerListWidget(QTreeWidget):
|
|||
if not item:
|
||||
return
|
||||
menu = QMenu()
|
||||
server = item.data(self.Columns.HOST, self.SERVER_STR_ROLE)
|
||||
server = item.data(self.Columns.HOST, self.SERVER_ADDR_ROLE)
|
||||
menu.addAction(_("Use as server"), lambda: self.set_server(server))
|
||||
menu.exec_(self.viewport().mapToGlobal(position))
|
||||
|
||||
def set_server(self, s):
|
||||
host, port, protocol = deserialize_server(s)
|
||||
self.parent.server_host.setText(host)
|
||||
self.parent.server_port.setText(port)
|
||||
def set_server(self, server: ServerAddr):
|
||||
self.parent.server_host.setText(server.host)
|
||||
self.parent.server_port.setText(str(server.port))
|
||||
self.parent.set_server()
|
||||
|
||||
def keyPressEvent(self, event):
|
||||
|
@ -188,8 +190,8 @@ class ServerListWidget(QTreeWidget):
|
|||
port = d.get(protocol)
|
||||
if port:
|
||||
x = QTreeWidgetItem([_host, port])
|
||||
server = serialize_server(_host, port, protocol)
|
||||
x.setData(self.Columns.HOST, self.SERVER_STR_ROLE, server)
|
||||
server = ServerAddr(_host, port, protocol=protocol)
|
||||
x.setData(self.Columns.HOST, self.SERVER_ADDR_ROLE, server)
|
||||
self.addTopLevelItem(x)
|
||||
|
||||
h = self.header()
|
||||
|
@ -431,7 +433,7 @@ class NetworkChoiceLayout(object):
|
|||
self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id))
|
||||
self.update()
|
||||
|
||||
def follow_server(self, server):
|
||||
def follow_server(self, server: ServerAddr):
|
||||
self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
|
||||
self.update()
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import locale
|
|||
from decimal import Decimal
|
||||
import getpass
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import electrum
|
||||
from electrum import util
|
||||
|
@ -15,15 +16,21 @@ from electrum.transaction import PartialTxOutput
|
|||
from electrum.wallet import Wallet
|
||||
from electrum.storage import WalletStorage
|
||||
from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed
|
||||
from electrum.interface import deserialize_server
|
||||
from electrum.interface import ServerAddr
|
||||
from electrum.logging import console_stderr_handler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from electrum.daemon import Daemon
|
||||
from electrum.simple_config import SimpleConfig
|
||||
from electrum.plugin import Plugins
|
||||
|
||||
|
||||
_ = lambda x:x # i18n
|
||||
|
||||
|
||||
class ElectrumGui:
|
||||
|
||||
def __init__(self, config, daemon, plugins):
|
||||
def __init__(self, config: 'SimpleConfig', daemon: 'Daemon', plugins: 'Plugins'):
|
||||
|
||||
self.config = config
|
||||
self.network = daemon.network
|
||||
|
@ -404,21 +411,24 @@ class ElectrumGui:
|
|||
net_params = self.network.get_parameters()
|
||||
host, port, protocol = net_params.host, net_params.port, net_params.protocol
|
||||
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
|
||||
srv = 'auto-connect' if auto_connect else self.network.default_server
|
||||
srv = 'auto-connect' if auto_connect else str(self.network.default_server)
|
||||
out = self.run_dialog('Network', [
|
||||
{'label':'server', 'type':'str', 'value':srv},
|
||||
{'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')},
|
||||
], buttons = 1)
|
||||
if out:
|
||||
if out.get('server'):
|
||||
server = out.get('server')
|
||||
auto_connect = server == 'auto-connect'
|
||||
server_str = out.get('server')
|
||||
auto_connect = server_str == 'auto-connect'
|
||||
if not auto_connect:
|
||||
try:
|
||||
host, port, protocol = deserialize_server(server)
|
||||
server_addr = ServerAddr.from_str(server_str)
|
||||
except Exception:
|
||||
self.show_message("Error:" + server + "\nIn doubt, type \"auto-connect\"")
|
||||
self.show_message("Error:" + server_str + "\nIn doubt, type \"auto-connect\"")
|
||||
return False
|
||||
host = server_addr.host
|
||||
port = str(server_addr.port)
|
||||
protocol = server_addr.protocol
|
||||
if out.get('server') or out.get('proxy'):
|
||||
proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config
|
||||
net_params = NetworkParameters(host, port, protocol, proxy, auto_connect)
|
||||
|
|
|
@ -29,7 +29,7 @@ import sys
|
|||
import traceback
|
||||
import asyncio
|
||||
import socket
|
||||
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set
|
||||
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
|
||||
from collections import defaultdict
|
||||
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
|
||||
import itertools
|
||||
|
@ -198,22 +198,57 @@ class _RSClient(RSClient):
|
|||
raise ConnectError(e) from e
|
||||
|
||||
|
||||
def deserialize_server(server_str: str) -> Tuple[str, str, str]:
|
||||
# host might be IPv6 address, hence do rsplit:
|
||||
host, port, protocol = str(server_str).rsplit(':', 2)
|
||||
if not host:
|
||||
raise ValueError('host must not be empty')
|
||||
if host[0] == '[' and host[-1] == ']': # IPv6
|
||||
host = host[1:-1]
|
||||
if protocol not in ('s', 't'):
|
||||
raise ValueError('invalid network protocol: {}'.format(protocol))
|
||||
net_addr = NetAddress(host, port) # this validates host and port
|
||||
host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
|
||||
return host, port, protocol
|
||||
class ServerAddr:
|
||||
|
||||
def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
|
||||
assert isinstance(host, str), repr(host)
|
||||
if protocol is None:
|
||||
protocol = 's'
|
||||
if not host:
|
||||
raise ValueError('host must not be empty')
|
||||
if host[0] == '[' and host[-1] == ']': # IPv6
|
||||
host = host[1:-1]
|
||||
try:
|
||||
net_addr = NetAddress(host, port) # this validates host and port
|
||||
except Exception as e:
|
||||
raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
|
||||
if protocol not in ('s', 't'):
|
||||
raise ValueError(f"invalid network protocol: {protocol}")
|
||||
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
|
||||
self.port = int(net_addr.port)
|
||||
self.protocol = protocol
|
||||
self._net_addr_str = str(net_addr)
|
||||
|
||||
def serialize_server(host: str, port: Union[str, int], protocol: str) -> str:
|
||||
return str(':'.join([host, str(port), protocol]))
|
||||
@classmethod
|
||||
def from_str(cls, s: str) -> 'ServerAddr':
|
||||
# host might be IPv6 address, hence do rsplit:
|
||||
host, port, protocol = str(s).rsplit(':', 2)
|
||||
return ServerAddr(host=host, port=port, protocol=protocol)
|
||||
|
||||
def __str__(self):
|
||||
return '{}:{}'.format(self.net_addr_str(), self.protocol)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
|
||||
|
||||
def net_addr_str(self) -> str:
|
||||
return self._net_addr_str
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, ServerAddr):
|
||||
return False
|
||||
return (self.host == other.host
|
||||
and self.port == other.port
|
||||
and self.protocol == other.protocol)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.host, self.port, self.protocol))
|
||||
|
||||
|
||||
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
|
||||
|
@ -232,12 +267,10 @@ class Interface(Logger):
|
|||
|
||||
LOGGING_SHORTCUT = 'i'
|
||||
|
||||
def __init__(self, network: 'Network', server: str, proxy: Optional[dict]):
|
||||
def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
|
||||
self.ready = asyncio.Future()
|
||||
self.got_disconnected = asyncio.Future()
|
||||
self.server = server
|
||||
self.host, self.port, self.protocol = deserialize_server(self.server)
|
||||
self.port = int(self.port)
|
||||
Logger.__init__(self)
|
||||
assert network.config.path
|
||||
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
|
||||
|
@ -259,8 +292,20 @@ class Interface(Logger):
|
|||
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
|
||||
self.taskgroup = SilentTaskGroup()
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self.server.host
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self.server.port
|
||||
|
||||
@property
|
||||
def protocol(self):
|
||||
return self.server.protocol
|
||||
|
||||
def diagnostic_name(self):
|
||||
return str(NetAddress(self.host, self.port))
|
||||
return self.server.net_addr_str()
|
||||
|
||||
def __str__(self):
|
||||
return f"<Interface {self.diagnostic_name()}>"
|
||||
|
|
|
@ -32,7 +32,7 @@ import socket
|
|||
import json
|
||||
import sys
|
||||
import asyncio
|
||||
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable
|
||||
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set
|
||||
import traceback
|
||||
import concurrent
|
||||
from concurrent import futures
|
||||
|
@ -44,7 +44,7 @@ from aiohttp import ClientResponse
|
|||
from . import util
|
||||
from .util import (log_exceptions, ignore_exceptions,
|
||||
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
|
||||
is_hash256_str, is_non_negative_integer)
|
||||
is_hash256_str, is_non_negative_integer, MyEncoder)
|
||||
|
||||
from .bitcoin import COIN
|
||||
from . import constants
|
||||
|
@ -53,9 +53,9 @@ from . import bitcoin
|
|||
from . import dns_hacks
|
||||
from .transaction import Transaction
|
||||
from .blockchain import Blockchain, HEADER_SIZE
|
||||
from .interface import (Interface, serialize_server, deserialize_server,
|
||||
from .interface import (Interface,
|
||||
RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS,
|
||||
NetworkException, RequestCorrupted)
|
||||
NetworkException, RequestCorrupted, ServerAddr)
|
||||
from .version import PROTOCOL_VERSION
|
||||
from .simple_config import SimpleConfig
|
||||
from .i18n import _
|
||||
|
@ -117,18 +117,18 @@ def filter_noonion(servers):
|
|||
return {k: v for k, v in servers.items() if not k.endswith('.onion')}
|
||||
|
||||
|
||||
def filter_protocol(hostmap, protocol='s'):
|
||||
'''Filters the hostmap for those implementing protocol.
|
||||
The result is a list in serialized form.'''
|
||||
def filter_protocol(hostmap, protocol='s') -> Sequence[ServerAddr]:
|
||||
"""Filters the hostmap for those implementing protocol."""
|
||||
eligible = []
|
||||
for host, portmap in hostmap.items():
|
||||
port = portmap.get(protocol)
|
||||
if port:
|
||||
eligible.append(serialize_server(host, port, protocol))
|
||||
eligible.append(ServerAddr(host, port, protocol=protocol))
|
||||
return eligible
|
||||
|
||||
|
||||
def pick_random_server(hostmap=None, protocol='s', exclude_set=None):
|
||||
def pick_random_server(hostmap=None, *, protocol='s',
|
||||
exclude_set: Set[ServerAddr] = None) -> Optional[ServerAddr]:
|
||||
if hostmap is None:
|
||||
hostmap = constants.net.DEFAULT_SERVERS
|
||||
if exclude_set is None:
|
||||
|
@ -240,6 +240,14 @@ class Network(Logger):
|
|||
|
||||
LOGGING_SHORTCUT = 'n'
|
||||
|
||||
taskgroup: Optional[TaskGroup]
|
||||
interface: Optional[Interface]
|
||||
interfaces: Dict[ServerAddr, Interface]
|
||||
connecting: Set[ServerAddr]
|
||||
server_queue: 'Optional[queue.Queue[ServerAddr]]'
|
||||
disconnected_servers: Set[ServerAddr]
|
||||
default_server: ServerAddr
|
||||
|
||||
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
|
||||
global _INSTANCE
|
||||
assert _INSTANCE is None, "Network is a singleton!"
|
||||
|
@ -266,14 +274,15 @@ class Network(Logger):
|
|||
# Sanitize default server
|
||||
if self.default_server:
|
||||
try:
|
||||
deserialize_server(self.default_server)
|
||||
self.default_server = ServerAddr.from_str(self.default_server)
|
||||
except:
|
||||
self.logger.warning('failed to parse server-string; falling back to localhost.')
|
||||
self.default_server = "localhost:50002:s"
|
||||
if not self.default_server:
|
||||
self.default_server = ServerAddr.from_str("localhost:50002:s")
|
||||
else:
|
||||
self.default_server = pick_random_server()
|
||||
assert isinstance(self.default_server, ServerAddr), f"invalid type for default_server: {self.default_server!r}"
|
||||
|
||||
self.taskgroup = None # type: TaskGroup
|
||||
self.taskgroup = None
|
||||
|
||||
# locks
|
||||
self.restart_lock = asyncio.Lock()
|
||||
|
@ -295,10 +304,10 @@ class Network(Logger):
|
|||
self.server_retry_time = time.time()
|
||||
self.nodes_retry_time = time.time()
|
||||
# the main server we are currently communicating with
|
||||
self.interface = None # type: Optional[Interface]
|
||||
self.interface = None
|
||||
self.default_server_changed_event = asyncio.Event()
|
||||
# set of servers we have an ongoing connection with
|
||||
self.interfaces = {} # type: Dict[str, Interface]
|
||||
self.interfaces = {}
|
||||
self.auto_connect = self.config.get('auto_connect', True)
|
||||
self.connecting = set()
|
||||
self.server_queue = None
|
||||
|
@ -347,14 +356,15 @@ class Network(Logger):
|
|||
return func(self, *args, **kwargs)
|
||||
return func_wrapper
|
||||
|
||||
def _read_recent_servers(self):
|
||||
def _read_recent_servers(self) -> List[ServerAddr]:
|
||||
if not self.config.path:
|
||||
return []
|
||||
path = os.path.join(self.config.path, "recent_servers")
|
||||
try:
|
||||
with open(path, "r", encoding='utf-8') as f:
|
||||
data = f.read()
|
||||
return json.loads(data)
|
||||
servers_list = json.loads(data)
|
||||
return [ServerAddr.from_str(s) for s in servers_list]
|
||||
except:
|
||||
return []
|
||||
|
||||
|
@ -363,7 +373,7 @@ class Network(Logger):
|
|||
if not self.config.path:
|
||||
return
|
||||
path = os.path.join(self.config.path, "recent_servers")
|
||||
s = json.dumps(self.recent_servers, indent=4, sort_keys=True)
|
||||
s = json.dumps(self.recent_servers, indent=4, sort_keys=True, cls=MyEncoder)
|
||||
try:
|
||||
with open(path, "w", encoding='utf-8') as f:
|
||||
f.write(s)
|
||||
|
@ -462,10 +472,10 @@ class Network(Logger):
|
|||
util.trigger_callback(key, self.get_status_value(key))
|
||||
|
||||
def get_parameters(self) -> NetworkParameters:
|
||||
host, port, protocol = deserialize_server(self.default_server)
|
||||
return NetworkParameters(host=host,
|
||||
port=port,
|
||||
protocol=protocol,
|
||||
server = self.default_server
|
||||
return NetworkParameters(host=server.host,
|
||||
port=str(server.port),
|
||||
protocol=server.protocol,
|
||||
proxy=self.proxy,
|
||||
auto_connect=self.auto_connect,
|
||||
oneserver=self.oneserver)
|
||||
|
@ -474,7 +484,7 @@ class Network(Logger):
|
|||
if self.is_connected():
|
||||
return self.donation_address
|
||||
|
||||
def get_interfaces(self) -> List[str]:
|
||||
def get_interfaces(self) -> List[ServerAddr]:
|
||||
"""The list of servers for the connected interfaces."""
|
||||
with self.interfaces_lock:
|
||||
return list(self.interfaces)
|
||||
|
@ -516,21 +526,18 @@ class Network(Logger):
|
|||
# hardcoded servers
|
||||
out.update(constants.net.DEFAULT_SERVERS)
|
||||
# add recent servers
|
||||
for s in self.recent_servers:
|
||||
try:
|
||||
host, port, protocol = deserialize_server(s)
|
||||
except:
|
||||
continue
|
||||
if host in out:
|
||||
out[host].update({protocol: port})
|
||||
for server in self.recent_servers:
|
||||
port = str(server.port)
|
||||
if server.host in out:
|
||||
out[server.host].update({server.protocol: port})
|
||||
else:
|
||||
out[host] = {protocol: port}
|
||||
out[server.host] = {server.protocol: port}
|
||||
# potentially filter out some
|
||||
if self.config.get('noonion'):
|
||||
out = filter_noonion(out)
|
||||
return out
|
||||
|
||||
def _start_interface(self, server: str):
|
||||
def _start_interface(self, server: ServerAddr):
|
||||
if server not in self.interfaces and server not in self.connecting:
|
||||
if server == self.default_server:
|
||||
self.logger.info(f"connecting to {server} as new interface")
|
||||
|
@ -538,10 +545,10 @@ class Network(Logger):
|
|||
self.connecting.add(server)
|
||||
self.server_queue.put(server)
|
||||
|
||||
def _start_random_interface(self):
|
||||
def _start_random_interface(self) -> Optional[ServerAddr]:
|
||||
with self.interfaces_lock:
|
||||
exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting
|
||||
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
|
||||
server = pick_random_server(self.get_servers(), protocol=self.protocol, exclude_set=exclude_set)
|
||||
if server:
|
||||
self._start_interface(server)
|
||||
return server
|
||||
|
@ -557,10 +564,9 @@ class Network(Logger):
|
|||
proxy = net_params.proxy
|
||||
proxy_str = serialize_proxy(proxy)
|
||||
host, port, protocol = net_params.host, net_params.port, net_params.protocol
|
||||
server_str = serialize_server(host, port, protocol)
|
||||
# sanitize parameters
|
||||
try:
|
||||
deserialize_server(serialize_server(host, port, protocol))
|
||||
server = ServerAddr(host, port, protocol=protocol)
|
||||
if proxy:
|
||||
proxy_modes.index(proxy['mode']) + 1
|
||||
int(proxy['port'])
|
||||
|
@ -569,9 +575,9 @@ class Network(Logger):
|
|||
self.config.set_key('auto_connect', net_params.auto_connect, False)
|
||||
self.config.set_key('oneserver', net_params.oneserver, False)
|
||||
self.config.set_key('proxy', proxy_str, False)
|
||||
self.config.set_key('server', server_str, True)
|
||||
self.config.set_key('server', str(server), True)
|
||||
# abort if changes were not allowed by config
|
||||
if self.config.get('server') != server_str \
|
||||
if self.config.get('server') != str(server) \
|
||||
or self.config.get('proxy') != proxy_str \
|
||||
or self.config.get('oneserver') != net_params.oneserver:
|
||||
return
|
||||
|
@ -581,10 +587,10 @@ class Network(Logger):
|
|||
if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver:
|
||||
# Restart the network defaulting to the given server
|
||||
await self._stop()
|
||||
self.default_server = server_str
|
||||
self.default_server = server
|
||||
await self._start()
|
||||
elif self.default_server != server_str:
|
||||
await self.switch_to_interface(server_str)
|
||||
elif self.default_server != server:
|
||||
await self.switch_to_interface(server)
|
||||
else:
|
||||
await self.switch_lagging_interface()
|
||||
|
||||
|
@ -646,7 +652,7 @@ class Network(Logger):
|
|||
# FIXME switch to best available?
|
||||
self.logger.info("tried to switch to best chain but no interfaces are on it")
|
||||
|
||||
async def switch_to_interface(self, server: str):
|
||||
async def switch_to_interface(self, server: ServerAddr):
|
||||
"""Switch to server as our main interface. If no connection exists,
|
||||
queue interface to be started. The actual switch will
|
||||
happen when the interface becomes ready.
|
||||
|
@ -722,8 +728,8 @@ class Network(Logger):
|
|||
|
||||
@ignore_exceptions # do not kill main_taskgroup
|
||||
@log_exceptions
|
||||
async def _run_new_interface(self, server):
|
||||
interface = Interface(self, server, self.proxy)
|
||||
async def _run_new_interface(self, server: ServerAddr):
|
||||
interface = Interface(network=self, server=server, proxy=self.proxy)
|
||||
# note: using longer timeouts here as DNS can sometimes be slow!
|
||||
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
|
||||
try:
|
||||
|
@ -1070,23 +1076,26 @@ class Network(Logger):
|
|||
with self.interfaces_lock: interfaces = list(self.interfaces.values())
|
||||
interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
|
||||
if len(interfaces_on_selected_chain) == 0: return
|
||||
chosen_iface = random.choice(interfaces_on_selected_chain)
|
||||
chosen_iface = random.choice(interfaces_on_selected_chain) # type: Interface
|
||||
# switch to server (and save to config)
|
||||
net_params = self.get_parameters()
|
||||
host, port, protocol = deserialize_server(chosen_iface.server)
|
||||
net_params = net_params._replace(host=host, port=port, protocol=protocol)
|
||||
server = chosen_iface.server
|
||||
net_params = net_params._replace(host=server.host,
|
||||
port=str(server.port),
|
||||
protocol=server.protocol)
|
||||
await self.set_parameters(net_params)
|
||||
|
||||
async def follow_chain_given_server(self, server_str: str) -> None:
|
||||
async def follow_chain_given_server(self, server: ServerAddr) -> None:
|
||||
# note that server_str should correspond to a connected interface
|
||||
iface = self.interfaces.get(server_str)
|
||||
iface = self.interfaces.get(server)
|
||||
if iface is None:
|
||||
return
|
||||
self._set_preferred_chain(iface.blockchain)
|
||||
# switch to server (and save to config)
|
||||
net_params = self.get_parameters()
|
||||
host, port, protocol = deserialize_server(server_str)
|
||||
net_params = net_params._replace(host=host, port=port, protocol=protocol)
|
||||
net_params = net_params._replace(host=server.host,
|
||||
port=str(server.port),
|
||||
protocol=server.protocol)
|
||||
await self.set_parameters(net_params)
|
||||
|
||||
def get_local_height(self):
|
||||
|
@ -1107,7 +1116,7 @@ class Network(Logger):
|
|||
assert not self.connecting and not self.server_queue
|
||||
self.logger.info('starting network')
|
||||
self.disconnected_servers = set([])
|
||||
self.protocol = deserialize_server(self.default_server)[2]
|
||||
self.protocol = self.default_server.protocol
|
||||
self.server_queue = queue.Queue()
|
||||
self._set_proxy(deserialize_proxy(self.config.get('proxy')))
|
||||
self._set_oneserver(self.config.get('oneserver', False))
|
||||
|
@ -1147,9 +1156,9 @@ class Network(Logger):
|
|||
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
|
||||
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
|
||||
self.taskgroup = None # type: TaskGroup
|
||||
self.interface = None # type: Interface
|
||||
self.interfaces = {} # type: Dict[str, Interface]
|
||||
self.taskgroup = None
|
||||
self.interface = None
|
||||
self.interfaces = {}
|
||||
self.connecting.clear()
|
||||
self.server_queue = None
|
||||
if not full_shutdown:
|
||||
|
@ -1268,8 +1277,8 @@ class Network(Logger):
|
|||
|
||||
async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence):
|
||||
responses = dict()
|
||||
async def get_response(server):
|
||||
interface = Interface(self, server, self.proxy)
|
||||
async def get_response(server: ServerAddr):
|
||||
interface = Interface(network=self, server=server, proxy=self.proxy)
|
||||
timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent)
|
||||
try:
|
||||
await asyncio.wait_for(interface.ready, timeout)
|
||||
|
@ -1283,5 +1292,6 @@ class Network(Logger):
|
|||
responses[interface.server] = res
|
||||
async with TaskGroup() as group:
|
||||
for server in servers:
|
||||
server = ServerAddr.from_str(server)
|
||||
await group.spawn(get_response(server))
|
||||
return responses
|
||||
|
|
Loading…
Add table
Reference in a new issue