mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-23 17:47:31 +00:00
LNPeerAddr: validate arguments
no longer subclassing NamedTuple (as it is difficult to do validation then...)
This commit is contained in:
parent
edba59ef54
commit
13d6997355
4 changed files with 53 additions and 20 deletions
|
@ -281,13 +281,19 @@ class ChannelDB(SqlDB):
|
||||||
return None
|
return None
|
||||||
addr = sorted(list(r), key=lambda x: x[2])[0]
|
addr = sorted(list(r), key=lambda x: x[2])[0]
|
||||||
host, port, timestamp = addr
|
host, port, timestamp = addr
|
||||||
|
try:
|
||||||
return LNPeerAddr(host, port, node_id)
|
return LNPeerAddr(host, port, node_id)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_recent_peers(self):
|
def get_recent_peers(self):
|
||||||
assert self.data_loaded.is_set(), "channelDB load_data did not finish yet!"
|
assert self.data_loaded.is_set(), "channelDB load_data did not finish yet!"
|
||||||
r = [self.get_last_good_address(x) for x in self._addresses.keys()]
|
# FIXME this does not reliably return "recent" peers...
|
||||||
r = r[-self.NUM_MAX_RECENT_PEERS:]
|
# Also, the list() cast over the whole dict (thousands of elements),
|
||||||
return r
|
# is really inefficient.
|
||||||
|
r = [self.get_last_good_address(node_id)
|
||||||
|
for node_id in list(self._addresses.keys())[-self.NUM_MAX_RECENT_PEERS:]]
|
||||||
|
return list(reversed(r))
|
||||||
|
|
||||||
def add_channel_announcement(self, msg_payloads, trusted=True):
|
def add_channel_announcement(self, msg_payloads, trusted=True):
|
||||||
if type(msg_payloads) is dict:
|
if type(msg_payloads) is dict:
|
||||||
|
|
|
@ -8,11 +8,12 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import StreamReader, StreamWriter
|
from asyncio import StreamReader, StreamWriter
|
||||||
|
|
||||||
from Cryptodome.Cipher import ChaCha20_Poly1305
|
from Cryptodome.Cipher import ChaCha20_Poly1305
|
||||||
|
|
||||||
from .crypto import sha256, hmac_oneshot
|
from .crypto import sha256, hmac_oneshot
|
||||||
from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
|
from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
|
||||||
HandshakeFailed)
|
HandshakeFailed, LNPeerAddr)
|
||||||
from . import ecc
|
from . import ecc
|
||||||
from .util import bh2u
|
from .util import bh2u
|
||||||
|
|
||||||
|
@ -86,7 +87,13 @@ def create_ephemeral_key() -> (bytes, bytes):
|
||||||
privkey = ecc.ECPrivkey.generate_random_key()
|
privkey = ecc.ECPrivkey.generate_random_key()
|
||||||
return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
|
return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
|
||||||
|
|
||||||
|
|
||||||
class LNTransportBase:
|
class LNTransportBase:
|
||||||
|
reader: StreamReader
|
||||||
|
writer: StreamWriter
|
||||||
|
|
||||||
|
def name(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def send_bytes(self, msg: bytes) -> None:
|
def send_bytes(self, msg: bytes) -> None:
|
||||||
l = len(msg).to_bytes(2, 'big')
|
l = len(msg).to_bytes(2, 'big')
|
||||||
|
@ -207,21 +214,18 @@ class LNResponderTransport(LNTransportBase):
|
||||||
|
|
||||||
class LNTransport(LNTransportBase):
|
class LNTransport(LNTransportBase):
|
||||||
|
|
||||||
def __init__(self, privkey: bytes, peer_addr):
|
def __init__(self, privkey: bytes, peer_addr: LNPeerAddr):
|
||||||
LNTransportBase.__init__(self)
|
LNTransportBase.__init__(self)
|
||||||
assert type(privkey) is bytes and len(privkey) == 32
|
assert type(privkey) is bytes and len(privkey) == 32
|
||||||
self.privkey = privkey
|
self.privkey = privkey
|
||||||
self.remote_pubkey = peer_addr.pubkey
|
|
||||||
self.host = peer_addr.host
|
|
||||||
self.port = peer_addr.port
|
|
||||||
self.peer_addr = peer_addr
|
self.peer_addr = peer_addr
|
||||||
|
|
||||||
def name(self):
|
def name(self):
|
||||||
return str(self.host) + ':' + str(self.port)
|
return self.peer_addr.net_addr_str()
|
||||||
|
|
||||||
async def handshake(self):
|
async def handshake(self):
|
||||||
self.reader, self.writer = await asyncio.open_connection(self.host, self.port)
|
self.reader, self.writer = await asyncio.open_connection(self.peer_addr.host, self.peer_addr.port)
|
||||||
hs = HandshakeState(self.remote_pubkey)
|
hs = HandshakeState(self.peer_addr.pubkey)
|
||||||
# Get a new ephemeral key
|
# Get a new ephemeral key
|
||||||
epriv, epub = create_ephemeral_key()
|
epriv, epub = create_ephemeral_key()
|
||||||
|
|
||||||
|
@ -230,7 +234,8 @@ class LNTransport(LNTransportBase):
|
||||||
self.writer.write(msg)
|
self.writer.write(msg)
|
||||||
rspns = await self.reader.read(2**10)
|
rspns = await self.reader.read(2**10)
|
||||||
if len(rspns) != 50:
|
if len(rspns) != 50:
|
||||||
raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, are you sure this is the right pubkey? {bh2u(self.remote_pubkey)}")
|
raise HandshakeFailed(f"Lightning handshake act 1 response has bad length, "
|
||||||
|
f"are you sure this is the right pubkey? {self.peer_addr}")
|
||||||
hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:]
|
hver, alice_epub, tag = rspns[0], rspns[1:34], rspns[34:]
|
||||||
if bytes([hver]) != hs.handshake_version:
|
if bytes([hver]) != hs.handshake_version:
|
||||||
raise HandshakeFailed("unexpected handshake version: {}".format(hver))
|
raise HandshakeFailed("unexpected handshake version: {}".format(hver))
|
||||||
|
|
|
@ -658,14 +658,31 @@ class LnGlobalFeatures(IntFlag):
|
||||||
LN_GLOBAL_FEATURES_KNOWN_SET = set(LnGlobalFeatures)
|
LN_GLOBAL_FEATURES_KNOWN_SET = set(LnGlobalFeatures)
|
||||||
|
|
||||||
|
|
||||||
class LNPeerAddr(NamedTuple):
|
class LNPeerAddr:
|
||||||
host: str
|
|
||||||
port: int
|
def __init__(self, host: str, port: int, pubkey: bytes):
|
||||||
pubkey: bytes
|
assert isinstance(host, str), repr(host)
|
||||||
|
assert isinstance(port, int), repr(port)
|
||||||
|
assert isinstance(pubkey, bytes), repr(pubkey)
|
||||||
|
try:
|
||||||
|
net_addr = NetAddress(host, port) # this validates host and port
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"cannot construct LNPeerAddr: invalid host or port (host={host}, port={port})") from e
|
||||||
|
# note: not validating pubkey as it would be too expensive:
|
||||||
|
# if not ECPubkey.is_pubkey_bytes(pubkey): raise ValueError()
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.pubkey = pubkey
|
||||||
|
self._net_addr_str = str(net_addr)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
host_and_port = str(NetAddress(self.host, self.port))
|
return '{}@{}'.format(self.pubkey.hex(), self.net_addr_str())
|
||||||
return '{}@{}'.format(self.pubkey.hex(), host_and_port)
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'<LNPeerAddr host={self.host} port={self.port} pubkey={self.pubkey.hex()}>'
|
||||||
|
|
||||||
|
def net_addr_str(self) -> str:
|
||||||
|
return self._net_addr_str
|
||||||
|
|
||||||
|
|
||||||
def get_compressed_pubkey_from_bech32(bech32_pubkey: str) -> bytes:
|
def get_compressed_pubkey_from_bech32(bech32_pubkey: str) -> bytes:
|
||||||
|
|
|
@ -221,7 +221,10 @@ class LNWorker(Logger):
|
||||||
if not addrs:
|
if not addrs:
|
||||||
continue
|
continue
|
||||||
host, port, timestamp = self.choose_preferred_address(list(addrs))
|
host, port, timestamp = self.choose_preferred_address(list(addrs))
|
||||||
|
try:
|
||||||
peer = LNPeerAddr(host, port, node_id)
|
peer = LNPeerAddr(host, port, node_id)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
if peer in self._last_tried_peer:
|
if peer in self._last_tried_peer:
|
||||||
continue
|
continue
|
||||||
#self.logger.info('taking random ln peer from our channel db')
|
#self.logger.info('taking random ln peer from our channel db')
|
||||||
|
@ -1265,6 +1268,8 @@ class LNWallet(LNWorker):
|
||||||
self.network.trigger_callback('channels_updated', self.wallet)
|
self.network.trigger_callback('channels_updated', self.wallet)
|
||||||
self.network.trigger_callback('wallet_updated', self.wallet)
|
self.network.trigger_callback('wallet_updated', self.wallet)
|
||||||
|
|
||||||
|
@ignore_exceptions
|
||||||
|
@log_exceptions
|
||||||
async def reestablish_peer_for_given_channel(self, chan):
|
async def reestablish_peer_for_given_channel(self, chan):
|
||||||
now = time.time()
|
now = time.time()
|
||||||
# try last good address first
|
# try last good address first
|
||||||
|
|
Loading…
Add table
Reference in a new issue