mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-23 17:47:31 +00:00
move connection string decoding to lnworker, fix test_lnutil
This commit is contained in:
parent
086acc43a2
commit
4b2fa6a828
5 changed files with 133 additions and 71 deletions
|
@ -746,9 +746,9 @@ class Commands:
|
|||
|
||||
# lightning network commands
|
||||
@command('wpn')
|
||||
def open_channel(self, node_id, amount, channel_push=0, password=None):
|
||||
f = self.wallet.lnworker.open_channel(bytes.fromhex(node_id), satoshis(amount), satoshis(channel_push), password)
|
||||
return f.result()
|
||||
def open_channel(self, connection_string, amount, channel_push=0, password=None):
|
||||
f = self.wallet.lnworker.open_channel(connection_string, satoshis(amount), satoshis(channel_push), password)
|
||||
return f.result(5)
|
||||
|
||||
@command('wn')
|
||||
def reestablish_channel(self):
|
||||
|
|
|
@ -5,8 +5,7 @@ from PyQt5.QtWidgets import *
|
|||
from electrum.util import inv_dict, bh2u, bfh
|
||||
from electrum.i18n import _
|
||||
from electrum.lnhtlc import HTLCStateMachine
|
||||
from electrum.lnaddr import lndecode
|
||||
from electrum.lnutil import LOCAL, REMOTE
|
||||
from electrum.lnutil import LOCAL, REMOTE, ConnStringFormatError
|
||||
|
||||
from .util import MyTreeWidget, SortableTreeWidgetItem, WindowModalDialog, Buttons, OkButton, CancelButton
|
||||
from .amountedit import BTCAmountEdit
|
||||
|
@ -108,55 +107,12 @@ class ChannelsList(MyTreeWidget):
|
|||
return
|
||||
local_amt = local_amt_inp.get_amount()
|
||||
push_amt = push_amt_inp.get_amount()
|
||||
connect_contents = str(remote_nodeid.text())
|
||||
nodeid_hex, rest = self.parse_connect_contents(connect_contents)
|
||||
try:
|
||||
node_id = bfh(nodeid_hex)
|
||||
assert len(node_id) == 33
|
||||
except:
|
||||
self.parent.show_error(_('Invalid node ID, must be 33 bytes and hexadecimal'))
|
||||
return
|
||||
connect_contents = str(remote_nodeid.text()).strip()
|
||||
|
||||
peer = lnworker.peers.get(node_id)
|
||||
if not peer:
|
||||
all_nodes = self.parent.network.channel_db.nodes
|
||||
node_info = all_nodes.get(node_id, None)
|
||||
if rest is not None:
|
||||
try:
|
||||
host, port = rest.split(":")
|
||||
except ValueError:
|
||||
self.parent.show_error(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
|
||||
return
|
||||
elif node_info:
|
||||
host, port = node_info.addresses[0]
|
||||
else:
|
||||
self.parent.show_error(_('Unknown node:') + ' ' + nodeid_hex)
|
||||
return
|
||||
try:
|
||||
int(port)
|
||||
except:
|
||||
self.parent.show_error(_('Port number must be decimal'))
|
||||
return
|
||||
lnworker.add_peer(host, port, node_id)
|
||||
|
||||
self.main_window.protect(self.open_channel, (node_id, local_amt, push_amt))
|
||||
|
||||
@classmethod
|
||||
def parse_connect_contents(cls, connect_contents: str):
|
||||
rest = None
|
||||
try:
|
||||
# connection string?
|
||||
nodeid_hex, rest = connect_contents.split("@")
|
||||
except ValueError:
|
||||
try:
|
||||
# invoice?
|
||||
invoice = lndecode(connect_contents)
|
||||
nodeid_bytes = invoice.pubkey.serialize()
|
||||
nodeid_hex = bh2u(nodeid_bytes)
|
||||
except:
|
||||
# node id as hex?
|
||||
nodeid_hex = connect_contents
|
||||
return nodeid_hex, rest
|
||||
self.main_window.protect(self.open_channel, (connect_contents, local_amt, push_amt))
|
||||
except ConnStringFormatError as e:
|
||||
self.parent.show_error(str(e))
|
||||
|
||||
def open_channel(self, *args, **kwargs):
|
||||
self.parent.wallet.lnworker.open_channel(*args, **kwargs)
|
||||
|
|
|
@ -2,6 +2,7 @@ from enum import IntFlag
|
|||
import json
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple, List, Tuple
|
||||
import re
|
||||
|
||||
from .util import bfh, bh2u, inv_dict
|
||||
from .crypto import sha256
|
||||
|
@ -11,6 +12,7 @@ from . import ecc, bitcoin, crypto, transaction
|
|||
from .transaction import opcodes, TxOutput
|
||||
from .bitcoin import push_script
|
||||
from . import segwit_addr
|
||||
from .i18n import _
|
||||
|
||||
HTLC_TIMEOUT_WEIGHT = 663
|
||||
HTLC_SUCCESS_WEIGHT = 703
|
||||
|
@ -478,3 +480,48 @@ def make_closing_tx(local_funding_pubkey: bytes, remote_funding_pubkey: bytes,
|
|||
c_input['sequence'] = 0xFFFF_FFFF
|
||||
tx = Transaction.from_io([c_input], outputs, locktime=0, version=2)
|
||||
return tx
|
||||
|
||||
class ConnStringFormatError(Exception):
|
||||
pass
|
||||
|
||||
def split_host_port(host_port: str) -> Tuple[str, str]: # port returned as string
|
||||
ipv6 = re.compile(r'\[(?P<host>[:0-9]+)\](?P<port>:\d+)?$')
|
||||
other = re.compile(r'(?P<host>[^:]+)(?P<port>:\d+)?$')
|
||||
m = ipv6.match(host_port)
|
||||
if not m:
|
||||
m = other.match(host_port)
|
||||
if not m:
|
||||
raise ConnStringFormatError(_('Connection strings must be in <node_pubkey>@<host>:<port> format'))
|
||||
host = m.group('host')
|
||||
if m.group('port'):
|
||||
port = m.group('port')[1:]
|
||||
else:
|
||||
port = '9735'
|
||||
try:
|
||||
int(port)
|
||||
except ValueError:
|
||||
raise ConnStringFormatError(_('Port number must be decimal'))
|
||||
return host, port
|
||||
|
||||
def extract_nodeid(connect_contents: str) -> Tuple[bytes, str]:
|
||||
rest = None
|
||||
try:
|
||||
# connection string?
|
||||
nodeid_hex, rest = connect_contents.split("@", 1)
|
||||
except ValueError:
|
||||
try:
|
||||
# invoice?
|
||||
invoice = lndecode(connect_contents)
|
||||
nodeid_bytes = invoice.pubkey.serialize()
|
||||
nodeid_hex = bh2u(nodeid_bytes)
|
||||
except:
|
||||
# node id as hex?
|
||||
nodeid_hex = connect_contents
|
||||
if rest == '':
|
||||
raise ConnStringFormatError(_('At least a hostname must be supplied after the at symbol.'))
|
||||
try:
|
||||
node_id = bfh(nodeid_hex)
|
||||
assert len(node_id) == 33
|
||||
except:
|
||||
raise ConnStringFormatError(_('Invalid node ID, must be 33 bytes and hexadecimal'))
|
||||
return node_id, rest
|
||||
|
|
|
@ -3,9 +3,10 @@ import os
|
|||
from decimal import Decimal
|
||||
import random
|
||||
import time
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Sequence, Tuple, List
|
||||
import threading
|
||||
from functools import partial
|
||||
import socket
|
||||
|
||||
import dns.resolver
|
||||
import dns.exception
|
||||
|
@ -17,8 +18,10 @@ from .lnbase import Peer, privkey_to_pubkey, aiosafe
|
|||
from .lnaddr import lnencode, LnAddr, lndecode
|
||||
from .ecc import der_sig_from_sig_string
|
||||
from .lnhtlc import HTLCStateMachine
|
||||
from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr, get_compressed_pubkey_from_bech32,
|
||||
PaymentFailure)
|
||||
from .lnutil import (Outpoint, calc_short_channel_id, LNPeerAddr,
|
||||
get_compressed_pubkey_from_bech32, extract_nodeid,
|
||||
PaymentFailure, split_host_port, ConnStringFormatError)
|
||||
from electrum.lnaddr import lndecode
|
||||
from .i18n import _
|
||||
|
||||
|
||||
|
@ -30,7 +33,6 @@ FALLBACK_NODE_LIST = (
|
|||
LNPeerAddr('ecdsa.net', 9735, bfh('038370f0e7a03eded3e1d41dc081084a87f0afa1c5b22090b4f3abb391eb15d8ff')),
|
||||
)
|
||||
|
||||
|
||||
class LNWorker(PrintError):
|
||||
|
||||
def __init__(self, wallet, network):
|
||||
|
@ -89,6 +91,7 @@ class LNWorker(PrintError):
|
|||
asyncio.run_coroutine_threadsafe(self.network.main_taskgroup.spawn(peer.main_loop()), self.network.asyncio_loop)
|
||||
self.peers[node_id] = peer
|
||||
self.network.trigger_callback('ln_status')
|
||||
return peer
|
||||
|
||||
def save_channel(self, openchannel):
|
||||
assert type(openchannel) is HTLCStateMachine
|
||||
|
@ -154,8 +157,10 @@ class LNWorker(PrintError):
|
|||
conf = self.wallet.get_tx_height(chan.funding_outpoint.txid).conf
|
||||
peer.on_network_update(chan, conf)
|
||||
|
||||
async def _open_channel_coroutine(self, node_id, local_amount_sat, push_sat, password):
|
||||
peer = self.peers[node_id]
|
||||
async def _open_channel_coroutine(self, peer, local_amount_sat, push_sat, password):
|
||||
# peer might just have been connected to
|
||||
await asyncio.wait_for(peer.initialized, 5)
|
||||
|
||||
openingchannel = await peer.channel_establishment_flow(self.wallet, self.config, password,
|
||||
funding_sat=local_amount_sat + push_sat,
|
||||
push_msat=push_sat * 1000,
|
||||
|
@ -171,8 +176,34 @@ class LNWorker(PrintError):
|
|||
def on_channels_updated(self):
|
||||
self.network.trigger_callback('channels')
|
||||
|
||||
def open_channel(self, node_id, local_amt_sat, push_amt_sat, pw):
|
||||
coro = self._open_channel_coroutine(node_id, local_amt_sat, push_amt_sat, None if pw == "" else pw)
|
||||
@staticmethod
|
||||
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
|
||||
for host, port in addr_list:
|
||||
if is_ip_address(host):
|
||||
return host, port
|
||||
# TODO maybe filter out onion if not on tor?
|
||||
self.print_error('Chose random address from ' + str(node_info.addresses))
|
||||
return random.choice(node_info.addresses)
|
||||
|
||||
def open_channel(self, connect_contents, local_amt_sat, push_amt_sat, pw):
|
||||
node_id, rest = extract_nodeid(connect_contents)
|
||||
|
||||
peer = self.peers.get(node_id)
|
||||
if not peer:
|
||||
all_nodes = self.network.channel_db.nodes
|
||||
node_info = all_nodes.get(node_id, None)
|
||||
if rest is not None:
|
||||
host, port = split_host_port(rest)
|
||||
elif node_info and len(node_info.addresses) > 0:
|
||||
host, port = self.choose_preferred_address(node_info.addresses)
|
||||
else:
|
||||
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
|
||||
try:
|
||||
socket.getaddrinfo(host, int(port))
|
||||
except socket.gaierror:
|
||||
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
|
||||
peer = self.add_peer(host, port, node_id)
|
||||
coro = self._open_channel_coroutine(peer, local_amt_sat, push_amt_sat, None if pw == "" else pw)
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
|
||||
|
||||
def pay(self, invoice, amount_sat=None):
|
||||
|
@ -262,7 +293,7 @@ class LNWorker(PrintError):
|
|||
if node is None: continue
|
||||
addresses = node.addresses
|
||||
if not addresses: continue
|
||||
host, port = random.choice(addresses)
|
||||
host, port = self.choose_preferred_address(addresses)
|
||||
peer = LNPeerAddr(host, port, node_id)
|
||||
if peer.pubkey in self.peers: continue
|
||||
if peer in self._last_tried_peer: continue
|
||||
|
|
|
@ -5,7 +5,9 @@ from electrum.lnutil import (RevocationStore, get_per_commitment_secret_from_see
|
|||
make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output,
|
||||
make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, derive_privkey,
|
||||
derive_pubkey, make_htlc_tx, extract_ctn_from_tx, UnableToDeriveSecret,
|
||||
get_compressed_pubkey_from_bech32)
|
||||
get_compressed_pubkey_from_bech32, split_host_port, ConnStringFormatError,
|
||||
ScriptHtlc, extract_nodeid)
|
||||
from electrum import lnhtlc
|
||||
from electrum.util import bh2u, bfh
|
||||
from electrum.transaction import Transaction
|
||||
|
||||
|
@ -488,13 +490,14 @@ class TestLNUtil(unittest.TestCase):
|
|||
remote_signature = "304402204fd4928835db1ccdfc40f5c78ce9bd65249b16348df81f0c44328dcdefc97d630220194d3869c38bc732dd87d13d2958015e2fc16829e74cd4377f84d215c0b70606"
|
||||
output_commit_tx = "02000000000101bef67e4e2fb9ddeeb3461973cd4c62abb35050b1add772995b820b584a488489000000000038b02b8007e80300000000000022002052bfef0479d7b293c27e0f1eb294bea154c63a3294ef092c19af51409bce0e2ad007000000000000220020403d394747cae42e98ff01734ad5c08f82ba123d3d9a620abda88989651e2ab5d007000000000000220020748eba944fedc8827f6b06bc44678f93c0f9e6078b35c6331ed31e75f8ce0c2db80b000000000000220020c20b5d1f8584fd90443e7b7b720136174fa4b9333c261d04dbbd012635c0f419a00f0000000000002200208c48d15160397c9731df9bc3b236656efb6665fbfe92b4a6878e88a499f741c4c0c62d0000000000160014ccf1af2f2aabee14bb40fa3851ab2301de843110e0a06a00000000002200204adb4e2f00643db396dd120d4e7dc17625f5f2c11a40d857accc862d6b7dd80e04004730440220275b0c325a5e9355650dc30c0eccfbc7efb23987c24b556b9dfdd40effca18d202206caceb2c067836c51f296740c7ae807ffcbfbf1dd3a0d56b6de9a5b247985f060147304402204fd4928835db1ccdfc40f5c78ce9bd65249b16348df81f0c44328dcdefc97d630220194d3869c38bc732dd87d13d2958015e2fc16829e74cd4377f84d215c0b7060601475221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae3e195220"
|
||||
|
||||
htlc_msat = {}
|
||||
htlc_msat[0] = 1000 * 1000
|
||||
htlc_msat[2] = 2000 * 1000
|
||||
htlc_msat[1] = 2000 * 1000
|
||||
htlc_msat[3] = 3000 * 1000
|
||||
htlc_msat[4] = 4000 * 1000
|
||||
htlcs = [(htlc[x], htlc_msat[x]) for x in range(5)]
|
||||
htlc_obj = {}
|
||||
for num, msat in [(0, 1000 * 1000),
|
||||
(2, 2000 * 1000),
|
||||
(1, 2000 * 1000),
|
||||
(3, 3000 * 1000),
|
||||
(4, 4000 * 1000)]:
|
||||
htlc_obj[num] = lnhtlc.UpdateAddHtlc(amount_msat=msat, payment_hash=bitcoin.sha256(htlc_payment_preimage[num]), cltv_expiry=None, htlc_id=None)
|
||||
htlcs = [ScriptHtlc(htlc[x], htlc_obj[x]) for x in range(5)]
|
||||
|
||||
our_commit_tx = make_commitment(
|
||||
commitment_number,
|
||||
|
@ -531,7 +534,7 @@ class TestLNUtil(unittest.TestCase):
|
|||
|
||||
for i in range(5):
|
||||
self.assertEqual(output_htlc_tx[i][1], self.htlc_tx(htlc[i], htlc_output_index[i],
|
||||
htlc_msat[i],
|
||||
htlcs[i].htlc.amount_msat,
|
||||
htlc_payment_preimage[i],
|
||||
signature_for_output_remote_htlc[i],
|
||||
output_htlc_tx[i][0], htlc_cltv_timeout[i] if not output_htlc_tx[i][0] else 0,
|
||||
|
@ -680,3 +683,28 @@ class TestLNUtil(unittest.TestCase):
|
|||
def test_get_compressed_pubkey_from_bech32(self):
|
||||
self.assertEqual(b'\x03\x84\xef\x87\xd9d\xa2\xaaa7=\xff\xb8\xfe=t8[}>;\n\x13\xa8e\x8eo:\xf5Mi\xb5H',
|
||||
get_compressed_pubkey_from_bech32('ln1qwzwlp7evj325cfh8hlm3l3awsu9klf78v9p82r93ehn4a2ddx65s66awg5'))
|
||||
|
||||
def test_split_host_port(self):
|
||||
self.assertEqual(split_host_port("[::1]:8000"), ("::1", "8000"))
|
||||
self.assertEqual(split_host_port("[::1]"), ("::1", "9735"))
|
||||
self.assertEqual(split_host_port("kæn.guru:8000"), ("kæn.guru", "8000"))
|
||||
self.assertEqual(split_host_port("kæn.guru"), ("kæn.guru", "9735"))
|
||||
self.assertEqual(split_host_port("127.0.0.1:8000"), ("127.0.0.1", "8000"))
|
||||
self.assertEqual(split_host_port("127.0.0.1"), ("127.0.0.1", "9735"))
|
||||
# accepted by getaddrinfo but not ipaddress.ip_address
|
||||
self.assertEqual(split_host_port("127.0.0:8000"), ("127.0.0", "8000"))
|
||||
self.assertEqual(split_host_port("127.0.0"), ("127.0.0", "9735"))
|
||||
self.assertEqual(split_host_port("electrum.org:8000"), ("electrum.org", "8000"))
|
||||
self.assertEqual(split_host_port("electrum.org"), ("electrum.org", "9735"))
|
||||
|
||||
with self.assertRaises(ConnStringFormatError):
|
||||
split_host_port("electrum.org:8000:")
|
||||
with self.assertRaises(ConnStringFormatError):
|
||||
split_host_port("electrum.org:")
|
||||
|
||||
def test_extract_nodeid(self):
|
||||
with self.assertRaises(ConnStringFormatError):
|
||||
extract_nodeid("00" * 32 + "@localhost")
|
||||
with self.assertRaises(ConnStringFormatError):
|
||||
extract_nodeid("00" * 33 + "@")
|
||||
self.assertEqual(extract_nodeid("00" * 33 + "@localhost"), (b"\x00" * 33, "localhost"))
|
||||
|
|
Loading…
Add table
Reference in a new issue