mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-09-01 09:45:18 +00:00
lnhtlc: save settled htlc amounts separately
This commit is contained in:
parent
d480d52c1c
commit
7cb76fee84
2 changed files with 24 additions and 36 deletions
|
@ -63,7 +63,7 @@ class FeeUpdate:
|
|||
return self.rate
|
||||
# implicit return None
|
||||
|
||||
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'settled', 'locked_in', 'htlc_id'])):
|
||||
class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash', 'cltv_expiry', 'locked_in', 'htlc_id'])):
|
||||
__slots__ = ()
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
|
@ -71,7 +71,6 @@ class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash',
|
|||
if type(args[1]) is str:
|
||||
args[1] = bfh(args[1])
|
||||
args[3] = {HTLCOwner(int(x)): y for x,y in args[3].items()}
|
||||
args[4] = {HTLCOwner(int(x)): y for x,y in args[4].items()}
|
||||
return super().__new__(cls, *args)
|
||||
if type(kwargs['payment_hash']) is str:
|
||||
kwargs['payment_hash'] = bfh(kwargs['payment_hash'])
|
||||
|
@ -79,10 +78,6 @@ class UpdateAddHtlc(namedtuple('UpdateAddHtlc', ['amount_msat', 'payment_hash',
|
|||
kwargs['locked_in'] = {LOCAL: None, REMOTE: None}
|
||||
else:
|
||||
kwargs['locked_in'] = {HTLCOwner(int(x)): y for x,y in kwargs['locked_in']}
|
||||
if 'settled' not in kwargs:
|
||||
kwargs['settled'] = {LOCAL: None, REMOTE: None}
|
||||
else:
|
||||
kwargs['settled'] = {HTLCOwner(int(x)): y for x,y in kwargs['settled']}
|
||||
return super().__new__(cls, **kwargs)
|
||||
|
||||
is_key = lambda k: k.endswith("_basepoint") or k.endswith("_key")
|
||||
|
@ -176,6 +171,8 @@ class HTLCStateMachine(PrintError):
|
|||
|
||||
self.lnwatcher = None
|
||||
|
||||
self.settled = {LOCAL: state.get('settled_local', []), REMOTE: state.get('settled_remote', [])}
|
||||
|
||||
def set_state(self, state: str):
|
||||
self._state = state
|
||||
|
||||
|
@ -429,10 +426,14 @@ class HTLCStateMachine(PrintError):
|
|||
"""
|
||||
old_amount = self.htlcsum(self.gen_htlc_indices(subject, False))
|
||||
|
||||
removed = []
|
||||
for x in self.log[-subject]:
|
||||
if type(x) is not SettleHtlc: continue
|
||||
htlc = self.lookup_htlc(self.log[subject], x.htlc_id)
|
||||
htlc.settled[subject] = self.current_height[subject]
|
||||
self.settled[subject].append(htlc.amount_msat)
|
||||
self.log[subject].remove(htlc)
|
||||
removed.append(x)
|
||||
for x in removed: self.log[-subject].remove(x)
|
||||
|
||||
return old_amount - self.htlcsum(self.gen_htlc_indices(subject, False))
|
||||
|
||||
|
@ -465,15 +466,8 @@ class HTLCStateMachine(PrintError):
|
|||
def balance(self, subject):
|
||||
initial = self.local_config.initial_msat if subject == LOCAL else self.remote_config.initial_msat
|
||||
|
||||
for direction in (SENT, RECEIVED):
|
||||
for x in self.log[-direction]:
|
||||
if type(x) is not SettleHtlc: continue
|
||||
htlc = self.lookup_htlc(self.log[direction], x.htlc_id)
|
||||
htlc_height = htlc.settled[direction]
|
||||
if htlc_height is not None and htlc_height <= self.current_height[direction]:
|
||||
# so we will subtract when direction == subject.
|
||||
# example subject=LOCAL, direction=SENT: we subtract
|
||||
initial -= htlc.amount_msat * subject * direction
|
||||
initial -= sum(self.settled[subject])
|
||||
initial += sum(self.settled[-subject])
|
||||
|
||||
assert initial == (self.local_state.amount_msat if subject == LOCAL else self.remote_state.amount_msat)
|
||||
return initial
|
||||
|
@ -528,11 +522,10 @@ class HTLCStateMachine(PrintError):
|
|||
_, this_point, _ = self.points
|
||||
return self.make_commitment(LOCAL, this_point)
|
||||
|
||||
@property
|
||||
def total_msat(self):
|
||||
return {LOCAL: self.htlcsum(self.gen_htlc_indices(LOCAL, False, True)), REMOTE: self.htlcsum(self.gen_htlc_indices(REMOTE, False, True))}
|
||||
def total_msat(self, sub):
|
||||
return sum(self.settled[sub])
|
||||
|
||||
def gen_htlc_indices(self, subject, only_pending, include_settled=False):
|
||||
def gen_htlc_indices(self, subject, only_pending):
|
||||
"""
|
||||
only_pending: require the htlc's settlement to be pending (needs additional signatures/acks)
|
||||
include_settled: include settled (totally done with) htlcs
|
||||
|
@ -543,17 +536,10 @@ class HTLCStateMachine(PrintError):
|
|||
for htlc in update_log:
|
||||
if type(htlc) is not UpdateAddHtlc:
|
||||
continue
|
||||
height = self.current_height[-subject]
|
||||
locked_in = htlc.locked_in[subject]
|
||||
|
||||
if locked_in is None or only_pending == (SettleHtlc(htlc.htlc_id) in other_log):
|
||||
continue
|
||||
|
||||
settled_cutoff = self.local_state.ctn if subject == LOCAL else self.remote_state.ctn
|
||||
|
||||
if not include_settled and htlc.settled[subject] is not None and settled_cutoff >= htlc.settled[subject]:
|
||||
continue
|
||||
|
||||
res.append(htlc)
|
||||
return res
|
||||
|
||||
|
@ -651,6 +637,8 @@ class HTLCStateMachine(PrintError):
|
|||
"remote_log": [(type(x).__name__, x) for x in remote_filtered],
|
||||
"local_log": [(type(x).__name__, x) for x in local_filtered],
|
||||
"onion_keys": {str(k): bh2u(v) for k, v in self.onion_keys.items()},
|
||||
"settled_local": self.settled[LOCAL],
|
||||
"settled_remote": self.settled[REMOTE],
|
||||
}
|
||||
|
||||
# htlcs number must be monotonically increasing,
|
||||
|
|
|
@ -201,10 +201,10 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
|
|||
aliceSent = 0
|
||||
bobSent = 0
|
||||
|
||||
self.assertEqual(alice_channel.total_msat[SENT], aliceSent, "alice has incorrect milli-satoshis sent")
|
||||
self.assertEqual(alice_channel.total_msat[RECEIVED], bobSent, "alice has incorrect milli-satoshis received")
|
||||
self.assertEqual(bob_channel.total_msat[SENT], bobSent, "bob has incorrect milli-satoshis sent")
|
||||
self.assertEqual(bob_channel.total_msat[RECEIVED], aliceSent, "bob has incorrect milli-satoshis received")
|
||||
self.assertEqual(alice_channel.total_msat(SENT), aliceSent, "alice has incorrect milli-satoshis sent")
|
||||
self.assertEqual(alice_channel.total_msat(RECEIVED), bobSent, "alice has incorrect milli-satoshis received")
|
||||
self.assertEqual(bob_channel.total_msat(SENT), bobSent, "bob has incorrect milli-satoshis sent")
|
||||
self.assertEqual(bob_channel.total_msat(RECEIVED), aliceSent, "bob has incorrect milli-satoshis received")
|
||||
self.assertEqual(bob_channel.local_state.ctn, 1, "bob has incorrect commitment height")
|
||||
self.assertEqual(alice_channel.local_state.ctn, 1, "alice has incorrect commitment height")
|
||||
|
||||
|
@ -242,10 +242,10 @@ class TestLNBaseHTLCStateMachine(unittest.TestCase):
|
|||
# should show 1 BTC received. They should also be at commitment height
|
||||
# two, with the revocation window extended by 1 (5).
|
||||
mSatTransferred = one_bitcoin_in_msat
|
||||
self.assertEqual(alice_channel.total_msat[SENT], mSatTransferred, "alice satoshis sent incorrect")
|
||||
self.assertEqual(alice_channel.total_msat[RECEIVED], 0, "alice satoshis received incorrect")
|
||||
self.assertEqual(bob_channel.total_msat[RECEIVED], mSatTransferred, "bob satoshis received incorrect")
|
||||
self.assertEqual(bob_channel.total_msat[SENT], 0, "bob satoshis sent incorrect")
|
||||
self.assertEqual(alice_channel.total_msat(SENT), mSatTransferred, "alice satoshis sent incorrect")
|
||||
self.assertEqual(alice_channel.total_msat(RECEIVED), 0, "alice satoshis received incorrect")
|
||||
self.assertEqual(bob_channel.total_msat(RECEIVED), mSatTransferred, "bob satoshis received incorrect")
|
||||
self.assertEqual(bob_channel.total_msat(SENT), 0, "bob satoshis sent incorrect")
|
||||
self.assertEqual(bob_channel.current_height[LOCAL], 2, "bob has incorrect commitment height")
|
||||
self.assertEqual(alice_channel.current_height[LOCAL], 2, "alice has incorrect commitment height")
|
||||
|
||||
|
@ -348,7 +348,7 @@ class TestLNHTLCDust(unittest.TestCase):
|
|||
alice_channel.receive_htlc_settle(paymentPreimage, aliceHtlcIndex)
|
||||
force_state_transition(bob_channel, alice_channel)
|
||||
self.assertEqual(len(alice_channel.local_commitment.outputs()), 2)
|
||||
self.assertEqual(alice_channel.total_msat[SENT] // 1000, htlcAmt)
|
||||
self.assertEqual(alice_channel.total_msat(SENT) // 1000, htlcAmt)
|
||||
|
||||
def force_state_transition(chanA, chanB):
|
||||
chanB.receive_new_commitment(*chanA.sign_next_commitment())
|
||||
|
|
Loading…
Add table
Reference in a new issue