lnchannel: make AbstractChannel inherit ABC

and add some type annotations, clean up method signatures
This commit is contained in:
SomberNight 2020-04-13 15:57:53 +02:00
parent 821431a239
commit 8e8ab775eb
No known key found for this signature in database
GPG key ID: B33B5F232C6271E9
6 changed files with 134 additions and 71 deletions

View file

@ -27,13 +27,14 @@ from typing import (Optional, Dict, List, Tuple, NamedTuple, Set, Callable,
Iterable, Sequence, TYPE_CHECKING, Iterator, Union)
import time
import threading
from abc import ABC, abstractmethod
from aiorpcx import NetAddress
import attr
from . import ecc
from . import constants
from .util import bfh, bh2u, chunks
from .util import bfh, bh2u, chunks, TxMinedInfo
from .bitcoin import redeem_script_to_address
from .crypto import sha256, sha256d
from .transaction import Transaction, PartialTransaction
@ -113,7 +114,9 @@ state_transitions = [
del cs # delete as name is ambiguous without context
RevokeAndAck = namedtuple("RevokeAndAck", ["per_commitment_secret", "next_per_commitment_point"])
class RevokeAndAck(NamedTuple):
per_commitment_secret: bytes
next_per_commitment_point: bytes
class RemoteCtnTooFarInFuture(Exception): pass
@ -123,7 +126,16 @@ def htlcsum(htlcs):
return sum([x.amount_msat for x in htlcs])
class AbstractChannel(Logger):
class AbstractChannel(Logger, ABC):
storage: Union['StoredDict', dict]
config: Dict[HTLCOwner, Union[LocalConfig, RemoteConfig]]
_sweep_info: Dict[str, Dict[str, 'SweepInfo']]
lnworker: Optional['LNWallet']
sweep_address: str
channel_id: bytes
funding_outpoint: Outpoint
node_id: bytes
_state: channel_states
def set_short_channel_id(self, short_id: ShortChannelID) -> None:
self.short_channel_id = short_id
@ -168,7 +180,7 @@ class AbstractChannel(Logger):
def is_redeemed(self):
return self.get_state() == channel_states.REDEEMED
def save_funding_height(self, txid, height, timestamp):
def save_funding_height(self, *, txid: str, height: int, timestamp: Optional[int]) -> None:
self.storage['funding_height'] = txid, height, timestamp
def get_funding_height(self):
@ -177,7 +189,7 @@ class AbstractChannel(Logger):
def delete_funding_height(self):
self.storage.pop('funding_height', None)
def save_closing_height(self, txid, height, timestamp):
def save_closing_height(self, *, txid: str, height: int, timestamp: Optional[int]) -> None:
self.storage['closing_height'] = txid, height, timestamp
def get_closing_height(self):
@ -197,30 +209,34 @@ class AbstractChannel(Logger):
def sweep_ctx(self, ctx: Transaction) -> Dict[str, SweepInfo]:
txid = ctx.txid()
if self.sweep_info.get(txid) is None:
if self._sweep_info.get(txid) is None:
our_sweep_info = self.create_sweeptxs_for_our_ctx(ctx)
their_sweep_info = self.create_sweeptxs_for_their_ctx(ctx)
if our_sweep_info is not None:
self.sweep_info[txid] = our_sweep_info
self._sweep_info[txid] = our_sweep_info
self.logger.info(f'we force closed')
elif their_sweep_info is not None:
self.sweep_info[txid] = their_sweep_info
self._sweep_info[txid] = their_sweep_info
self.logger.info(f'they force closed.')
else:
self.sweep_info[txid] = {}
self._sweep_info[txid] = {}
self.logger.info(f'not sure who closed.')
return self.sweep_info[txid]
return self._sweep_info[txid]
# ancestor for Channel and ChannelBackup
def update_onchain_state(self, funding_txid, funding_height, closing_txid, closing_height, keep_watching):
def update_onchain_state(self, *, funding_txid: str, funding_height: TxMinedInfo,
closing_txid: str, closing_height: TxMinedInfo, keep_watching: bool) -> None:
# note: state transitions are irreversible, but
# save_funding_height, save_closing_height are reversible
if funding_height.height == TX_HEIGHT_LOCAL:
self.update_unfunded_state()
elif closing_height.height == TX_HEIGHT_LOCAL:
self.update_funded_state(funding_txid, funding_height)
self.update_funded_state(funding_txid=funding_txid, funding_height=funding_height)
else:
self.update_closed_state(funding_txid, funding_height, closing_txid, closing_height, keep_watching)
self.update_closed_state(funding_txid=funding_txid,
funding_height=funding_height,
closing_txid=closing_txid,
closing_height=closing_height,
keep_watching=keep_watching)
def update_unfunded_state(self):
self.delete_funding_height()
@ -249,8 +265,8 @@ class AbstractChannel(Logger):
if now - self.storage.get('init_timestamp', 0) > CHANNEL_OPENING_TIMEOUT:
self.lnworker.remove_channel(self.channel_id)
def update_funded_state(self, funding_txid, funding_height):
self.save_funding_height(funding_txid, funding_height.height, funding_height.timestamp)
def update_funded_state(self, *, funding_txid: str, funding_height: TxMinedInfo) -> None:
self.save_funding_height(txid=funding_txid, height=funding_height.height, timestamp=funding_height.timestamp)
self.delete_closing_height()
if funding_height.conf>0:
self.set_short_channel_id(ShortChannelID.from_components(
@ -259,9 +275,10 @@ class AbstractChannel(Logger):
if self.is_funding_tx_mined(funding_height):
self.set_state(channel_states.FUNDED)
def update_closed_state(self, funding_txid, funding_height, closing_txid, closing_height, keep_watching):
self.save_funding_height(funding_txid, funding_height.height, funding_height.timestamp)
self.save_closing_height(closing_txid, closing_height.height, closing_height.timestamp)
def update_closed_state(self, *, funding_txid: str, funding_height: TxMinedInfo,
closing_txid: str, closing_height: TxMinedInfo, keep_watching: bool) -> None:
self.save_funding_height(txid=funding_txid, height=funding_height.height, timestamp=funding_height.timestamp)
self.save_closing_height(txid=closing_txid, height=closing_height.height, timestamp=closing_height.timestamp)
if self.get_state() < channel_states.CLOSED:
conf = closing_height.conf
if conf > 0:
@ -273,6 +290,66 @@ class AbstractChannel(Logger):
if self.get_state() == channel_states.CLOSED and not keep_watching:
self.set_state(channel_states.REDEEMED)
@abstractmethod
def is_initiator(self) -> bool:
pass
@abstractmethod
def is_funding_tx_mined(self, funding_height: TxMinedInfo) -> bool:
pass
@abstractmethod
def get_funding_address(self) -> str:
pass
@abstractmethod
def get_state_for_GUI(self) -> str:
pass
@abstractmethod
def get_oldest_unrevoked_ctn(self, subject: HTLCOwner) -> int:
pass
@abstractmethod
def included_htlcs(self, subject: HTLCOwner, direction: Direction, ctn: int = None) -> Sequence[UpdateAddHtlc]:
pass
@abstractmethod
def funding_txn_minimum_depth(self) -> int:
pass
@abstractmethod
def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int:
"""This balance (in msat) only considers HTLCs that have been settled by ctn.
It disregards reserve, fees, and pending HTLCs (in both directions).
"""
pass
@abstractmethod
def balance_minus_outgoing_htlcs(self, whose: HTLCOwner, *,
ctx_owner: HTLCOwner = HTLCOwner.LOCAL,
ctn: int = None) -> int:
"""This balance (in msat), which includes the value of
pending outgoing HTLCs, is used in the UI.
"""
pass
@abstractmethod
def is_frozen_for_sending(self) -> bool:
"""Whether the user has marked this channel as frozen for sending.
Frozen channels are not supposed to be used for new outgoing payments.
(note that payment-forwarding ignores this option)
"""
pass
@abstractmethod
def is_frozen_for_receiving(self) -> bool:
"""Whether the user has marked this channel as frozen for receiving.
Frozen channels are not supposed to be used for new incoming payments.
(note that payment-forwarding ignores this option)
"""
pass
class ChannelBackup(AbstractChannel):
"""
@ -288,7 +365,7 @@ class ChannelBackup(AbstractChannel):
self.name = None
Logger.__init__(self)
self.cb = cb
self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
self._sweep_info = {}
self.sweep_address = sweep_address
self.storage = {} # dummy storage
self._state = channel_states.OPENING
@ -351,7 +428,7 @@ class ChannelBackup(AbstractChannel):
def get_oldest_unrevoked_ctn(self, who):
return -1
def included_htlcs(self, subject, direction, ctn):
def included_htlcs(self, subject, direction, ctn=None):
return []
def funding_txn_minimum_depth(self):
@ -381,16 +458,16 @@ class Channel(AbstractChannel):
def __init__(self, state: 'StoredDict', *, sweep_address=None, name=None, lnworker=None, initial_feerate=None):
self.name = name
Logger.__init__(self)
self.lnworker = lnworker # type: Optional[LNWallet]
self.lnworker = lnworker
self.sweep_address = sweep_address
self.storage = state
self.db_lock = self.storage.db.lock if self.storage.db else threading.RLock()
self.config = {} # type: Dict[HTLCOwner, Union[LocalConfig, RemoteConfig]]
self.config = {}
self.config[LOCAL] = state["local_config"]
self.config[REMOTE] = state["remote_config"]
self.channel_id = bfh(state["channel_id"])
self.constraints = state["constraints"] # type: ChannelConstraints
self.funding_outpoint = state["funding_outpoint"] # type: Outpoint
self.funding_outpoint = state["funding_outpoint"]
self.node_id = bfh(state["node_id"])
self.short_channel_id = ShortChannelID.normalize(state["short_channel_id"])
self.onion_keys = state['onion_keys'] # type: Dict[int, bytes]
@ -398,7 +475,7 @@ class Channel(AbstractChannel):
self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
self._state = channel_states[state['state']]
self.peer_state = peer_states.DISCONNECTED
self.sweep_info = {} # type: Dict[str, Dict[str, SweepInfo]]
self._sweep_info = {}
self._outgoing_channel_update = None # type: Optional[bytes]
self._chan_ann_without_sigs = None # type: Optional[bytes]
self.revocation_store = RevocationStore(state["revocation_store"])
@ -596,10 +673,6 @@ class Channel(AbstractChannel):
return self.can_send_ctx_updates() and not self.is_closing()
def is_frozen_for_sending(self) -> bool:
"""Whether the user has marked this channel as frozen for sending.
Frozen channels are not supposed to be used for new outgoing payments.
(note that payment-forwarding ignores this option)
"""
return self.storage.get('frozen_for_sending', False)
def set_frozen_for_sending(self, b: bool) -> None:
@ -608,10 +681,6 @@ class Channel(AbstractChannel):
self.lnworker.network.trigger_callback('channel', self)
def is_frozen_for_receiving(self) -> bool:
"""Whether the user has marked this channel as frozen for receiving.
Frozen channels are not supposed to be used for new incoming payments.
(note that payment-forwarding ignores this option)
"""
return self.storage.get('frozen_for_receiving', False)
def set_frozen_for_receiving(self, b: bool) -> None:
@ -880,9 +949,6 @@ class Channel(AbstractChannel):
self.lnworker.payment_failed(self, htlc.payment_hash, payment_attempt)
def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int:
"""This balance (in msat) only considers HTLCs that have been settled by ctn.
It disregards reserve, fees, and pending HTLCs (in both directions).
"""
assert type(whose) is HTLCOwner
initial = self.config[whose].initial_msat
return self.hm.get_balance_msat(whose=whose,
@ -891,10 +957,7 @@ class Channel(AbstractChannel):
initial_balance_msat=initial)
def balance_minus_outgoing_htlcs(self, whose: HTLCOwner, *, ctx_owner: HTLCOwner = HTLCOwner.LOCAL,
ctn: int = None):
"""This balance (in msat), which includes the value of
pending outgoing HTLCs, is used in the UI.
"""
ctn: int = None) -> int:
assert type(whose) is HTLCOwner
if ctn is None:
ctn = self.get_next_ctn(ctx_owner)
@ -1282,11 +1345,6 @@ class Channel(AbstractChannel):
return total_value_sat > min_value_worth_closing_channel_over_sat
def is_funding_tx_mined(self, funding_height):
"""
Checks if Funding TX has been mined. If it has, save the short channel ID in chan;
if it's also deep enough, also save to disk.
Returns tuple (mined_deep_enough, num_confirmations).
"""
funding_txid = self.funding_outpoint.txid
funding_idx = self.funding_outpoint.output_index
conf = funding_height.conf

View file

@ -21,7 +21,7 @@ from .simple_config import SimpleConfig
from .logging import get_logger, Logger
if TYPE_CHECKING:
from .lnchannel import Channel
from .lnchannel import Channel, AbstractChannel
_logger = get_logger(__name__)
@ -169,7 +169,7 @@ def create_sweeptx_for_their_revoked_htlc(chan: 'Channel', ctx: Transaction, htl
def create_sweeptxs_for_our_ctx(*, chan: 'Channel', ctx: Transaction,
def create_sweeptxs_for_our_ctx(*, chan: 'AbstractChannel', ctx: Transaction,
sweep_address: str) -> Optional[Dict[str, SweepInfo]]:
"""Handle the case where we force close unilaterally with our latest ctx.
Construct sweep txns for 'to_local', and for all HTLCs (2 txns each).

View file

@ -89,6 +89,7 @@ def create_ephemeral_key() -> (bytes, bytes):
class LNTransportBase:
reader: StreamReader
writer: StreamWriter
privkey: bytes
def name(self) -> str:
raise NotImplementedError()

View file

@ -27,7 +27,7 @@ from .bip32 import BIP32Node, BIP32_PRIME
from .transaction import BCDataStream
if TYPE_CHECKING:
from .lnchannel import Channel
from .lnchannel import Channel, AbstractChannel
from .lnrouter import LNPaymentRoute
from .lnonion import OnionRoutingFailureMessage
@ -504,8 +504,8 @@ def make_htlc_output_witness_script(is_received_htlc: bool, remote_revocation_pu
payment_hash=payment_hash)
def get_ordered_channel_configs(chan: 'Channel', for_us: bool) -> Tuple[Union[LocalConfig, RemoteConfig],
Union[LocalConfig, RemoteConfig]]:
def get_ordered_channel_configs(chan: 'AbstractChannel', for_us: bool) -> Tuple[Union[LocalConfig, RemoteConfig],
Union[LocalConfig, RemoteConfig]]:
conf = chan.config[LOCAL] if for_us else chan.config[REMOTE]
other_conf = chan.config[LOCAL] if not for_us else chan.config[REMOTE]
return conf, other_conf
@ -781,7 +781,7 @@ def extract_ctn_from_tx(tx: Transaction, txin_index: int, funder_payment_basepoi
obs = ((sequence & 0xffffff) << 24) + (locktime & 0xffffff)
return get_obscured_ctn(obs, funder_payment_basepoint, fundee_payment_basepoint)
def extract_ctn_from_tx_and_chan(tx: Transaction, chan: 'Channel') -> int:
def extract_ctn_from_tx_and_chan(tx: Transaction, chan: 'AbstractChannel') -> int:
funder_conf = chan.config[LOCAL] if chan.is_initiator() else chan.config[REMOTE]
fundee_conf = chan.config[LOCAL] if not chan.is_initiator() else chan.config[REMOTE]
return extract_ctn_from_tx(tx, txin_index=0,

View file

@ -4,20 +4,13 @@
from typing import NamedTuple, Iterable, TYPE_CHECKING
import os
import queue
import threading
import concurrent
from collections import defaultdict
import asyncio
from enum import IntEnum, auto
from typing import NamedTuple, Dict
from .sql_db import SqlDB, sql
from .wallet_db import WalletDB
from .util import bh2u, bfh, log_exceptions, ignore_exceptions
from .lnutil import Outpoint
from . import wallet
from .storage import WalletStorage
from .util import bh2u, bfh, log_exceptions, ignore_exceptions, TxMinedInfo
from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED
from .transaction import Transaction
@ -199,17 +192,22 @@ class LNWatcher(AddressSynchronizer):
else:
keep_watching = True
await self.update_channel_state(
funding_outpoint, funding_txid,
funding_height, closing_txid,
closing_height, keep_watching)
funding_outpoint=funding_outpoint,
funding_txid=funding_txid,
funding_height=funding_height,
closing_txid=closing_txid,
closing_height=closing_height,
keep_watching=keep_watching)
if not keep_watching:
await self.unwatch_channel(address, funding_outpoint)
async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders):
raise NotImplementedError() # implemented by subclasses
async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders) -> bool:
raise NotImplementedError() # implemented by subclasses
async def update_channel_state(self, *args):
raise NotImplementedError() # implemented by subclasses
async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
funding_height: TxMinedInfo, closing_txid: str,
closing_height: TxMinedInfo, keep_watching: bool) -> None:
raise NotImplementedError() # implemented by subclasses
def inspect_tx_candidate(self, outpoint, n):
prev_txid, index = outpoint.split(':')
@ -325,7 +323,7 @@ class WatchTower(LNWatcher):
if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set()
async def update_channel_state(self, *args):
async def update_channel_state(self, *args, **kwargs):
pass
@ -340,17 +338,23 @@ class LNWalletWatcher(LNWatcher):
@ignore_exceptions
@log_exceptions
async def update_channel_state(self, funding_outpoint, funding_txid, funding_height, closing_txid, closing_height, keep_watching):
async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
funding_height: TxMinedInfo, closing_txid: str,
closing_height: TxMinedInfo, keep_watching: bool) -> None:
chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan:
return
chan.update_onchain_state(funding_txid, funding_height, closing_txid, closing_height, keep_watching)
chan.update_onchain_state(funding_txid=funding_txid,
funding_height=funding_height,
closing_txid=closing_txid,
closing_height=closing_height,
keep_watching=keep_watching)
await self.lnworker.on_channel_update(chan)
async def do_breach_remedy(self, funding_outpoint, closing_tx, spenders):
chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan:
return
return False
# detect who closed and set sweep_info
sweep_info_dict = chan.sweep_ctx(closing_tx)
keep_watching = False if sweep_info_dict else not self.is_deeply_mined(closing_tx.txid())

View file

@ -432,7 +432,7 @@ class LNWallet(LNWorker):
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
self.sweep_address = wallet.get_receiving_address()
self.lock = threading.RLock()
self.logs = defaultdict(list) # (not persisted) type: Dict[str, List[PaymentAttemptLog]] # key is RHASH
self.logs = defaultdict(list) # type: Dict[str, List[PaymentAttemptLog]] # key is RHASH # (not persisted)
self.is_routing = set() # (not persisted) keys of invoices that are in PR_ROUTING state
# used in tests
self.enable_htlc_settle = asyncio.Event()