mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-23 17:47:31 +00:00
some more type annotations that needed conditional imports
This commit is contained in:
parent
f3d1f71e94
commit
f70e679aba
8 changed files with 65 additions and 33 deletions
|
@ -10,7 +10,7 @@ import asyncio
|
|||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import List, Tuple, Dict
|
||||
from typing import List, Tuple, Dict, TYPE_CHECKING
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
|
@ -31,10 +31,13 @@ from .lnutil import (Outpoint, LocalConfig, ChannelConfig,
|
|||
funding_output_script, get_per_commitment_secret_from_seed,
|
||||
secret_to_pubkey, LNPeerAddr, PaymentFailure, LnLocalFeatures,
|
||||
LOCAL, REMOTE, HTLCOwner, generate_keypair, LnKeyFamily,
|
||||
get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED)
|
||||
from .lnutil import LightningPeerConnectionClosed, HandshakeFailed
|
||||
get_ln_flag_pair_of_bit, privkey_to_pubkey, UnknownPaymentHash, MIN_FINAL_CLTV_EXPIRY_ACCEPTED,
|
||||
LightningPeerConnectionClosed, HandshakeFailed, LNPeerAddr)
|
||||
from .lnrouter import NotFoundChanAnnouncementForUpdate, RouteEdge
|
||||
from .lntransport import LNTransport
|
||||
from .lntransport import LNTransport, LNTransportBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lnworker import LNWorker
|
||||
|
||||
|
||||
def channel_id_from_funding_tx(funding_txid, funding_index):
|
||||
|
@ -191,7 +194,8 @@ def gen_msg(msg_type: str, **kwargs) -> bytes:
|
|||
|
||||
class Peer(PrintError):
|
||||
|
||||
def __init__(self, lnworker, peer_addr, request_initial_sync=False, transport=None):
|
||||
def __init__(self, lnworker: 'LNWorker', peer_addr: LNPeerAddr,
|
||||
request_initial_sync=False, transport: LNTransportBase=None):
|
||||
self.initialized = asyncio.Future()
|
||||
self.transport = transport
|
||||
self.peer_addr = peer_addr
|
||||
|
@ -357,7 +361,7 @@ class Peer(PrintError):
|
|||
def close_and_cleanup(self):
|
||||
try:
|
||||
if self.transport:
|
||||
self.transport.writer.close()
|
||||
self.transport.close()
|
||||
except:
|
||||
pass
|
||||
for chan in self.channels.values():
|
||||
|
|
|
@ -3,7 +3,7 @@ from collections import namedtuple, defaultdict
|
|||
import binascii
|
||||
import json
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
from typing import Optional, Mapping, List
|
||||
|
||||
from .util import bfh, PrintError, bh2u
|
||||
from .bitcoin import Hash, TYPE_SCRIPT, TYPE_ADDRESS
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiorpcx
|
||||
|
||||
|
@ -38,6 +39,10 @@ from .verifier import verify_tx_is_in_block, MerkleVerificationFailure
|
|||
from .transaction import Transaction
|
||||
from .interface import GracefulDisconnect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import Network
|
||||
from .lnrouter import ChannelDB
|
||||
|
||||
|
||||
class LNChannelVerifier(NetworkJobOnDefaultServer):
|
||||
""" Verify channel announcements for the Channel DB """
|
||||
|
@ -46,7 +51,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
|||
# will start throttling us, making it even slower. one option would be to
|
||||
# spread it over multiple servers.
|
||||
|
||||
def __init__(self, network, channel_db):
|
||||
def __init__(self, network: 'Network', channel_db: 'ChannelDB'):
|
||||
NetworkJobOnDefaultServer.__init__(self, network)
|
||||
self.channel_db = channel_db
|
||||
self.lock = threading.Lock()
|
||||
|
@ -105,7 +110,7 @@ class LNChannelVerifier(NetworkJobOnDefaultServer):
|
|||
await self.group.spawn(self.verify_channel(block_height, tx_pos, short_channel_id))
|
||||
#self.print_error('requested short_channel_id', bh2u(short_channel_id))
|
||||
|
||||
async def verify_channel(self, block_height, tx_pos, short_channel_id):
|
||||
async def verify_channel(self, block_height: int, tx_pos: int, short_channel_id: bytes):
|
||||
# we are verifying channel announcements as they are from untrusted ln peers.
|
||||
# we use electrum servers to do this. however we don't trust electrum servers either...
|
||||
try:
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
# SOFTWARE.
|
||||
|
||||
import hashlib
|
||||
from typing import Sequence, List, Tuple, NamedTuple
|
||||
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING
|
||||
from enum import IntEnum, IntFlag
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
|
||||
|
@ -34,6 +34,8 @@ from . import ecc
|
|||
from .crypto import sha256, hmac_oneshot
|
||||
from .util import bh2u, profiler, xor_bytes, bfh
|
||||
from .lnutil import get_ecdh, PaymentFailure, NUM_MAX_HOPS_IN_PAYMENT_PATH, NUM_MAX_EDGES_IN_PAYMENT_PATH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lnrouter import RouteEdge
|
||||
|
||||
|
||||
|
@ -186,7 +188,7 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes,
|
|||
hmac=next_hmac)
|
||||
|
||||
|
||||
def calc_hops_data_for_payment(route: List[RouteEdge], amount_msat: int, final_cltv: int) \
|
||||
def calc_hops_data_for_payment(route: List['RouteEdge'], amount_msat: int, final_cltv: int) \
|
||||
-> Tuple[List[OnionHopsDataSingle], int, int]:
|
||||
"""Returns the hops_data to be used for constructing an onion packet,
|
||||
and the amount_msat and cltv to be used on our immediate channel.
|
||||
|
|
|
@ -27,8 +27,8 @@ import queue
|
|||
import os
|
||||
import json
|
||||
import threading
|
||||
from collections import namedtuple, defaultdict
|
||||
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple
|
||||
from collections import defaultdict
|
||||
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING
|
||||
import binascii
|
||||
import base64
|
||||
import asyncio
|
||||
|
@ -41,6 +41,10 @@ from .crypto import Hash
|
|||
from . import ecc
|
||||
from .lnutil import LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lnchan import Channel
|
||||
from .network import Network
|
||||
|
||||
|
||||
class UnknownEvenFeatureBits(Exception): pass
|
||||
|
||||
|
@ -272,7 +276,7 @@ class ChannelDB(JsonDB):
|
|||
|
||||
NUM_MAX_RECENT_PEERS = 20
|
||||
|
||||
def __init__(self, network):
|
||||
def __init__(self, network: 'Network'):
|
||||
self.network = network
|
||||
|
||||
path = os.path.join(get_headers_dir(network.config), 'channel_db')
|
||||
|
@ -597,7 +601,7 @@ class LNPathFinder(PrintError):
|
|||
@profiler
|
||||
def find_path_for_payment(self, nodeA: bytes, nodeB: bytes,
|
||||
invoice_amount_msat: int,
|
||||
my_channels: List=None) -> Sequence[Tuple[bytes, bytes]]:
|
||||
my_channels: List['Channel']=None) -> Sequence[Tuple[bytes, bytes]]:
|
||||
"""Return a path from nodeA to nodeB.
|
||||
|
||||
Returns a list of (node_id, short_channel_id) representing a path.
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import hmac
|
||||
import hashlib
|
||||
from asyncio import StreamReader, StreamWriter
|
||||
|
||||
import cryptography.hazmat.primitives.ciphers.aead as AEAD
|
||||
|
||||
from .crypto import sha256
|
||||
from .lnutil import get_ecdh, privkey_to_pubkey
|
||||
from .lnutil import LightningPeerConnectionClosed, HandshakeFailed
|
||||
from .crypto import sha256, hmac_oneshot
|
||||
from .lnutil import (get_ecdh, privkey_to_pubkey, LightningPeerConnectionClosed,
|
||||
HandshakeFailed)
|
||||
from . import ecc
|
||||
from .util import bh2u
|
||||
|
||||
|
@ -49,13 +50,13 @@ def get_bolt8_hkdf(salt, ikm):
|
|||
Return as two 32 byte fields.
|
||||
"""
|
||||
#Extract
|
||||
prk = hmac.new(salt, msg=ikm, digestmod=hashlib.sha256).digest()
|
||||
prk = hmac_oneshot(salt, msg=ikm, digest=hashlib.sha256)
|
||||
assert len(prk) == 32
|
||||
#Expand
|
||||
info = b""
|
||||
T0 = b""
|
||||
T1 = hmac.new(prk, T0 + info + b"\x01", digestmod=hashlib.sha256).digest()
|
||||
T2 = hmac.new(prk, T1 + info + b"\x02", digestmod=hashlib.sha256).digest()
|
||||
T1 = hmac_oneshot(prk, T0 + info + b"\x01", digest=hashlib.sha256)
|
||||
T2 = hmac_oneshot(prk, T1 + info + b"\x02", digest=hashlib.sha256)
|
||||
assert len(T1 + T2) == 64
|
||||
return T1, T2
|
||||
|
||||
|
@ -76,6 +77,11 @@ def create_ephemeral_key() -> (bytes, bytes):
|
|||
return privkey.get_secret_bytes(), privkey.get_public_key_bytes()
|
||||
|
||||
class LNTransportBase:
|
||||
|
||||
def __init__(self, reader: StreamReader, writer: StreamWriter):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
|
||||
def send_bytes(self, msg):
|
||||
l = len(msg).to_bytes(2, 'big')
|
||||
lc = aead_encrypt(self.sk, self.sn(), b'', l)
|
||||
|
@ -132,11 +138,14 @@ class LNTransportBase:
|
|||
self.r_ck = ck
|
||||
self.s_ck = ck
|
||||
|
||||
def close(self):
|
||||
self.writer.close()
|
||||
|
||||
|
||||
class LNResponderTransport(LNTransportBase):
|
||||
def __init__(self, privkey, reader, writer):
|
||||
def __init__(self, privkey: bytes, reader: StreamReader, writer: StreamWriter):
|
||||
LNTransportBase.__init__(self, reader, writer)
|
||||
self.privkey = privkey
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
|
||||
async def handshake(self, **kwargs):
|
||||
hs = HandshakeState(privkey_to_pubkey(self.privkey))
|
||||
|
@ -187,12 +196,12 @@ class LNResponderTransport(LNTransportBase):
|
|||
return rs
|
||||
|
||||
class LNTransport(LNTransportBase):
|
||||
def __init__(self, privkey, remote_pubkey, reader, writer):
|
||||
def __init__(self, privkey: bytes, remote_pubkey: bytes,
|
||||
reader: StreamReader, writer: StreamWriter):
|
||||
LNTransportBase.__init__(self, reader, writer)
|
||||
assert type(privkey) is bytes and len(privkey) == 32
|
||||
self.privkey = privkey
|
||||
self.remote_pubkey = remote_pubkey
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
|
||||
async def handshake(self):
|
||||
hs = HandshakeState(self.remote_pubkey)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import threading
|
||||
from typing import NamedTuple, Iterable
|
||||
from typing import NamedTuple, Iterable, TYPE_CHECKING
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
|
@ -11,6 +11,9 @@ from . import wallet
|
|||
from .storage import WalletStorage
|
||||
from .address_synchronizer import AddressSynchronizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import Network
|
||||
|
||||
|
||||
TX_MINED_STATUS_DEEP, TX_MINED_STATUS_SHALLOW, TX_MINED_STATUS_MEMPOOL, TX_MINED_STATUS_FREE = range(0, 4)
|
||||
|
||||
|
@ -21,7 +24,7 @@ class LNWatcher(PrintError):
|
|||
# maybe we should disconnect from server in these cases
|
||||
verbosity_filter = 'W'
|
||||
|
||||
def __init__(self, network):
|
||||
def __init__(self, network: 'Network'):
|
||||
self.network = network
|
||||
self.config = network.config
|
||||
path = os.path.join(network.config.path, "watcher_db")
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
from decimal import Decimal
|
||||
import random
|
||||
import time
|
||||
from typing import Optional, Sequence, Tuple, List, Dict
|
||||
from typing import Optional, Sequence, Tuple, List, Dict, TYPE_CHECKING
|
||||
import threading
|
||||
import socket
|
||||
|
||||
|
@ -31,6 +31,11 @@ from .lnaddr import lndecode
|
|||
from .i18n import _
|
||||
from .lnrouter import RouteEdge, is_route_sane_to_use
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import Network
|
||||
from .wallet import Abstract_Wallet
|
||||
|
||||
|
||||
NUM_PEERS_TARGET = 4
|
||||
PEER_RETRY_INTERVAL = 600 # seconds
|
||||
PEER_RETRY_INTERVAL_FOR_CHANNELS = 30 # seconds
|
||||
|
@ -45,7 +50,7 @@ FALLBACK_NODE_LIST_MAINNET = (
|
|||
|
||||
class LNWorker(PrintError):
|
||||
|
||||
def __init__(self, wallet, network):
|
||||
def __init__(self, wallet: 'Abstract_Wallet', network: 'Network'):
|
||||
self.wallet = wallet
|
||||
self.sweep_address = wallet.get_receiving_address()
|
||||
self.network = network
|
||||
|
|
Loading…
Add table
Reference in a new issue