mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-09-15 06:59:50 +00:00
test_lnbase: add test that pays to another local electrum
This commit is contained in:
parent
16a9aa322f
commit
751942442e
3 changed files with 146 additions and 26 deletions
|
@ -350,6 +350,11 @@ class Peer(PrintError):
|
||||||
@log_exceptions
|
@log_exceptions
|
||||||
@handle_disconnect
|
@handle_disconnect
|
||||||
async def main_loop(self):
|
async def main_loop(self):
|
||||||
|
"""
|
||||||
|
This is used in LNWorker and is necessary so that we don't kill the main
|
||||||
|
task group. It is not merged with _main_loop, so that we can test if the
|
||||||
|
correct exceptions are getting thrown using _main_loop.
|
||||||
|
"""
|
||||||
await self._main_loop()
|
await self._main_loop()
|
||||||
|
|
||||||
async def _main_loop(self):
|
async def _main_loop(self):
|
||||||
|
|
|
@ -32,7 +32,6 @@ from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
|
||||||
generate_keypair, LnKeyFamily, LOCAL, REMOTE,
|
generate_keypair, LnKeyFamily, LOCAL, REMOTE,
|
||||||
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
|
UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE,
|
||||||
NUM_MAX_EDGES_IN_PAYMENT_PATH)
|
NUM_MAX_EDGES_IN_PAYMENT_PATH)
|
||||||
from .lnaddr import lndecode
|
|
||||||
from .i18n import _
|
from .i18n import _
|
||||||
from .lnrouter import RouteEdge, is_route_sane_to_use
|
from .lnrouter import RouteEdge, is_route_sane_to_use
|
||||||
|
|
||||||
|
@ -258,6 +257,15 @@ class LNWorker(PrintError):
|
||||||
return bh2u(chan.node_id)
|
return bh2u(chan.node_id)
|
||||||
|
|
||||||
def pay(self, invoice, amount_sat=None):
|
def pay(self, invoice, amount_sat=None):
|
||||||
|
"""
|
||||||
|
This is not merged with _pay so that we can run the test with
|
||||||
|
one thread only.
|
||||||
|
"""
|
||||||
|
addr, peer, coro = self._pay(invoice, amount_sat)
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
||||||
|
return addr, peer, fut
|
||||||
|
|
||||||
|
def _pay(self, invoice, amount_sat=None):
|
||||||
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
|
addr = lndecode(invoice, expected_hrp=constants.net.SEGWIT_HRP)
|
||||||
payment_hash = addr.paymenthash
|
payment_hash = addr.paymenthash
|
||||||
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
|
amount_sat = (addr.amount * COIN) if addr.amount else amount_sat
|
||||||
|
@ -279,7 +287,7 @@ class LNWorker(PrintError):
|
||||||
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
|
raise Exception("PathFinder returned path with short_channel_id {} that is not in channel list".format(bh2u(short_channel_id)))
|
||||||
peer = self.peers[node_id]
|
peer = self.peers[node_id]
|
||||||
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
|
coro = peer.pay(route, chan, amount_msat, payment_hash, addr.get_min_final_cltv_expiry())
|
||||||
return addr, peer, asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
return addr, peer, coro
|
||||||
|
|
||||||
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
|
def _create_route_from_invoice(self, decoded_invoice, amount_msat) -> List[RouteEdge]:
|
||||||
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
invoice_pubkey = decoded_invoice.pubkey.serialize()
|
||||||
|
|
|
@ -1,16 +1,40 @@
|
||||||
|
import unittest
|
||||||
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
from decimal import Decimal
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from electrum.network import Network
|
||||||
|
from electrum.ecc import ECPrivkey
|
||||||
|
from electrum import simple_config, lnutil
|
||||||
|
from electrum.lnaddr import lnencode, LnAddr, lndecode
|
||||||
|
from electrum.bitcoin import COIN, sha256
|
||||||
|
from electrum.util import bh2u
|
||||||
|
|
||||||
from electrum.lnbase import Peer, decode_msg, gen_msg
|
from electrum.lnbase import Peer, decode_msg, gen_msg
|
||||||
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
|
||||||
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
|
||||||
from electrum.ecc import ECPrivkey
|
from electrum.lnrouter import ChannelDB, LNPathFinder
|
||||||
from electrum.lnrouter import ChannelDB
|
from electrum.lnworker import LNWorker
|
||||||
import unittest
|
|
||||||
import asyncio
|
|
||||||
from electrum import simple_config
|
|
||||||
import tempfile
|
|
||||||
from .test_lnchan import create_test_channels
|
from .test_lnchan import create_test_channels
|
||||||
|
|
||||||
|
def keypair():
|
||||||
|
priv = ECPrivkey.generate_random_key().get_secret_bytes()
|
||||||
|
k1 = Keypair(
|
||||||
|
pubkey=privkey_to_pubkey(priv),
|
||||||
|
privkey=priv)
|
||||||
|
return k1
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def noop_lock():
|
||||||
|
yield
|
||||||
|
|
||||||
class MockNetwork:
|
class MockNetwork:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.callbacks = defaultdict(list)
|
||||||
self.lnwatcher = None
|
self.lnwatcher = None
|
||||||
user_config = {}
|
user_config = {}
|
||||||
user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
|
user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
|
||||||
|
@ -18,49 +42,132 @@ class MockNetwork:
|
||||||
self.asyncio_loop = asyncio.get_event_loop()
|
self.asyncio_loop = asyncio.get_event_loop()
|
||||||
self.channel_db = ChannelDB(self)
|
self.channel_db = ChannelDB(self)
|
||||||
self.interface = None
|
self.interface = None
|
||||||
def register_callback(self, cb, trigger_names):
|
self.path_finder = LNPathFinder(self.channel_db)
|
||||||
print("callback registered", repr(trigger_names))
|
|
||||||
def trigger_callback(self, trigger_name, obj):
|
@property
|
||||||
print("callback triggered", repr(trigger_name))
|
def callback_lock(self):
|
||||||
|
return noop_lock()
|
||||||
|
|
||||||
|
register_callback = Network.register_callback
|
||||||
|
unregister_callback = Network.unregister_callback
|
||||||
|
trigger_callback = Network.trigger_callback
|
||||||
|
|
||||||
|
def get_local_height(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
class MockLNWorker:
|
class MockLNWorker:
|
||||||
def __init__(self, remote_peer_pubkey, chan):
|
def __init__(self, remote_keypair, local_keypair, chan):
|
||||||
self.chan = chan
|
self.chan = chan
|
||||||
self.remote_peer_pubkey = remote_peer_pubkey
|
self.remote_keypair = remote_keypair
|
||||||
priv = ECPrivkey.generate_random_key().get_secret_bytes()
|
self.node_keypair = local_keypair
|
||||||
self.node_keypair = Keypair(
|
|
||||||
pubkey=privkey_to_pubkey(priv),
|
|
||||||
privkey=priv)
|
|
||||||
self.network = MockNetwork()
|
self.network = MockNetwork()
|
||||||
|
self.channels = {self.chan.channel_id: self.chan}
|
||||||
|
self.invoices = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lock(self):
|
||||||
|
return noop_lock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def peers(self):
|
def peers(self):
|
||||||
return {self.remote_peer_pubkey: self.peer}
|
return {self.remote_keypair.pubkey: self.peer}
|
||||||
|
|
||||||
def channels_for_peer(self, pubkey):
|
def channels_for_peer(self, pubkey):
|
||||||
return {self.chan.channel_id: self.chan}
|
return self.channels
|
||||||
|
|
||||||
|
def save_channel(self, chan):
|
||||||
|
pass
|
||||||
|
|
||||||
|
get_invoice = LNWorker.get_invoice
|
||||||
|
_create_route_from_invoice = LNWorker._create_route_from_invoice
|
||||||
|
|
||||||
class MockTransport:
|
class MockTransport:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
|
|
||||||
async def read_messages(self):
|
async def read_messages(self):
|
||||||
while True:
|
while True:
|
||||||
yield await self.queue.get()
|
yield await self.queue.get()
|
||||||
|
|
||||||
class BadFeaturesTransport(MockTransport):
|
class NoFeaturesTransport(MockTransport):
|
||||||
|
"""
|
||||||
|
This answers the init message with a init that doesn't signal any features.
|
||||||
|
Used for testing that we require DATA_LOSS_PROTECT.
|
||||||
|
"""
|
||||||
def send_bytes(self, data):
|
def send_bytes(self, data):
|
||||||
decoded = decode_msg(data)
|
decoded = decode_msg(data)
|
||||||
print(decoded)
|
print(decoded)
|
||||||
if decoded[0] == 'init':
|
if decoded[0] == 'init':
|
||||||
self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
|
self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
|
||||||
|
|
||||||
|
class PutIntoOthersQueueTransport(MockTransport):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.other_mock_transport = None
|
||||||
|
|
||||||
|
def send_bytes(self, data):
|
||||||
|
self.other_mock_transport.queue.put_nowait(data)
|
||||||
|
|
||||||
|
def transport_pair():
|
||||||
|
t1 = PutIntoOthersQueueTransport()
|
||||||
|
t2 = PutIntoOthersQueueTransport()
|
||||||
|
t1.other_mock_transport = t2
|
||||||
|
t2.other_mock_transport = t1
|
||||||
|
return t1, t2
|
||||||
|
|
||||||
class TestPeer(unittest.TestCase):
|
class TestPeer(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.alice_channel, self.bob_channel = create_test_channels()
|
self.alice_channel, self.bob_channel = create_test_channels()
|
||||||
def test_bad_feature_flags(self):
|
|
||||||
# we should require DATA_LOSS_PROTECT
|
def test_require_data_loss_protect(self):
|
||||||
mock_lnworker = MockLNWorker(b"\x00" * 32, self.alice_channel)
|
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel)
|
||||||
mock_transport = BadFeaturesTransport()
|
mock_transport = NoFeaturesTransport()
|
||||||
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 32), request_initial_sync=False, transport=mock_transport)
|
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
|
||||||
mock_lnworker.peer = p1
|
mock_lnworker.peer = p1
|
||||||
with self.assertRaises(LightningPeerConnectionClosed):
|
with self.assertRaises(LightningPeerConnectionClosed):
|
||||||
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
|
asyncio.get_event_loop().run_until_complete(asyncio.wait_for(p1._main_loop(), 1))
|
||||||
|
|
||||||
|
def test_payment(self):
|
||||||
|
k1, k2 = keypair(), keypair()
|
||||||
|
t1, t2 = transport_pair()
|
||||||
|
w1 = MockLNWorker(k1, k2, self.alice_channel)
|
||||||
|
w2 = MockLNWorker(k2, k1, self.bob_channel)
|
||||||
|
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
|
||||||
|
request_initial_sync=False, transport=t1)
|
||||||
|
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
|
||||||
|
request_initial_sync=False, transport=t2)
|
||||||
|
w1.peer = p1
|
||||||
|
w2.peer = p2
|
||||||
|
# mark_open won't work if state is already OPEN.
|
||||||
|
# so set it to OPENING
|
||||||
|
self.alice_channel.set_state("OPENING")
|
||||||
|
self.bob_channel.set_state("OPENING")
|
||||||
|
# this populates the channel graph:
|
||||||
|
p1.mark_open(self.alice_channel)
|
||||||
|
p2.mark_open(self.bob_channel)
|
||||||
|
amount_btc = 100000/Decimal(COIN)
|
||||||
|
payment_preimage = os.urandom(32)
|
||||||
|
RHASH = sha256(payment_preimage)
|
||||||
|
addr = LnAddr(
|
||||||
|
RHASH,
|
||||||
|
amount_btc,
|
||||||
|
tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
|
||||||
|
('d', 'coffee')
|
||||||
|
])
|
||||||
|
pay_req = lnencode(addr, w2.node_keypair.privkey)
|
||||||
|
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
|
||||||
|
l = asyncio.get_event_loop()
|
||||||
|
async def pay():
|
||||||
|
fut = asyncio.Future()
|
||||||
|
def evt_set(event, _lnworker, msg):
|
||||||
|
fut.set_result(msg)
|
||||||
|
w2.network.register_callback(evt_set, ['ln_message'])
|
||||||
|
|
||||||
|
addr, peer, coro = LNWorker._pay(w1, pay_req)
|
||||||
|
await coro
|
||||||
|
print("HTLC ADDED")
|
||||||
|
self.assertEqual(await fut, 'Payment received')
|
||||||
|
gath.cancel()
|
||||||
|
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
|
||||||
|
with self.assertRaises(asyncio.CancelledError):
|
||||||
|
l.run_until_complete(gath)
|
||||||
|
|
Loading…
Add table
Reference in a new issue