LBRY-Vault/electrum/tests/test_lnbase.py
2018-11-12 18:15:48 +01:00

240 lines
7.9 KiB
Python

import unittest
import asyncio
import tempfile
from decimal import Decimal
import os
from contextlib import contextmanager
from collections import defaultdict
from electrum.network import Network
from electrum.ecc import ECPrivkey
from electrum import simple_config, lnutil
from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u
from electrum.lnbase import Peer, decode_msg, gen_msg
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure
from electrum.lnrouter import ChannelDB, LNPathFinder
from electrum.lnworker import LNWorker
from .test_lnchan import create_test_channels
def keypair():
priv = ECPrivkey.generate_random_key().get_secret_bytes()
k1 = Keypair(
pubkey=privkey_to_pubkey(priv),
privkey=priv)
return k1
@contextmanager
def noop_lock():
yield
class MockNetwork:
def __init__(self, tx_queue):
self.callbacks = defaultdict(list)
self.lnwatcher = None
user_config = {}
user_dir = tempfile.mkdtemp(prefix="electrum-lnbase-test-")
self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
self.asyncio_loop = asyncio.get_event_loop()
self.channel_db = ChannelDB(self)
self.interface = None
self.path_finder = LNPathFinder(self.channel_db)
self.tx_queue = tx_queue
@property
def callback_lock(self):
return noop_lock()
register_callback = Network.register_callback
unregister_callback = Network.unregister_callback
trigger_callback = Network.trigger_callback
def get_local_height(self):
return 0
async def broadcast_transaction(self, tx):
if self.tx_queue:
await self.tx_queue.put(tx)
class MockStorage:
def put(self, key, value):
pass
def get(self, key, default=None):
pass
def write(self):
pass
class MockWallet:
storage = MockStorage()
class MockLNWorker:
def __init__(self, remote_keypair, local_keypair, chan, tx_queue):
self.chan = chan
self.remote_keypair = remote_keypair
self.node_keypair = local_keypair
self.network = MockNetwork(tx_queue)
self.channels = {self.chan.channel_id: self.chan}
self.invoices = {}
self.paying = {}
self.wallet = MockWallet()
@property
def lock(self):
return noop_lock()
@property
def peers(self):
return {self.remote_keypair.pubkey: self.peer}
def channels_for_peer(self, pubkey):
return self.channels
def save_channel(self, chan):
print("Ignoring channel save")
def on_channels_updated(self):
pass
get_invoice = LNWorker.get_invoice
_create_route_from_invoice = LNWorker._create_route_from_invoice
_check_invoice = staticmethod(LNWorker._check_invoice)
_pay_to_route = LNWorker._pay_to_route
force_close_channel = LNWorker.force_close_channel
class MockTransport:
def __init__(self):
self.queue = asyncio.Queue()
async def read_messages(self):
while True:
yield await self.queue.get()
class NoFeaturesTransport(MockTransport):
"""
This answers the init message with a init that doesn't signal any features.
Used for testing that we require DATA_LOSS_PROTECT.
"""
def send_bytes(self, data):
decoded = decode_msg(data)
print(decoded)
if decoded[0] == 'init':
self.queue.put_nowait(gen_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00"))
class PutIntoOthersQueueTransport(MockTransport):
def __init__(self):
super().__init__()
self.other_mock_transport = None
def send_bytes(self, data):
self.other_mock_transport.queue.put_nowait(data)
def transport_pair():
t1 = PutIntoOthersQueueTransport()
t2 = PutIntoOthersQueueTransport()
t1.other_mock_transport = t2
t2.other_mock_transport = t1
return t1, t2
class TestPeer(unittest.TestCase):
def setUp(self):
self.alice_channel, self.bob_channel = create_test_channels()
def test_require_data_loss_protect(self):
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
mock_transport = NoFeaturesTransport()
p1 = Peer(mock_lnworker, LNPeerAddr("bogus", 1337, b"\x00" * 33), request_initial_sync=False, transport=mock_transport)
mock_lnworker.peer = p1
with self.assertRaises(LightningPeerConnectionClosed):
run(asyncio.wait_for(p1._main_loop(), 1))
def prepare_peers(self):
k1, k2 = keypair(), keypair()
t1, t2 = transport_pair()
q1, q2 = asyncio.Queue(), asyncio.Queue()
w1 = MockLNWorker(k1, k2, self.alice_channel, tx_queue=q1)
w2 = MockLNWorker(k2, k1, self.bob_channel, tx_queue=q2)
p1 = Peer(w1, LNPeerAddr("bogus1", 1337, k1.pubkey),
request_initial_sync=False, transport=t1)
p2 = Peer(w2, LNPeerAddr("bogus2", 1337, k2.pubkey),
request_initial_sync=False, transport=t2)
w1.peer = p1
w2.peer = p2
# mark_open won't work if state is already OPEN.
# so set it to OPENING
self.alice_channel.set_state("OPENING")
self.bob_channel.set_state("OPENING")
# this populates the channel graph:
p1.mark_open(self.alice_channel)
p2.mark_open(self.bob_channel)
return p1, p2, w1, w2, q1, q2
@staticmethod
def prepare_invoice(w2 # receiver
):
amount_btc = 100000/Decimal(COIN)
payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage)
addr = LnAddr(
RHASH,
amount_btc,
tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE),
('d', 'coffee')
])
pay_req = lnencode(addr, w2.node_keypair.privkey)
w2.invoices[bh2u(RHASH)] = (bh2u(payment_preimage), pay_req)
return pay_req
@staticmethod
def prepare_ln_message_future(w2 # receiver
):
fut = asyncio.Future()
def evt_set(event, _lnworker, msg):
fut.set_result(msg)
w2.network.register_callback(evt_set, ['ln_message'])
return fut
def test_payment(self):
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers()
pay_req = self.prepare_invoice(w2)
fut = self.prepare_ln_message_future(w2)
async def pay():
addr, peer, coro = LNWorker._pay(w1, pay_req)
await coro
print("HTLC ADDED")
self.assertEqual(await fut, 'Payment received')
gath.cancel()
gath = asyncio.gather(pay(), p1._main_loop(), p2._main_loop())
with self.assertRaises(asyncio.CancelledError):
run(gath)
def test_channel_usage_after_closing(self):
p1, p2, w1, w2, q1, q2 = self.prepare_peers()
pay_req = self.prepare_invoice(w2)
addr = w1._check_invoice(pay_req)
route = w1._create_route_from_invoice(decoded_invoice=addr)
run(w1.force_close_channel(self.alice_channel.channel_id))
# check if a tx (commitment transaction) was broadcasted:
assert q1.qsize() == 1
with self.assertRaises(PaymentFailure) as e:
w1._create_route_from_invoice(decoded_invoice=addr)
self.assertEqual(str(e.exception), 'No path found')
peer = w1.peers[route[0].node_id]
# AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed
with self.assertRaises(AssertionError):
run(asyncio.gather(w1._pay_to_route(route, addr), p1._main_loop(), p2._main_loop()))
def run(coro):
asyncio.get_event_loop().run_until_complete(coro)