mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-23 17:47:31 +00:00
some more type annotations that needed conditional imports
This commit is contained in:
parent
f3d1f71e94
commit
f70e679aba
8 changed files with 65 additions and 33 deletions
|
@ -10,7 +10,7 @@ import asyncio
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Tuple, Dict
|
from typing import List, Tuple, Dict, TYPE_CHECKING
|
||||||
import traceback
|
import traceback
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -31,10 +31,13 @@ from .lnutil import (Outpoint, LocalConfig, ChannelConfig,
|
||||||
funding_output_script, get_per_commitment_secret_from_seed,
|
funding_output_script, get_per_commitment_secret_from_seed,
|
||||||
secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures,
|
secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures,
|
||||||
LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily,
|
LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily,
|
||||||
get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED)
|
get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED,
|
||||||
from .lnutil import LightningPeerConnectionClosed, HandshakeFailed
|
LightningPeerConnectionClosed, HandshakeFailed, LNPeerAddr)
|
||||||
from .lnrouter import NotFoundChanAnnouncementForUpdate, RouteEdge
|
from .lnrouter import NotFoundChanAnnouncementForUpdate, RouteEdge
|
||||||
from .lntransport import LNTransport
|
from .lntransport import LNTransport, LNTransportBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .lnworker import LNWorker
|
||||||
|
|
||||||
|
|
||||||
def channel_id_from_funding_tx(funding_txid, funding_index):
|
def channel_id_from_funding_tx(funding_txid, funding_index):
|
||||||
|
@ -191,7 +194,8 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
|
||||||
|
|
||||||
class Peer(PrintError):
|
class Peer(PrintError):
|
||||||
|
|
||||||
def __init__(self, lnworker, peer_addr, request_initial_sync=False, transport=None):
|
def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr,
|
||||||
|
request_initial_sync=False, transport: LNTransportBase=None):
|
||||||
self.initialized = asyncio.Future()
|
self.initialized = asyncio.Future()
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.peer_addr = peer_addr
|
self.peer_addr = peer_addr
|
||||||
|
@ -357,7 +361,7 @@ class Peer(PrintError):
|
||||||
def close_and_cleanup(self):
|
def close_and_cleanup(self):
|
||||||
try:
|
try:
|
||||||
if self.transport:
|
if self.transport:
|
||||||
self.transport.writer.close()
|
self.transport.close()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
for chan in self.channels.values():
|
for chan in self.channels.values():
|
||||||
|
|
|
@ -3,7 +3,7 @@ from collections import namedtuple, defaultdict
|
||||||
import binascii
|
import binascii
|
||||||
import json
|
import json
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Optional
|
from typing import Optional, Mapping, List
|
||||||
|
|
||||||
from .util import bfh, PrintError, bh2u
|
from .util import bfh, PrintError, bh2u
|
||||||
from .bitcoin import Hash, TYPE_SCRIPT, TYPE_ADDRESS
|
from .bitcoin import Hash, TYPE_SCRIPT, TYPE_ADDRESS
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import aiorpcx
|
import aiorpcx
|
||||||
|
|
||||||
|
@ -38,6 +39,10 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
|
||||||
from .transaction import Transaction
|
from .transaction import Transaction
|
||||||
from .interface import GracefulDisconnect
|
from .interface import GracefulDisconnect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .network import Network
|
||||||
|
from .lnrouter import ChannelDB
|
||||||
|
|
||||||
|
|
||||||
class LNChannelVerifier(NetworkJobOnDefaultServer):
|
class LNChannelVerifier(NetworkJobOnDefaultServer):
|
||||||
""" Verify channel announcements for the Channel DB """
|
""" Verify channel announcements for the Channel DB """
|
||||||
|
@ -46,7 +51,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
||||||
# will start throttling us, making it even slower. one option would be to
|
# will start throttling us, making it even slower. one option would be to
|
||||||
# spread it over multiple servers.
|
# spread it over multiple servers.
|
||||||
|
|
||||||
def __init__(self, network, channel_db):
|
def __init__(self, network: 'Network', channel_db: 'ChannelDB'):
|
||||||
NetworkJobOnDefaultServer.__init__(self, network)
|
NetworkJobOnDefaultServer.__init__(self, network)
|
||||||
self.channel_db = channel_db
|
self.channel_db = channel_db
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
@ -105,7 +110,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
||||||
await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id))
|
await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id))
|
||||||
#self.print_error('requested short_channel_id', bh2u(short_channel_id))
|
#self.print_error('requested short_channel_id', bh2u(short_channel_id))
|
||||||
|
|
||||||
async def verify_channel(self, block_height, tx_pos, short_channel_id):
|
async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes):
|
||||||
# we are verifying channel announcements as they are from untrusted ln peers.
|
# we are verifying channel announcements as they are from untrusted ln peers.
|
||||||
# we use electrum servers to do this. however we don't trust electrum servers either...
|
# we use electrum servers to do this. however we don't trust electrum servers either...
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Sequence, List, Tuple, NamedTuple
|
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING
|
||||||
from enum import IntEnum, IntFlag
|
from enum import IntEnum, IntFlag
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
|
||||||
|
@ -34,7 +34,9 @@ from . import ecc
|
||||||
from .crypto import sha256, hmac_oneshot
|
from .crypto import sha256, hmac_oneshot
|
||||||
from .util import bh2u, profiler, xor_bytes, bfh
|
from .util import bh2u, profiler, xor_bytes, bfh
|
||||||
from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH
|
from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH
|
||||||
from .lnrouter import RouteEdge
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .lnrouter import RouteEdge
|
||||||
|
|
||||||
|
|
||||||
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
|
HOPS_DATA_SIZE = 1300 # also sometimes called routingInfoSize in bolt-04
|
||||||
|
@ -186,7 +188,7 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes,
|
||||||
hmac=next_hmac)
|
hmac=next_hmac)
|
||||||
|
|
||||||
|
|
||||||
def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_cltv: int) \
|
def calc_hops_data_for_payment(route: List['RouteEdge'], amount_msat: int, final_cltv: int) \
|
||||||
-> Tuple[List[OnionHopsDataSingle], int, int]:
|
-> Tuple[List[OnionHopsDataSingle], int, int]:
|
||||||
"""Returns the hops_data to be used for constructing an onion packet,
|
"""Returns the hops_data to be used for constructing an onion packet,
|
||||||
and the amount_msat and cltv to be used on our immediate channel.
|
and the amount_msat and cltv to be used on our immediate channel.
|
||||||
|
|
|
@ -27,8 +27,8 @@ import queue
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple, defaultdict
|
from collections import defaultdict
|
||||||
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple
|
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING
|
||||||
import binascii
|
import binascii
|
||||||
import base64
|
import base64
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -41,6 +41,10 @@ from .crypto import Hash
|
||||||
from . import ecc
|
from . import ecc
|
||||||
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH
|
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .lnchan import Channel
|
||||||
|
from .network import Network
|
||||||
|
|
||||||
|
|
||||||
class UnknownEvenFeatureBits(Exception): pass
|
class UnknownEvenFeatureBits(Exception): pass
|
||||||
|
|
||||||
|
@ -272,7 +276,7 @@ class ChannelDB(JsonDB):
|
||||||
|
|
||||||
NUM_MAX_RECENT_PEERS = 20
|
NUM_MAX_RECENT_PEERS = 20
|
||||||
|
|
||||||
def __init__(self, network):
|
def __init__(self, network: 'Network'):
|
||||||
self.network = network
|
self.network = network
|
||||||
|
|
||||||
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
||||||
|
@ -597,7 +601,7 @@ class LNPathFinder(PrintError):
|
||||||
@profiler
|
@profiler
|
||||||
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
||||||
invoice_amount_msat: int,
|
invoice_amount_msat: int,
|
||||||
my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
|
my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
|
||||||
"""Return a path from nodeA to nodeB.
|
"""Return a path from nodeA to nodeB.
|
||||||
|
|
||||||
Returns a list of (node_id, short_channel_id) representing a path.
|
Returns a list of (node_id, short_channel_id) representing a path.
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import hmac
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from asyncio import StreamReader, StreamWriter
|
||||||
|
|
||||||
import cryptography.hazmat.primitives.ciphers.aead as AEAD
|
import cryptography.hazmat.primitives.ciphers.aead as AEAD
|
||||||
|
|
||||||
from .crypto import sha256
|
from .crypto import sha256, hmac_oneshot
|
||||||
from .lnutil import get_ecdh, privkey_to_pubkey
|
from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
|
||||||
from .lnutil import LightningPeerConnectionClosed, HandshakeFailed
|
HandshakeFailed)
|
||||||
from . import ecc
|
from . import ecc
|
||||||
from .util import bh2u
|
from .util import bh2u
|
||||||
|
|
||||||
|
@ -49,13 +50,13 @@ def get_bolt8_hkdf(salt, ikm):
|
||||||
Return as two 32 byte fields.
|
Return as two 32 byte fields.
|
||||||
"""
|
"""
|
||||||
#Extract
|
#Extract
|
||||||
prk = hmac.new(salt, msg=ikm, digestmod=hashlib.sha256).digest()
|
prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256)
|
||||||
assert len(prk) == 32
|
assert len(prk) == 32
|
||||||
#Expand
|
#Expand
|
||||||
info = b""
|
info = b""
|
||||||
T0 = b""
|
T0 = b""
|
||||||
T1 = hmac.new(prk, T0 + info + b"\x01", digestmod=hashlib.sha256).digest()
|
T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256)
|
||||||
T2 = hmac.new(prk, T1 + info + b"\x02", digestmod=hashlib.sha256).digest()
|
T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256)
|
||||||
assert len(T1 + T2) == 64
|
assert len(T1 + T2) == 64
|
||||||
return T1, T2
|
return T1, T2
|
||||||
|
|
||||||
|
@ -76,6 +77,11 @@ def create_ephemeral_key() -> (bytes, bytes):
|
||||||
return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
|
return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
|
||||||
|
|
||||||
class LNTransportBase:
|
class LNTransportBase:
|
||||||
|
|
||||||
|
def __init__(self, reader: StreamReader, writer: StreamWriter):
|
||||||
|
self.reader = reader
|
||||||
|
self.writer = writer
|
||||||
|
|
||||||
def send_bytes(self, msg):
|
def send_bytes(self, msg):
|
||||||
l = len(msg).to_bytes(2, 'big')
|
l = len(msg).to_bytes(2, 'big')
|
||||||
lc = aead_encrypt(self.sk, self.sn(), b'', l)
|
lc = aead_encrypt(self.sk, self.sn(), b'', l)
|
||||||
|
@ -132,11 +138,14 @@ class LNTransportBase:
|
||||||
self.r_ck = ck
|
self.r_ck = ck
|
||||||
self.s_ck = ck
|
self.s_ck = ck
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
|
||||||
class LNResponderTransport(LNTransportBase):
|
class LNResponderTransport(LNTransportBase):
|
||||||
def __init__(self, privkey, reader, writer):
|
def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
|
||||||
|
LNTransportBase.__init__(self, reader, writer)
|
||||||
self.privkey = privkey
|
self.privkey = privkey
|
||||||
self.reader = reader
|
|
||||||
self.writer = writer
|
|
||||||
|
|
||||||
async def handshake(self, **kwargs):
|
async def handshake(self, **kwargs):
|
||||||
hs = HandshakeState(privkey_to_pubkey(self.privkey))
|
hs = HandshakeState(privkey_to_pubkey(self.privkey))
|
||||||
|
@ -187,12 +196,12 @@ class LNResponderTransport(LNTransportBase):
|
||||||
return rs
|
return rs
|
||||||
|
|
||||||
class LNTransport(LNTransportBase):
|
class LNTransport(LNTransportBase):
|
||||||
def __init__(self, privkey, remote_pubkey, reader, writer):
|
def __init__(self, privkey: bytes, remote_pubkey: bytes,
|
||||||
|
reader: StreamReader, writer: StreamWriter):
|
||||||
|
LNTransportBase.__init__(self, reader, writer)
|
||||||
assert type(privkey) is bytes and len(privkey) == 32
|
assert type(privkey) is bytes and len(privkey) == 32
|
||||||
self.privkey = privkey
|
self.privkey = privkey
|
||||||
self.remote_pubkey = remote_pubkey
|
self.remote_pubkey = remote_pubkey
|
||||||
self.reader = reader
|
|
||||||
self.writer = writer
|
|
||||||
|
|
||||||
async def handshake(self):
|
async def handshake(self):
|
||||||
hs = HandshakeState(self.remote_pubkey)
|
hs = HandshakeState(self.remote_pubkey)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import threading
|
import threading
|
||||||
from typing import NamedTuple, Iterable
|
from typing import NamedTuple, Iterable, TYPE_CHECKING
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -11,6 +11,9 @@ from . import wallet
|
||||||
from .storage import WalletStorage
|
from .storage import WalletStorage
|
||||||
from .address_synchronizer import AddressSynchronizer
|
from .address_synchronizer import AddressSynchronizer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .network import Network
|
||||||
|
|
||||||
|
|
||||||
TX_MINED_STATUS_DEEP, TX_MINED_STATUS_SHALLOW, TX_MINED_STATUS_MEMPOOL, TX_MINED_STATUS_FREE = range(0, 4)
|
TX_MINED_STATUS_DEEP, TX_MINED_STATUS_SHALLOW, TX_MINED_STATUS_MEMPOOL, TX_MINED_STATUS_FREE = range(0, 4)
|
||||||
|
|
||||||
|
@ -21,7 +24,7 @@ class LNWatcher(PrintError):
|
||||||
# maybe we should disconnect from server in these cases
|
# maybe we should disconnect from server in these cases
|
||||||
verbosity_filter = 'W'
|
verbosity_filter = 'W'
|
||||||
|
|
||||||
def __init__(self, network):
|
def __init__(self, network: 'Network'):
|
||||||
self.network = network
|
self.network = network
|
||||||
self.config = network.config
|
self.config = network.config
|
||||||
path = os.path.join(network.config.path, "watcher_db")
|
path = os.path.join(network.config.path, "watcher_db")
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Sequence, Tuple, List, Dict
|
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
|
||||||
import threading
|
import threading
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
|
@ -31,6 +31,11 @@ from .lnaddr import lndecode
|
||||||
from .i18n import _
|
from .i18n import _
|
||||||
from .lnrouter import RouteEdge, is_route_sane_to_use
|
from .lnrouter import RouteEdge, is_route_sane_to_use
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .network import Network
|
||||||
|
from .wallet import Abstract_Wallet
|
||||||
|
|
||||||
|
|
||||||
NUM_PEERS_TARGET = 4
|
NUM_PEERS_TARGET = 4
|
||||||
PEER_RETRY_INTERVAL = 600 # seconds
|
PEER_RETRY_INTERVAL = 600 # seconds
|
||||||
PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds
|
PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds
|
||||||
|
@ -45,7 +50,7 @@ FALLBACK_NODE_LIST_MAINNET = (
|
||||||
|
|
||||||
class LNWorker(PrintError):
|
class LNWorker(PrintError):
|
||||||
|
|
||||||
def __init__(self, wallet, network):
|
def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'):
|
||||||
self.wallet = wallet
|
self.wallet = wallet
|
||||||
self.sweep_address = wallet.get_receiving_address()
|
self.sweep_address = wallet.get_receiving_address()
|
||||||
self.network = network
|
self.network = network
|
||||||
|
|
Loading…
Add table
Reference in a new issue