mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-31 17:31:36 +00:00
lnbase: mark initialized later, add tests, etc
- consistent node_id sorting - require OPTION_DATA_LOSS_PROTECT and test it
This commit is contained in:
parent
3866df6650
commit
3efd8dedb6
2 changed files with 96 additions and 20 deletions
|
@ -201,6 +201,7 @@ class Peer(PrintError):
|
|||
self.peer_addr = peer_addr
|
||||
self.lnworker = lnworker
|
||||
self.privkey = lnworker.node_keypair.privkey
|
||||
self.node_ids = [peer_addr.pubkey, privkey_to_pubkey(self.privkey)]
|
||||
self.network = lnworker.network
|
||||
self.lnwatcher = lnworker.network.lnwatcher
|
||||
self.channel_db = lnworker.network.channel_db
|
||||
|
@ -218,7 +219,7 @@ class Peer(PrintError):
|
|||
self.localfeatures = LnLocalFeatures(0)
|
||||
if request_initial_sync:
|
||||
self.localfeatures |= LnLocalFeatures.INITIAL_ROUTING_SYNC
|
||||
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT
|
||||
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_REQ
|
||||
self.attempted_route = {}
|
||||
self.orphan_channel_updates = OrderedDict()
|
||||
|
||||
|
@ -234,7 +235,6 @@ class Peer(PrintError):
|
|||
await transport.handshake()
|
||||
self.transport = transport
|
||||
self.send_message("init", gflen=0, lflen=1, localfeatures=self.localfeatures)
|
||||
self.initialized.set_result(True)
|
||||
|
||||
@property
|
||||
def channels(self) -> Dict[bytes, Channel]:
|
||||
|
@ -310,6 +310,7 @@ class Peer(PrintError):
|
|||
raise LightningPeerConnectionClosed("remote does not have even flag {}"
|
||||
.format(str(LnLocalFeatures(1 << flag))))
|
||||
self.localfeatures ^= 1 << flag # disable flag
|
||||
self.initialized.set_result(True)
|
||||
|
||||
def on_channel_update(self, payload):
|
||||
try:
|
||||
|
@ -349,6 +350,13 @@ class Peer(PrintError):
|
|||
@log_exceptions
|
||||
@handle_disconnect
|
||||
async def main_loop(self):
|
||||
"""
|
||||
This is used from the GUI. It is not merged with the other function,
|
||||
so that we can test if the correct exceptions are getting thrown.
|
||||
"""
|
||||
await self._main_loop()
|
||||
|
||||
async def _main_loop(self):
|
||||
try:
|
||||
await asyncio.wait_for(self.initialize(), 10)
|
||||
except (OSError, asyncio.TimeoutError, HandshakeFailed) as e:
|
||||
|
@ -757,16 +765,17 @@ class Peer(PrintError):
|
|||
if not ecc.verify_signature(self.peer_addr.pubkey, remote_node_sig, h):
|
||||
raise Exception("node_sig invalid in announcement_signatures")
|
||||
|
||||
node_sigs = [local_node_sig, remote_node_sig]
|
||||
bitcoin_sigs = [local_bitcoin_sig, remote_bitcoin_sig]
|
||||
node_ids = [privkey_to_pubkey(self.privkey), self.peer_addr.pubkey]
|
||||
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
|
||||
node_sigs = [remote_node_sig, local_node_sig]
|
||||
bitcoin_sigs = [remote_bitcoin_sig, local_bitcoin_sig]
|
||||
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, chan.config[LOCAL].multisig_key.pubkey]
|
||||
|
||||
if node_ids[0] > node_ids[1]:
|
||||
if self.node_ids[0] > self.node_ids[1]:
|
||||
node_sigs.reverse()
|
||||
bitcoin_sigs.reverse()
|
||||
node_ids.reverse()
|
||||
node_ids = list(reversed(self.node_ids))
|
||||
bitcoin_keys.reverse()
|
||||
else:
|
||||
node_ids = self.node_ids
|
||||
|
||||
self.send_message("channel_announcement",
|
||||
node_signatures_1=node_sigs[0],
|
||||
|
@ -793,14 +802,13 @@ class Peer(PrintError):
|
|||
chan.set_state("OPEN")
|
||||
self.network.trigger_callback('channel', chan)
|
||||
# add channel to database
|
||||
pubkey_ours = self.lnworker.node_keypair.pubkey
|
||||
pubkey_theirs = self.peer_addr.pubkey
|
||||
node_ids = [pubkey_theirs, pubkey_ours]
|
||||
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
|
||||
sorted_node_ids = list(sorted(node_ids))
|
||||
if sorted_node_ids != node_ids:
|
||||
sorted_node_ids = list(sorted(self.node_ids))
|
||||
if sorted_node_ids != self.node_ids:
|
||||
node_ids = sorted_node_ids
|
||||
bitcoin_keys.reverse()
|
||||
else:
|
||||
node_ids = self.node_ids
|
||||
# note: we inject a channel announcement, and a channel update (for outgoing direction)
|
||||
# This is atm needed for
|
||||
# - finding routes
|
||||
|
@ -813,7 +821,10 @@ class Peer(PrintError):
|
|||
'bitcoin_key_1': bitcoin_keys[0], 'bitcoin_key_2': bitcoin_keys[1]},
|
||||
trusted=True)
|
||||
# only inject outgoing direction:
|
||||
channel_flags = b'\x00' if node_ids[0] == pubkey_ours else b'\x01'
|
||||
if node_ids[0] == privkey_to_pubkey(self.privkey):
|
||||
channel_flags = b'\x00'
|
||||
else:
|
||||
channel_flags = b'\x01'
|
||||
now = int(time.time()).to_bytes(4, byteorder="big")
|
||||
self.channel_db.on_channel_update({"short_channel_id": chan.short_channel_id, 'channel_flags': channel_flags, 'cltv_expiry_delta': b'\x90',
|
||||
'htlc_minimum_msat': b'\x03\xe8', 'fee_base_msat': b'\x03\xe8', 'fee_proportional_millionths': b'\x01',
|
||||
|
@ -832,16 +843,15 @@ class Peer(PrintError):
|
|||
|
||||
def send_announcement_signatures(self, chan):
|
||||
|
||||
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey,
|
||||
chan.config[REMOTE].multisig_key.pubkey]
|
||||
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
|
||||
chan.config[LOCAL].multisig_key.pubkey]
|
||||
|
||||
node_ids = [privkey_to_pubkey(self.privkey),
|
||||
self.peer_addr.pubkey]
|
||||
|
||||
sorted_node_ids = list(sorted(node_ids))
|
||||
sorted_node_ids = list(sorted(self.node_ids))
|
||||
if sorted_node_ids != node_ids:
|
||||
node_ids = sorted_node_ids
|
||||
bitcoin_keys.reverse()
|
||||
else:
|
||||
node_ids = self.node_ids
|
||||
|
||||
chan_ann = gen_msg("channel_announcement",
|
||||
len=0,
|
||||
|
|
66
electrum/tests/test_lnbase.py
Normal file
66
electrum/tests/test_lnbase.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
from electrum.lnbase import Peer, decode_msg, gen_msg
|
||||
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
||||
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
||||
from electrum.ecc import ECPrivkey
|
||||
from electrum.lnrouter import ChannelDB
|
||||
import unittest
|
||||
import asyncio
|
||||
from electrum import simple_config
|
||||
import tempfile
|
||||
from .test_lnchan import create_test_channels
|
||||
|
||||
class MockNetwork:
|
||||
def __init__(self):
|
||||
self.lnwatcher = None
|
||||
user_config = {}
|
||||
user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
|
||||
self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
|
||||
self.asyncio_loop = asyncio.get_event_loop()
|
||||
self.channel_db = ChannelDB(self)
|
||||
self.interface = None
|
||||
def register_callback(self, cb, trigger_names):
|
||||
print("callback registered", repr(trigger_names))
|
||||
def trigger_callback(self, trigger_name, obj):
|
||||
print("callback triggered", repr(trigger_name))
|
||||
|
||||
class MockLNWorker:
|
||||
def __init__(self, remote_peer_pubkey, chan):
|
||||
self.chan = chan
|
||||
self.remote_peer_pubkey = remote_peer_pubkey
|
||||
priv = ECPrivkey.generate_random_key().get_secret_bytes()
|
||||
self.node_keypair = Keypair(
|
||||
pubkey=privkey_to_pubkey(priv),
|
||||
privkey=priv)
|
||||
self.network = MockNetwork()
|
||||
@property
|
||||
def peers(self):
|
||||
return {self.remote_peer_pubkey: self.peer}
|
||||
def channels_for_peer(self, pubkey):
|
||||
return {self.chan.channel_id: self.chan}
|
||||
|
||||
class MockTransport:
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
async def read_messages(self):
|
||||
while True:
|
||||
yield await self.queue.get()
|
||||
|
||||
class BadFeaturesTransport(MockTransport):
|
||||
def send_bytes(self, data):
|
||||
decoded = decode_msg(data)
|
||||
print(decoded)
|
||||
if decoded[0] == 'init':
|
||||
self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
|
||||
|
||||
class TestPeer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.alice_channel, self.bob_channel = create_test_channels()
|
||||
def test_bad_feature_flags(self):
|
||||
# we should require DATA_LOSS_PROTECT
|
||||
mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel)
|
||||
mock_transport = BadFeaturesTransport()
|
||||
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport)
|
||||
mock_lnworker.peer = p1
|
||||
with self.assertRaises(LightningPeerConnectionClosed):
|
||||
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
|
||||
|
Loading…
Add table
Reference in a new issue