mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-27 23:41:35 +00:00
network: replace "server" strings with ServerAddr objects
This commit is contained in:
parent
ef2ff11926
commit
cf1f2ba4dc
6 changed files with 172 additions and 103 deletions
|
@ -270,6 +270,8 @@ class AuthenticationCredentialsInvalid(AuthenticationError):
|
||||||
|
|
||||||
class Daemon(Logger):
|
class Daemon(Logger):
|
||||||
|
|
||||||
|
network: Optional[Network]
|
||||||
|
|
||||||
@profiler
|
@profiler
|
||||||
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
|
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
|
||||||
Logger.__init__(self)
|
Logger.__init__(self)
|
||||||
|
|
|
@ -453,7 +453,7 @@ def get_exchanges_by_ccy(history=True):
|
||||||
|
|
||||||
class FxThread(ThreadJob):
|
class FxThread(ThreadJob):
|
||||||
|
|
||||||
def __init__(self, config: SimpleConfig, network: Network):
|
def __init__(self, config: SimpleConfig, network: Optional[Network]):
|
||||||
ThreadJob.__init__(self)
|
ThreadJob.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.network = network
|
self.network = network
|
||||||
|
|
|
@ -36,7 +36,7 @@ from PyQt5.QtGui import QFontMetrics
|
||||||
|
|
||||||
from electrum.i18n import _
|
from electrum.i18n import _
|
||||||
from electrum import constants, blockchain, util
|
from electrum import constants, blockchain, util
|
||||||
from electrum.interface import serialize_server, deserialize_server
|
from electrum.interface import ServerAddr
|
||||||
from electrum.network import Network
|
from electrum.network import Network
|
||||||
from electrum.logging import get_logger
|
from electrum.logging import get_logger
|
||||||
|
|
||||||
|
@ -72,10 +72,13 @@ class NetworkDialog(QDialog):
|
||||||
|
|
||||||
|
|
||||||
class NodesListWidget(QTreeWidget):
|
class NodesListWidget(QTreeWidget):
|
||||||
|
SERVER_ADDR_ROLE = Qt.UserRole + 100
|
||||||
|
CHAIN_ID_ROLE = Qt.UserRole + 101
|
||||||
|
IS_SERVER_ROLE = Qt.UserRole + 102
|
||||||
|
|
||||||
def __init__(self, parent):
|
def __init__(self, parent):
|
||||||
QTreeWidget.__init__(self)
|
QTreeWidget.__init__(self)
|
||||||
self.parent = parent
|
self.parent = parent # type: NetworkChoiceLayout
|
||||||
self.setHeaderLabels([_('Connected node'), _('Height')])
|
self.setHeaderLabels([_('Connected node'), _('Height')])
|
||||||
self.setContextMenuPolicy(Qt.CustomContextMenu)
|
self.setContextMenuPolicy(Qt.CustomContextMenu)
|
||||||
self.customContextMenuRequested.connect(self.create_menu)
|
self.customContextMenuRequested.connect(self.create_menu)
|
||||||
|
@ -84,13 +87,13 @@ class NodesListWidget(QTreeWidget):
|
||||||
item = self.currentItem()
|
item = self.currentItem()
|
||||||
if not item:
|
if not item:
|
||||||
return
|
return
|
||||||
is_server = not bool(item.data(0, Qt.UserRole))
|
is_server = bool(item.data(0, self.IS_SERVER_ROLE))
|
||||||
menu = QMenu()
|
menu = QMenu()
|
||||||
if is_server:
|
if is_server:
|
||||||
server = item.data(1, Qt.UserRole)
|
server = item.data(0, self.SERVER_ADDR_ROLE) # type: ServerAddr
|
||||||
menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server))
|
menu.addAction(_("Use as server"), lambda: self.parent.follow_server(server))
|
||||||
else:
|
else:
|
||||||
chain_id = item.data(1, Qt.UserRole)
|
chain_id = item.data(0, self.CHAIN_ID_ROLE)
|
||||||
menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id))
|
menu.addAction(_("Follow this branch"), lambda: self.parent.follow_branch(chain_id))
|
||||||
menu.exec_(self.viewport().mapToGlobal(position))
|
menu.exec_(self.viewport().mapToGlobal(position))
|
||||||
|
|
||||||
|
@ -117,15 +120,15 @@ class NodesListWidget(QTreeWidget):
|
||||||
name = b.get_name()
|
name = b.get_name()
|
||||||
if n_chains > 1:
|
if n_chains > 1:
|
||||||
x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
|
x = QTreeWidgetItem([name + '@%d'%b.get_max_forkpoint(), '%d'%b.height()])
|
||||||
x.setData(0, Qt.UserRole, 1)
|
x.setData(0, self.IS_SERVER_ROLE, 0)
|
||||||
x.setData(1, Qt.UserRole, b.get_id())
|
x.setData(0, self.CHAIN_ID_ROLE, b.get_id())
|
||||||
else:
|
else:
|
||||||
x = self
|
x = self
|
||||||
for i in interfaces:
|
for i in interfaces:
|
||||||
star = ' *' if i == network.interface else ''
|
star = ' *' if i == network.interface else ''
|
||||||
item = QTreeWidgetItem([i.host + star, '%d'%i.tip])
|
item = QTreeWidgetItem([i.host + star, '%d'%i.tip])
|
||||||
item.setData(0, Qt.UserRole, 0)
|
item.setData(0, self.IS_SERVER_ROLE, 1)
|
||||||
item.setData(1, Qt.UserRole, i.server)
|
item.setData(0, self.SERVER_ADDR_ROLE, i.server)
|
||||||
x.addChild(item)
|
x.addChild(item)
|
||||||
if n_chains > 1:
|
if n_chains > 1:
|
||||||
self.addTopLevelItem(x)
|
self.addTopLevelItem(x)
|
||||||
|
@ -144,11 +147,11 @@ class ServerListWidget(QTreeWidget):
|
||||||
HOST = 0
|
HOST = 0
|
||||||
PORT = 1
|
PORT = 1
|
||||||
|
|
||||||
SERVER_STR_ROLE = Qt.UserRole + 100
|
SERVER_ADDR_ROLE = Qt.UserRole + 100
|
||||||
|
|
||||||
def __init__(self, parent):
|
def __init__(self, parent):
|
||||||
QTreeWidget.__init__(self)
|
QTreeWidget.__init__(self)
|
||||||
self.parent = parent
|
self.parent = parent # type: NetworkChoiceLayout
|
||||||
self.setHeaderLabels([_('Host'), _('Port')])
|
self.setHeaderLabels([_('Host'), _('Port')])
|
||||||
self.setContextMenuPolicy(Qt.CustomContextMenu)
|
self.setContextMenuPolicy(Qt.CustomContextMenu)
|
||||||
self.customContextMenuRequested.connect(self.create_menu)
|
self.customContextMenuRequested.connect(self.create_menu)
|
||||||
|
@ -158,14 +161,13 @@ class ServerListWidget(QTreeWidget):
|
||||||
if not item:
|
if not item:
|
||||||
return
|
return
|
||||||
menu = QMenu()
|
menu = QMenu()
|
||||||
server = item.data(self.Columns.HOST, self.SERVER_STR_ROLE)
|
server = item.data(self.Columns.HOST, self.SERVER_ADDR_ROLE)
|
||||||
menu.addAction(_("Use as server"), lambda: self.set_server(server))
|
menu.addAction(_("Use as server"), lambda: self.set_server(server))
|
||||||
menu.exec_(self.viewport().mapToGlobal(position))
|
menu.exec_(self.viewport().mapToGlobal(position))
|
||||||
|
|
||||||
def set_server(self, s):
|
def set_server(self, server: ServerAddr):
|
||||||
host, port, protocol = deserialize_server(s)
|
self.parent.server_host.setText(server.host)
|
||||||
self.parent.server_host.setText(host)
|
self.parent.server_port.setText(str(server.port))
|
||||||
self.parent.server_port.setText(port)
|
|
||||||
self.parent.set_server()
|
self.parent.set_server()
|
||||||
|
|
||||||
def keyPressEvent(self, event):
|
def keyPressEvent(self, event):
|
||||||
|
@ -188,8 +190,8 @@ class ServerListWidget(QTreeWidget):
|
||||||
port = d.get(protocol)
|
port = d.get(protocol)
|
||||||
if port:
|
if port:
|
||||||
x = QTreeWidgetItem([_host, port])
|
x = QTreeWidgetItem([_host, port])
|
||||||
server = serialize_server(_host, port, protocol)
|
server = ServerAddr(_host, port, protocol=protocol)
|
||||||
x.setData(self.Columns.HOST, self.SERVER_STR_ROLE, server)
|
x.setData(self.Columns.HOST, self.SERVER_ADDR_ROLE, server)
|
||||||
self.addTopLevelItem(x)
|
self.addTopLevelItem(x)
|
||||||
|
|
||||||
h = self.header()
|
h = self.header()
|
||||||
|
@ -431,7 +433,7 @@ class NetworkChoiceLayout(object):
|
||||||
self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id))
|
self.network.run_from_another_thread(self.network.follow_chain_given_id(chain_id))
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
def follow_server(self, server):
|
def follow_server(self, server: ServerAddr):
|
||||||
self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
|
self.network.run_from_another_thread(self.network.follow_chain_given_server(server))
|
||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import locale
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
import getpass
|
import getpass
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import electrum
|
import electrum
|
||||||
from electrum import util
|
from electrum import util
|
||||||
|
@ -15,15 +16,21 @@ from electrum.transaction import PartialTxOutput
|
||||||
from electrum.wallet import Wallet
|
from electrum.wallet import Wallet
|
||||||
from electrum.storage import WalletStorage
|
from electrum.storage import WalletStorage
|
||||||
from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed
|
from electrum.network import NetworkParameters, TxBroadcastError, BestEffortRequestFailed
|
||||||
from electrum.interface import deserialize_server
|
from electrum.interface import ServerAddr
|
||||||
from electrum.logging import console_stderr_handler
|
from electrum.logging import console_stderr_handler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from electrum.daemon import Daemon
|
||||||
|
from electrum.simple_config import SimpleConfig
|
||||||
|
from electrum.plugin import Plugins
|
||||||
|
|
||||||
|
|
||||||
_ = lambda x:x # i18n
|
_ = lambda x:x # i18n
|
||||||
|
|
||||||
|
|
||||||
class ElectrumGui:
|
class ElectrumGui:
|
||||||
|
|
||||||
def __init__(self, config, daemon, plugins):
|
def __init__(self, config: 'SimpleConfig', daemon: 'Daemon', plugins: 'Plugins'):
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.network = daemon.network
|
self.network = daemon.network
|
||||||
|
@ -404,21 +411,24 @@ class ElectrumGui:
|
||||||
net_params = self.network.get_parameters()
|
net_params = self.network.get_parameters()
|
||||||
host, port, protocol = net_params.host, net_params.port, net_params.protocol
|
host, port, protocol = net_params.host, net_params.port, net_params.protocol
|
||||||
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
|
proxy_config, auto_connect = net_params.proxy, net_params.auto_connect
|
||||||
srv = 'auto-connect' if auto_connect else self.network.default_server
|
srv = 'auto-connect' if auto_connect else str(self.network.default_server)
|
||||||
out = self.run_dialog('Network', [
|
out = self.run_dialog('Network', [
|
||||||
{'label':'server', 'type':'str', 'value':srv},
|
{'label':'server', 'type':'str', 'value':srv},
|
||||||
{'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')},
|
{'label':'proxy', 'type':'str', 'value':self.config.get('proxy', '')},
|
||||||
], buttons = 1)
|
], buttons = 1)
|
||||||
if out:
|
if out:
|
||||||
if out.get('server'):
|
if out.get('server'):
|
||||||
server = out.get('server')
|
server_str = out.get('server')
|
||||||
auto_connect = server == 'auto-connect'
|
auto_connect = server_str == 'auto-connect'
|
||||||
if not auto_connect:
|
if not auto_connect:
|
||||||
try:
|
try:
|
||||||
host, port, protocol = deserialize_server(server)
|
server_addr = ServerAddr.from_str(server_str)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.show_message("Error:" + server + "\nIn doubt, type \"auto-connect\"")
|
self.show_message("Error:" + server_str + "\nIn doubt, type \"auto-connect\"")
|
||||||
return False
|
return False
|
||||||
|
host = server_addr.host
|
||||||
|
port = str(server_addr.port)
|
||||||
|
protocol = server_addr.protocol
|
||||||
if out.get('server') or out.get('proxy'):
|
if out.get('server') or out.get('proxy'):
|
||||||
proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config
|
proxy = electrum.network.deserialize_proxy(out.get('proxy')) if out.get('proxy') else proxy_config
|
||||||
net_params = NetworkParameters(host, port, protocol, proxy, auto_connect)
|
net_params = NetworkParameters(host, port, protocol, proxy, auto_connect)
|
||||||
|
|
|
@ -29,7 +29,7 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio
|
||||||
import socket
|
import socket
|
||||||
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set
|
from typing import Tuple, Union, List, TYPE_CHECKING, Optional, Set, NamedTuple
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
|
from ipaddress import IPv4Network, IPv6Network, ip_address, IPv6Address
|
||||||
import itertools
|
import itertools
|
||||||
|
@ -198,22 +198,57 @@ class _RSClient(RSClient):
|
||||||
raise ConnectError(e) from e
|
raise ConnectError(e) from e
|
||||||
|
|
||||||
|
|
||||||
def deserialize_server(server_str: str) -> Tuple[str, str, str]:
|
class ServerAddr:
|
||||||
# host might be IPv6 address, hence do rsplit:
|
|
||||||
host, port, protocol = str(server_str).rsplit(':', 2)
|
|
||||||
if not host:
|
|
||||||
raise ValueError('host must not be empty')
|
|
||||||
if host[0] == '[' and host[-1] == ']': # IPv6
|
|
||||||
host = host[1:-1]
|
|
||||||
if protocol not in ('s', 't'):
|
|
||||||
raise ValueError('invalid network protocol: {}'.format(protocol))
|
|
||||||
net_addr = NetAddress(host, port) # this validates host and port
|
|
||||||
host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
|
|
||||||
return host, port, protocol
|
|
||||||
|
|
||||||
|
def __init__(self, host: str, port: Union[int, str], *, protocol: str = None):
|
||||||
|
assert isinstance(host, str), repr(host)
|
||||||
|
if protocol is None:
|
||||||
|
protocol = 's'
|
||||||
|
if not host:
|
||||||
|
raise ValueError('host must not be empty')
|
||||||
|
if host[0] == '[' and host[-1] == ']': # IPv6
|
||||||
|
host = host[1:-1]
|
||||||
|
try:
|
||||||
|
net_addr = NetAddress(host, port) # this validates host and port
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"cannot construct ServerAddr: invalid host or port (host={host}, port={port})") from e
|
||||||
|
if protocol not in ('s', 't'):
|
||||||
|
raise ValueError(f"invalid network protocol: {protocol}")
|
||||||
|
self.host = str(net_addr.host) # canonical form (if e.g. IPv6 address)
|
||||||
|
self.port = int(net_addr.port)
|
||||||
|
self.protocol = protocol
|
||||||
|
self._net_addr_str = str(net_addr)
|
||||||
|
|
||||||
def serialize_server(host: str, port: Union[str, int], protocol: str) -> str:
|
@classmethod
|
||||||
return str(':'.join([host, str(port), protocol]))
|
def from_str(cls, s: str) -> 'ServerAddr':
|
||||||
|
# host might be IPv6 address, hence do rsplit:
|
||||||
|
host, port, protocol = str(s).rsplit(':', 2)
|
||||||
|
return ServerAddr(host=host, port=port, protocol=protocol)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return '{}:{}'.format(self.net_addr_str(), self.protocol)
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'<ServerAddr host={self.host} port={self.port} protocol={self.protocol}>'
|
||||||
|
|
||||||
|
def net_addr_str(self) -> str:
|
||||||
|
return self._net_addr_str
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, ServerAddr):
|
||||||
|
return False
|
||||||
|
return (self.host == other.host
|
||||||
|
and self.port == other.port
|
||||||
|
and self.protocol == other.protocol)
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not (self == other)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.host, self.port, self.protocol))
|
||||||
|
|
||||||
|
|
||||||
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
|
def _get_cert_path_for_host(*, config: 'SimpleConfig', host: str) -> str:
|
||||||
|
@ -232,12 +267,10 @@ class Interface(Logger):
|
||||||
|
|
||||||
LOGGING_SHORTCUT = 'i'
|
LOGGING_SHORTCUT = 'i'
|
||||||
|
|
||||||
def __init__(self, network: 'Network', server: str, proxy: Optional[dict]):
|
def __init__(self, *, network: 'Network', server: ServerAddr, proxy: Optional[dict]):
|
||||||
self.ready = asyncio.Future()
|
self.ready = asyncio.Future()
|
||||||
self.got_disconnected = asyncio.Future()
|
self.got_disconnected = asyncio.Future()
|
||||||
self.server = server
|
self.server = server
|
||||||
self.host, self.port, self.protocol = deserialize_server(self.server)
|
|
||||||
self.port = int(self.port)
|
|
||||||
Logger.__init__(self)
|
Logger.__init__(self)
|
||||||
assert network.config.path
|
assert network.config.path
|
||||||
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
|
self.cert_path = _get_cert_path_for_host(config=network.config, host=self.host)
|
||||||
|
@ -259,8 +292,20 @@ class Interface(Logger):
|
||||||
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
|
self.network.taskgroup.spawn(self.run()), self.network.asyncio_loop)
|
||||||
self.taskgroup = SilentTaskGroup()
|
self.taskgroup = SilentTaskGroup()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def host(self):
|
||||||
|
return self.server.host
|
||||||
|
|
||||||
|
@property
|
||||||
|
def port(self):
|
||||||
|
return self.server.port
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocol(self):
|
||||||
|
return self.server.protocol
|
||||||
|
|
||||||
def diagnostic_name(self):
|
def diagnostic_name(self):
|
||||||
return str(NetAddress(self.host, self.port))
|
return self.server.net_addr_str()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"<Interface {self.diagnostic_name()}>"
|
return f"<Interface {self.diagnostic_name()}>"
|
||||||
|
|
|
@ -32,7 +32,7 @@ import socket
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable
|
from typing import NamedTuple, Optional, Sequence, List, Dict, Tuple, TYPE_CHECKING, Iterable, Set
|
||||||
import traceback
|
import traceback
|
||||||
import concurrent
|
import concurrent
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
|
@ -44,7 +44,7 @@ from aiohttp import ClientResponse
|
||||||
from . import util
|
from . import util
|
||||||
from .util import (log_exceptions, ignore_exceptions,
|
from .util import (log_exceptions, ignore_exceptions,
|
||||||
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
|
bfh, SilentTaskGroup, make_aiohttp_session, send_exception_to_crash_reporter,
|
||||||
is_hash256_str, is_non_negative_integer)
|
is_hash256_str, is_non_negative_integer, MyEncoder)
|
||||||
|
|
||||||
from .bitcoin import COIN
|
from .bitcoin import COIN
|
||||||
from . import constants
|
from . import constants
|
||||||
|
@ -53,9 +53,9 @@ from . import bitcoin
|
||||||
from . import dns_hacks
|
from . import dns_hacks
|
||||||
from .transaction import Transaction
|
from .transaction import Transaction
|
||||||
from .blockchain import Blockchain, HEADER_SIZE
|
from .blockchain import Blockchain, HEADER_SIZE
|
||||||
from .interface import (Interface, serialize_server, deserialize_server,
|
from .interface import (Interface,
|
||||||
RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS,
|
RequestTimedOut, NetworkTimeout, BUCKET_NAME_OF_ONION_SERVERS,
|
||||||
NetworkException, RequestCorrupted)
|
NetworkException, RequestCorrupted, ServerAddr)
|
||||||
from .version import PROTOCOL_VERSION
|
from .version import PROTOCOL_VERSION
|
||||||
from .simple_config import SimpleConfig
|
from .simple_config import SimpleConfig
|
||||||
from .i18n import _
|
from .i18n import _
|
||||||
|
@ -117,18 +117,18 @@ def filter_noonion(servers):
|
||||||
return {k: v for k, v in servers.items() if not k.endswith('.onion')}
|
return {k: v for k, v in servers.items() if not k.endswith('.onion')}
|
||||||
|
|
||||||
|
|
||||||
def filter_protocol(hostmap, protocol='s'):
|
def filter_protocol(hostmap, protocol='s') -> Sequence[ServerAddr]:
|
||||||
'''Filters the hostmap for those implementing protocol.
|
"""Filters the hostmap for those implementing protocol."""
|
||||||
The result is a list in serialized form.'''
|
|
||||||
eligible = []
|
eligible = []
|
||||||
for host, portmap in hostmap.items():
|
for host, portmap in hostmap.items():
|
||||||
port = portmap.get(protocol)
|
port = portmap.get(protocol)
|
||||||
if port:
|
if port:
|
||||||
eligible.append(serialize_server(host, port, protocol))
|
eligible.append(ServerAddr(host, port, protocol=protocol))
|
||||||
return eligible
|
return eligible
|
||||||
|
|
||||||
|
|
||||||
def pick_random_server(hostmap=None, protocol='s', exclude_set=None):
|
def pick_random_server(hostmap=None, *, protocol='s',
|
||||||
|
exclude_set: Set[ServerAddr] = None) -> Optional[ServerAddr]:
|
||||||
if hostmap is None:
|
if hostmap is None:
|
||||||
hostmap = constants.net.DEFAULT_SERVERS
|
hostmap = constants.net.DEFAULT_SERVERS
|
||||||
if exclude_set is None:
|
if exclude_set is None:
|
||||||
|
@ -240,6 +240,14 @@ class Network(Logger):
|
||||||
|
|
||||||
LOGGING_SHORTCUT = 'n'
|
LOGGING_SHORTCUT = 'n'
|
||||||
|
|
||||||
|
taskgroup: Optional[TaskGroup]
|
||||||
|
interface: Optional[Interface]
|
||||||
|
interfaces: Dict[ServerAddr, Interface]
|
||||||
|
connecting: Set[ServerAddr]
|
||||||
|
server_queue: 'Optional[queue.Queue[ServerAddr]]'
|
||||||
|
disconnected_servers: Set[ServerAddr]
|
||||||
|
default_server: ServerAddr
|
||||||
|
|
||||||
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
|
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
|
||||||
global _INSTANCE
|
global _INSTANCE
|
||||||
assert _INSTANCE is None, "Network is a singleton!"
|
assert _INSTANCE is None, "Network is a singleton!"
|
||||||
|
@ -266,14 +274,15 @@ class Network(Logger):
|
||||||
# Sanitize default server
|
# Sanitize default server
|
||||||
if self.default_server:
|
if self.default_server:
|
||||||
try:
|
try:
|
||||||
deserialize_server(self.default_server)
|
self.default_server = ServerAddr.from_str(self.default_server)
|
||||||
except:
|
except:
|
||||||
self.logger.warning('failed to parse server-string; falling back to localhost.')
|
self.logger.warning('failed to parse server-string; falling back to localhost.')
|
||||||
self.default_server = "localhost:50002:s"
|
self.default_server = ServerAddr.from_str("localhost:50002:s")
|
||||||
if not self.default_server:
|
else:
|
||||||
self.default_server = pick_random_server()
|
self.default_server = pick_random_server()
|
||||||
|
assert isinstance(self.default_server, ServerAddr), f"invalid type for default_server: {self.default_server!r}"
|
||||||
|
|
||||||
self.taskgroup = None # type: TaskGroup
|
self.taskgroup = None
|
||||||
|
|
||||||
# locks
|
# locks
|
||||||
self.restart_lock = asyncio.Lock()
|
self.restart_lock = asyncio.Lock()
|
||||||
|
@ -295,10 +304,10 @@ class Network(Logger):
|
||||||
self.server_retry_time = time.time()
|
self.server_retry_time = time.time()
|
||||||
self.nodes_retry_time = time.time()
|
self.nodes_retry_time = time.time()
|
||||||
# the main server we are currently communicating with
|
# the main server we are currently communicating with
|
||||||
self.interface = None # type: Optional[Interface]
|
self.interface = None
|
||||||
self.default_server_changed_event = asyncio.Event()
|
self.default_server_changed_event = asyncio.Event()
|
||||||
# set of servers we have an ongoing connection with
|
# set of servers we have an ongoing connection with
|
||||||
self.interfaces = {} # type: Dict[str, Interface]
|
self.interfaces = {}
|
||||||
self.auto_connect = self.config.get('auto_connect', True)
|
self.auto_connect = self.config.get('auto_connect', True)
|
||||||
self.connecting = set()
|
self.connecting = set()
|
||||||
self.server_queue = None
|
self.server_queue = None
|
||||||
|
@ -347,14 +356,15 @@ class Network(Logger):
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
return func_wrapper
|
return func_wrapper
|
||||||
|
|
||||||
def _read_recent_servers(self):
|
def _read_recent_servers(self) -> List[ServerAddr]:
|
||||||
if not self.config.path:
|
if not self.config.path:
|
||||||
return []
|
return []
|
||||||
path = os.path.join(self.config.path, "recent_servers")
|
path = os.path.join(self.config.path, "recent_servers")
|
||||||
try:
|
try:
|
||||||
with open(path, "r", encoding='utf-8') as f:
|
with open(path, "r", encoding='utf-8') as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
return json.loads(data)
|
servers_list = json.loads(data)
|
||||||
|
return [ServerAddr.from_str(s) for s in servers_list]
|
||||||
except:
|
except:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -363,7 +373,7 @@ class Network(Logger):
|
||||||
if not self.config.path:
|
if not self.config.path:
|
||||||
return
|
return
|
||||||
path = os.path.join(self.config.path, "recent_servers")
|
path = os.path.join(self.config.path, "recent_servers")
|
||||||
s = json.dumps(self.recent_servers, indent=4, sort_keys=True)
|
s = json.dumps(self.recent_servers, indent=4, sort_keys=True, cls=MyEncoder)
|
||||||
try:
|
try:
|
||||||
with open(path, "w", encoding='utf-8') as f:
|
with open(path, "w", encoding='utf-8') as f:
|
||||||
f.write(s)
|
f.write(s)
|
||||||
|
@ -462,10 +472,10 @@ class Network(Logger):
|
||||||
util.trigger_callback(key, self.get_status_value(key))
|
util.trigger_callback(key, self.get_status_value(key))
|
||||||
|
|
||||||
def get_parameters(self) -> NetworkParameters:
|
def get_parameters(self) -> NetworkParameters:
|
||||||
host, port, protocol = deserialize_server(self.default_server)
|
server = self.default_server
|
||||||
return NetworkParameters(host=host,
|
return NetworkParameters(host=server.host,
|
||||||
port=port,
|
port=str(server.port),
|
||||||
protocol=protocol,
|
protocol=server.protocol,
|
||||||
proxy=self.proxy,
|
proxy=self.proxy,
|
||||||
auto_connect=self.auto_connect,
|
auto_connect=self.auto_connect,
|
||||||
oneserver=self.oneserver)
|
oneserver=self.oneserver)
|
||||||
|
@ -474,7 +484,7 @@ class Network(Logger):
|
||||||
if self.is_connected():
|
if self.is_connected():
|
||||||
return self.donation_address
|
return self.donation_address
|
||||||
|
|
||||||
def get_interfaces(self) -> List[str]:
|
def get_interfaces(self) -> List[ServerAddr]:
|
||||||
"""The list of servers for the connected interfaces."""
|
"""The list of servers for the connected interfaces."""
|
||||||
with self.interfaces_lock:
|
with self.interfaces_lock:
|
||||||
return list(self.interfaces)
|
return list(self.interfaces)
|
||||||
|
@ -516,21 +526,18 @@ class Network(Logger):
|
||||||
# hardcoded servers
|
# hardcoded servers
|
||||||
out.update(constants.net.DEFAULT_SERVERS)
|
out.update(constants.net.DEFAULT_SERVERS)
|
||||||
# add recent servers
|
# add recent servers
|
||||||
for s in self.recent_servers:
|
for server in self.recent_servers:
|
||||||
try:
|
port = str(server.port)
|
||||||
host, port, protocol = deserialize_server(s)
|
if server.host in out:
|
||||||
except:
|
out[server.host].update({server.protocol: port})
|
||||||
continue
|
|
||||||
if host in out:
|
|
||||||
out[host].update({protocol: port})
|
|
||||||
else:
|
else:
|
||||||
out[host] = {protocol: port}
|
out[server.host] = {server.protocol: port}
|
||||||
# potentially filter out some
|
# potentially filter out some
|
||||||
if self.config.get('noonion'):
|
if self.config.get('noonion'):
|
||||||
out = filter_noonion(out)
|
out = filter_noonion(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _start_interface(self, server: str):
|
def _start_interface(self, server: ServerAddr):
|
||||||
if server not in self.interfaces and server not in self.connecting:
|
if server not in self.interfaces and server not in self.connecting:
|
||||||
if server == self.default_server:
|
if server == self.default_server:
|
||||||
self.logger.info(f"connecting to {server} as new interface")
|
self.logger.info(f"connecting to {server} as new interface")
|
||||||
|
@ -538,10 +545,10 @@ class Network(Logger):
|
||||||
self.connecting.add(server)
|
self.connecting.add(server)
|
||||||
self.server_queue.put(server)
|
self.server_queue.put(server)
|
||||||
|
|
||||||
def _start_random_interface(self):
|
def _start_random_interface(self) -> Optional[ServerAddr]:
|
||||||
with self.interfaces_lock:
|
with self.interfaces_lock:
|
||||||
exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting
|
exclude_set = self.disconnected_servers | set(self.interfaces) | self.connecting
|
||||||
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
|
server = pick_random_server(self.get_servers(), protocol=self.protocol, exclude_set=exclude_set)
|
||||||
if server:
|
if server:
|
||||||
self._start_interface(server)
|
self._start_interface(server)
|
||||||
return server
|
return server
|
||||||
|
@ -557,10 +564,9 @@ class Network(Logger):
|
||||||
proxy = net_params.proxy
|
proxy = net_params.proxy
|
||||||
proxy_str = serialize_proxy(proxy)
|
proxy_str = serialize_proxy(proxy)
|
||||||
host, port, protocol = net_params.host, net_params.port, net_params.protocol
|
host, port, protocol = net_params.host, net_params.port, net_params.protocol
|
||||||
server_str = serialize_server(host, port, protocol)
|
|
||||||
# sanitize parameters
|
# sanitize parameters
|
||||||
try:
|
try:
|
||||||
deserialize_server(serialize_server(host, port, protocol))
|
server = ServerAddr(host, port, protocol=protocol)
|
||||||
if proxy:
|
if proxy:
|
||||||
proxy_modes.index(proxy['mode']) + 1
|
proxy_modes.index(proxy['mode']) + 1
|
||||||
int(proxy['port'])
|
int(proxy['port'])
|
||||||
|
@ -569,9 +575,9 @@ class Network(Logger):
|
||||||
self.config.set_key('auto_connect', net_params.auto_connect, False)
|
self.config.set_key('auto_connect', net_params.auto_connect, False)
|
||||||
self.config.set_key('oneserver', net_params.oneserver, False)
|
self.config.set_key('oneserver', net_params.oneserver, False)
|
||||||
self.config.set_key('proxy', proxy_str, False)
|
self.config.set_key('proxy', proxy_str, False)
|
||||||
self.config.set_key('server', server_str, True)
|
self.config.set_key('server', str(server), True)
|
||||||
# abort if changes were not allowed by config
|
# abort if changes were not allowed by config
|
||||||
if self.config.get('server') != server_str \
|
if self.config.get('server') != str(server) \
|
||||||
or self.config.get('proxy') != proxy_str \
|
or self.config.get('proxy') != proxy_str \
|
||||||
or self.config.get('oneserver') != net_params.oneserver:
|
or self.config.get('oneserver') != net_params.oneserver:
|
||||||
return
|
return
|
||||||
|
@ -581,10 +587,10 @@ class Network(Logger):
|
||||||
if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver:
|
if self.proxy != proxy or self.protocol != protocol or self.oneserver != net_params.oneserver:
|
||||||
# Restart the network defaulting to the given server
|
# Restart the network defaulting to the given server
|
||||||
await self._stop()
|
await self._stop()
|
||||||
self.default_server = server_str
|
self.default_server = server
|
||||||
await self._start()
|
await self._start()
|
||||||
elif self.default_server != server_str:
|
elif self.default_server != server:
|
||||||
await self.switch_to_interface(server_str)
|
await self.switch_to_interface(server)
|
||||||
else:
|
else:
|
||||||
await self.switch_lagging_interface()
|
await self.switch_lagging_interface()
|
||||||
|
|
||||||
|
@ -646,7 +652,7 @@ class Network(Logger):
|
||||||
# FIXME switch to best available?
|
# FIXME switch to best available?
|
||||||
self.logger.info("tried to switch to best chain but no interfaces are on it")
|
self.logger.info("tried to switch to best chain but no interfaces are on it")
|
||||||
|
|
||||||
async def switch_to_interface(self, server: str):
|
async def switch_to_interface(self, server: ServerAddr):
|
||||||
"""Switch to server as our main interface. If no connection exists,
|
"""Switch to server as our main interface. If no connection exists,
|
||||||
queue interface to be started. The actual switch will
|
queue interface to be started. The actual switch will
|
||||||
happen when the interface becomes ready.
|
happen when the interface becomes ready.
|
||||||
|
@ -722,8 +728,8 @@ class Network(Logger):
|
||||||
|
|
||||||
@ignore_exceptions # do not kill main_taskgroup
|
@ignore_exceptions # do not kill main_taskgroup
|
||||||
@log_exceptions
|
@log_exceptions
|
||||||
async def _run_new_interface(self, server):
|
async def _run_new_interface(self, server: ServerAddr):
|
||||||
interface = Interface(self, server, self.proxy)
|
interface = Interface(network=self, server=server, proxy=self.proxy)
|
||||||
# note: using longer timeouts here as DNS can sometimes be slow!
|
# note: using longer timeouts here as DNS can sometimes be slow!
|
||||||
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
|
timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic)
|
||||||
try:
|
try:
|
||||||
|
@ -1070,23 +1076,26 @@ class Network(Logger):
|
||||||
with self.interfaces_lock: interfaces = list(self.interfaces.values())
|
with self.interfaces_lock: interfaces = list(self.interfaces.values())
|
||||||
interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
|
interfaces_on_selected_chain = list(filter(lambda iface: iface.blockchain == bc, interfaces))
|
||||||
if len(interfaces_on_selected_chain) == 0: return
|
if len(interfaces_on_selected_chain) == 0: return
|
||||||
chosen_iface = random.choice(interfaces_on_selected_chain)
|
chosen_iface = random.choice(interfaces_on_selected_chain) # type: Interface
|
||||||
# switch to server (and save to config)
|
# switch to server (and save to config)
|
||||||
net_params = self.get_parameters()
|
net_params = self.get_parameters()
|
||||||
host, port, protocol = deserialize_server(chosen_iface.server)
|
server = chosen_iface.server
|
||||||
net_params = net_params._replace(host=host, port=port, protocol=protocol)
|
net_params = net_params._replace(host=server.host,
|
||||||
|
port=str(server.port),
|
||||||
|
protocol=server.protocol)
|
||||||
await self.set_parameters(net_params)
|
await self.set_parameters(net_params)
|
||||||
|
|
||||||
async def follow_chain_given_server(self, server_str: str) -> None:
|
async def follow_chain_given_server(self, server: ServerAddr) -> None:
|
||||||
# note that server_str should correspond to a connected interface
|
# note that server_str should correspond to a connected interface
|
||||||
iface = self.interfaces.get(server_str)
|
iface = self.interfaces.get(server)
|
||||||
if iface is None:
|
if iface is None:
|
||||||
return
|
return
|
||||||
self._set_preferred_chain(iface.blockchain)
|
self._set_preferred_chain(iface.blockchain)
|
||||||
# switch to server (and save to config)
|
# switch to server (and save to config)
|
||||||
net_params = self.get_parameters()
|
net_params = self.get_parameters()
|
||||||
host, port, protocol = deserialize_server(server_str)
|
net_params = net_params._replace(host=server.host,
|
||||||
net_params = net_params._replace(host=host, port=port, protocol=protocol)
|
port=str(server.port),
|
||||||
|
protocol=server.protocol)
|
||||||
await self.set_parameters(net_params)
|
await self.set_parameters(net_params)
|
||||||
|
|
||||||
def get_local_height(self):
|
def get_local_height(self):
|
||||||
|
@ -1107,7 +1116,7 @@ class Network(Logger):
|
||||||
assert not self.connecting and not self.server_queue
|
assert not self.connecting and not self.server_queue
|
||||||
self.logger.info('starting network')
|
self.logger.info('starting network')
|
||||||
self.disconnected_servers = set([])
|
self.disconnected_servers = set([])
|
||||||
self.protocol = deserialize_server(self.default_server)[2]
|
self.protocol = self.default_server.protocol
|
||||||
self.server_queue = queue.Queue()
|
self.server_queue = queue.Queue()
|
||||||
self._set_proxy(deserialize_proxy(self.config.get('proxy')))
|
self._set_proxy(deserialize_proxy(self.config.get('proxy')))
|
||||||
self._set_oneserver(self.config.get('oneserver', False))
|
self._set_oneserver(self.config.get('oneserver', False))
|
||||||
|
@ -1147,9 +1156,9 @@ class Network(Logger):
|
||||||
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
|
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
|
||||||
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
|
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
|
||||||
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
|
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
|
||||||
self.taskgroup = None # type: TaskGroup
|
self.taskgroup = None
|
||||||
self.interface = None # type: Interface
|
self.interface = None
|
||||||
self.interfaces = {} # type: Dict[str, Interface]
|
self.interfaces = {}
|
||||||
self.connecting.clear()
|
self.connecting.clear()
|
||||||
self.server_queue = None
|
self.server_queue = None
|
||||||
if not full_shutdown:
|
if not full_shutdown:
|
||||||
|
@ -1268,8 +1277,8 @@ class Network(Logger):
|
||||||
|
|
||||||
async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence):
|
async def send_multiple_requests(self, servers: List[str], method: str, params: Sequence):
|
||||||
responses = dict()
|
responses = dict()
|
||||||
async def get_response(server):
|
async def get_response(server: ServerAddr):
|
||||||
interface = Interface(self, server, self.proxy)
|
interface = Interface(network=self, server=server, proxy=self.proxy)
|
||||||
timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent)
|
timeout = self.get_network_timeout_seconds(NetworkTimeout.Urgent)
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(interface.ready, timeout)
|
await asyncio.wait_for(interface.ready, timeout)
|
||||||
|
@ -1283,5 +1292,6 @@ class Network(Logger):
|
||||||
responses[interface.server] = res
|
responses[interface.server] = res
|
||||||
async with TaskGroup() as group:
|
async with TaskGroup() as group:
|
||||||
for server in servers:
|
for server in servers:
|
||||||
|
server = ServerAddr.from_str(server)
|
||||||
await group.spawn(get_response(server))
|
await group.spawn(get_response(server))
|
||||||
return responses
|
return responses
|
||||||
|
|
Loading…
Add table
Reference in a new issue