mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-27 15:31:31 +00:00
replace await_local/remote
This commit is contained in:
parent
dae842e2ad
commit
b9eaba3e85
6 changed files with 48 additions and 70 deletions
|
@ -152,6 +152,7 @@ class Channel(Logger):
|
|||
self._chan_ann_without_sigs = None # type: Optional[bytes]
|
||||
self.revocation_store = RevocationStore(state["revocation_store"])
|
||||
self._can_send_ctx_updates = True # type: bool
|
||||
self._receive_fail_reasons = {}
|
||||
|
||||
def get_id_for_log(self) -> str:
|
||||
scid = self.short_channel_id
|
||||
|
@ -562,11 +563,15 @@ class Channel(Logger):
|
|||
self.hm.send_rev()
|
||||
received = self.hm.received_in_ctn(new_ctn)
|
||||
sent = self.hm.sent_in_ctn(new_ctn)
|
||||
failed = self.hm.failed_in_ctn(new_ctn)
|
||||
if self.lnworker:
|
||||
for htlc in received:
|
||||
self.lnworker.payment_completed(self, RECEIVED, htlc)
|
||||
for htlc in sent:
|
||||
self.lnworker.payment_completed(self, SENT, htlc)
|
||||
for htlc in failed:
|
||||
reason = self._receive_fail_reasons.get(htlc.htlc_id)
|
||||
self.lnworker.payment_failed(htlc.payment_hash, reason)
|
||||
received_this_batch = htlcsum(received)
|
||||
sent_this_batch = htlcsum(sent)
|
||||
last_secret, last_point = self.get_secret_and_point(LOCAL, new_ctn - 1)
|
||||
|
@ -575,12 +580,10 @@ class Channel(Logger):
|
|||
|
||||
def receive_revocation(self, revocation: RevokeAndAck):
|
||||
self.logger.info("receive_revocation")
|
||||
|
||||
cur_point = self.config[REMOTE].current_per_commitment_point
|
||||
derived_point = ecc.ECPrivkey(revocation.per_commitment_secret).get_public_key_bytes(compressed=True)
|
||||
if cur_point != derived_point:
|
||||
raise Exception('revoked secret not for current point')
|
||||
|
||||
with self.db_lock:
|
||||
self.revocation_store.add_next_entry(revocation.per_commitment_secret)
|
||||
##### start applying fee/htlc changes
|
||||
|
@ -763,10 +766,11 @@ class Channel(Logger):
|
|||
with self.db_lock:
|
||||
self.hm.send_fail(htlc_id)
|
||||
|
||||
def receive_fail_htlc(self, htlc_id):
|
||||
def receive_fail_htlc(self, htlc_id, reason):
|
||||
self.logger.info("receive_fail_htlc")
|
||||
with self.db_lock:
|
||||
self.hm.recv_fail(htlc_id)
|
||||
self._receive_fail_reasons[htlc_id] = reason
|
||||
|
||||
def pending_local_fee(self):
|
||||
return self.constraints.capacity - sum(x.value for x in self.get_next_commitment(LOCAL).outputs())
|
||||
|
|
|
@ -298,6 +298,11 @@ class HTLCManager:
|
|||
for htlc_id, ctns in self.log[LOCAL]['settles'].items()
|
||||
if ctns[LOCAL] == ctn]
|
||||
|
||||
def failed_in_ctn(self, ctn: int) -> Sequence[UpdateAddHtlc]:
|
||||
return [self.log[LOCAL]['adds'][htlc_id]
|
||||
for htlc_id, ctns in self.log[LOCAL]['fails'].items()
|
||||
if ctns[LOCAL] == ctn]
|
||||
|
||||
##### Queries re Fees:
|
||||
|
||||
def get_feerate(self, subject: HTLCOwner, ctn: int) -> int:
|
||||
|
|
|
@ -93,8 +93,6 @@ class Peer(Logger):
|
|||
self.shutdown_received = {}
|
||||
self.announcement_signatures = defaultdict(asyncio.Queue)
|
||||
self.orphan_channel_updates = OrderedDict()
|
||||
self._local_changed_events = defaultdict(asyncio.Event)
|
||||
self._remote_changed_events = defaultdict(asyncio.Event)
|
||||
Logger.__init__(self)
|
||||
self.taskgroup = SilentTaskGroup()
|
||||
|
||||
|
@ -1006,16 +1004,8 @@ class Peer(Logger):
|
|||
reason = payload["reason"]
|
||||
chan = self.channels[channel_id]
|
||||
self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
|
||||
chan.receive_fail_htlc(htlc_id)
|
||||
local_ctn = chan.get_latest_ctn(LOCAL)
|
||||
asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn, reason))
|
||||
|
||||
@log_exceptions
|
||||
async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn, reason):
|
||||
chan = self.channels[channel_id]
|
||||
await self.await_local(chan, local_ctn)
|
||||
payment_hash = chan.get_payment_hash(htlc_id)
|
||||
self.lnworker.payment_failed(payment_hash, reason)
|
||||
chan.receive_fail_htlc(htlc_id, reason)
|
||||
self.maybe_send_commitment(chan)
|
||||
|
||||
def maybe_send_commitment(self, chan: Channel):
|
||||
# REMOTE should revoke first before we can sign a new ctx
|
||||
|
@ -1028,27 +1018,9 @@ class Peer(Logger):
|
|||
sig_64, htlc_sigs = chan.sign_next_commitment()
|
||||
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
|
||||
|
||||
async def await_remote(self, chan: Channel, ctn: int):
|
||||
"""Wait until remote 'ctn' gets revoked."""
|
||||
# if 'ctn' is too high, we risk waiting "forever", hence assert:
|
||||
assert chan.get_latest_ctn(REMOTE) >= ctn, (chan.get_latest_ctn(REMOTE), ctn)
|
||||
self.maybe_send_commitment(chan)
|
||||
while chan.get_oldest_unrevoked_ctn(REMOTE) <= ctn:
|
||||
await self._remote_changed_events[chan.channel_id].wait()
|
||||
|
||||
async def await_local(self, chan: Channel, ctn: int):
|
||||
"""Wait until local 'ctn' gets revoked."""
|
||||
# if 'ctn' is too high, we risk waiting "forever", hence assert:
|
||||
assert chan.get_latest_ctn(LOCAL) >= ctn, (chan.get_latest_ctn(LOCAL), ctn)
|
||||
self.maybe_send_commitment(chan)
|
||||
while chan.get_oldest_unrevoked_ctn(LOCAL) <= ctn:
|
||||
await self._local_changed_events[chan.channel_id].wait()
|
||||
|
||||
async def pay(self, route: 'LNPaymentRoute', chan: Channel, amount_msat: int,
|
||||
payment_hash: bytes, min_final_cltv_expiry: int) -> UpdateAddHtlc:
|
||||
def pay(self, route: 'LNPaymentRoute', chan: Channel, amount_msat: int,
|
||||
payment_hash: bytes, min_final_cltv_expiry: int) -> UpdateAddHtlc:
|
||||
assert amount_msat > 0, "amount_msat is not greater zero"
|
||||
await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT)
|
||||
# TODO also wait for channel reestablish to finish. (combine timeout with waiting for init?)
|
||||
if not chan.can_send_update_add_htlc():
|
||||
raise PaymentFailure("Channel cannot send update_add_htlc")
|
||||
# create onion packet
|
||||
|
@ -1060,25 +1032,23 @@ class Peer(Logger):
|
|||
# create htlc
|
||||
htlc = UpdateAddHtlc(amount_msat=amount_msat, payment_hash=payment_hash, cltv_expiry=cltv, timestamp=int(time.time()))
|
||||
htlc = chan.add_htlc(htlc)
|
||||
remote_ctn = chan.get_latest_ctn(REMOTE)
|
||||
chan.set_onion_key(htlc.htlc_id, secret_key)
|
||||
self.logger.info(f"starting payment. len(route)={len(route)}. route: {route}. htlc: {htlc}")
|
||||
self.send_message("update_add_htlc",
|
||||
channel_id=chan.channel_id,
|
||||
id=htlc.htlc_id,
|
||||
cltv_expiry=htlc.cltv_expiry,
|
||||
amount_msat=htlc.amount_msat,
|
||||
payment_hash=htlc.payment_hash,
|
||||
onion_routing_packet=onion.to_bytes())
|
||||
await self.await_remote(chan, remote_ctn)
|
||||
self.send_message(
|
||||
"update_add_htlc",
|
||||
channel_id=chan.channel_id,
|
||||
id=htlc.htlc_id,
|
||||
cltv_expiry=htlc.cltv_expiry,
|
||||
amount_msat=htlc.amount_msat,
|
||||
payment_hash=htlc.payment_hash,
|
||||
onion_routing_packet=onion.to_bytes())
|
||||
self.maybe_send_commitment(chan)
|
||||
return htlc
|
||||
|
||||
def send_revoke_and_ack(self, chan: Channel):
|
||||
self.logger.info(f'send_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(LOCAL)}')
|
||||
rev, _ = chan.revoke_current_commitment()
|
||||
self.lnworker.save_channel(chan)
|
||||
self._local_changed_events[chan.channel_id].set()
|
||||
self._local_changed_events[chan.channel_id].clear()
|
||||
self.send_message("revoke_and_ack",
|
||||
channel_id=chan.channel_id,
|
||||
per_commitment_secret=rev.per_commitment_secret,
|
||||
|
@ -1113,13 +1083,7 @@ class Peer(Logger):
|
|||
self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
|
||||
chan.receive_htlc_settle(preimage, htlc_id)
|
||||
self.lnworker.save_preimage(payment_hash, preimage)
|
||||
local_ctn = chan.get_latest_ctn(LOCAL)
|
||||
asyncio.ensure_future(self._on_update_fulfill_htlc(chan, local_ctn, payment_hash))
|
||||
|
||||
@log_exceptions
|
||||
async def _on_update_fulfill_htlc(self, chan, local_ctn, payment_hash):
|
||||
await self.await_local(chan, local_ctn)
|
||||
self.lnworker.payment_sent(payment_hash)
|
||||
self.maybe_send_commitment(chan)
|
||||
|
||||
def on_update_fail_malformed_htlc(self, payload):
|
||||
self.logger.info(f"on_update_fail_malformed_htlc. error {payload['data'].decode('ascii')}")
|
||||
|
@ -1272,8 +1236,6 @@ class Peer(Logger):
|
|||
self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}')
|
||||
rev = RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"])
|
||||
chan.receive_revocation(rev)
|
||||
self._remote_changed_events[chan.channel_id].set()
|
||||
self._remote_changed_events[chan.channel_id].clear()
|
||||
self.lnworker.save_channel(chan)
|
||||
self.maybe_send_commitment(chan)
|
||||
|
||||
|
@ -1303,11 +1265,11 @@ class Peer(Logger):
|
|||
self.logger.info(f"(chan: {chan.get_id_for_log()}) current pending feerate {chan_fee}. "
|
||||
f"new feerate {feerate_per_kw}")
|
||||
chan.update_fee(feerate_per_kw, True)
|
||||
remote_ctn = chan.get_latest_ctn(REMOTE)
|
||||
self.send_message("update_fee",
|
||||
channel_id=chan.channel_id,
|
||||
feerate_per_kw=feerate_per_kw)
|
||||
await self.await_remote(chan, remote_ctn)
|
||||
self.send_message(
|
||||
"update_fee",
|
||||
channel_id=chan.channel_id,
|
||||
feerate_per_kw=feerate_per_kw)
|
||||
self.maybe_send_commitment(chan)
|
||||
|
||||
@log_exceptions
|
||||
async def close_channel(self, chan_id: bytes):
|
||||
|
@ -1351,9 +1313,8 @@ class Peer(Logger):
|
|||
scriptpubkey = bfh(bitcoin.address_to_script(chan.sweep_address))
|
||||
# wait until no more pending updates (bolt2)
|
||||
chan.set_can_send_ctx_updates(False)
|
||||
ctn = chan.get_latest_ctn(REMOTE)
|
||||
if chan.has_pending_changes(REMOTE):
|
||||
await self.await_remote(chan, ctn)
|
||||
while chan.has_pending_changes(REMOTE):
|
||||
await asyncio.sleep(0.1)
|
||||
self.send_message('shutdown', channel_id=chan.channel_id, len=len(scriptpubkey), scriptpubkey=scriptpubkey)
|
||||
chan.set_state(channel_states.CLOSING)
|
||||
# can fullfill or fail htlcs. cannot add htlcs, because of CLOSING state
|
||||
|
|
|
@ -523,6 +523,8 @@ class LNWallet(LNWorker):
|
|||
preimage = self.get_preimage(htlc.payment_hash)
|
||||
timestamp = int(time.time())
|
||||
self.network.trigger_callback('ln_payment_completed', timestamp, direction, htlc, preimage, chan_id)
|
||||
if direction == SENT:
|
||||
self.payment_sent(htlc.payment_hash)
|
||||
|
||||
def get_settled_payments(self):
|
||||
# return one item per payment_hash
|
||||
|
@ -952,7 +954,8 @@ class LNWallet(LNWorker):
|
|||
peer = self.peers.get(route[0].node_id)
|
||||
if not peer:
|
||||
raise Exception('Dropped peer')
|
||||
htlc = await peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry())
|
||||
await peer.initialized
|
||||
htlc = peer.pay(route, chan, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry())
|
||||
self.network.trigger_callback('htlc_added', htlc, lnaddr, SENT)
|
||||
success, preimage, reason = await self.await_payment(lnaddr.paymenthash)
|
||||
if success:
|
||||
|
@ -1207,12 +1210,16 @@ class LNWallet(LNWorker):
|
|||
|
||||
def payment_failed(self, payment_hash: bytes, reason):
|
||||
self.set_payment_status(payment_hash, PR_UNPAID)
|
||||
self.pending_payments[payment_hash].set_result((False, None, reason))
|
||||
f = self.pending_payments[payment_hash]
|
||||
if not f.cancelled():
|
||||
f.set_result((False, None, reason))
|
||||
|
||||
def payment_sent(self, payment_hash: bytes):
|
||||
self.set_payment_status(payment_hash, PR_PAID)
|
||||
preimage = self.get_preimage(payment_hash)
|
||||
self.pending_payments[payment_hash].set_result((True, preimage, None))
|
||||
f = self.pending_payments[payment_hash]
|
||||
if not f.cancelled():
|
||||
f.set_result((True, preimage, None))
|
||||
|
||||
def payment_received(self, payment_hash: bytes):
|
||||
self.set_payment_status(payment_hash, PR_PAID)
|
||||
|
|
|
@ -618,7 +618,7 @@ class TestAvailableToSpend(ElectrumTestCase):
|
|||
bob_idx = bob_channel.receive_htlc(htlc_dict).htlc_id
|
||||
force_state_transition(alice_channel, bob_channel)
|
||||
bob_channel.fail_htlc(bob_idx)
|
||||
alice_channel.receive_fail_htlc(alice_idx)
|
||||
alice_channel.receive_fail_htlc(alice_idx, None)
|
||||
# Alice now has gotten all her original balance (5 BTC) back, however,
|
||||
# adding a new HTLC at this point SHOULD fail, since if she adds the
|
||||
# HTLC and signs the next state, Bob cannot assume she received the
|
||||
|
|
|
@ -131,6 +131,7 @@ class MockLNWallet:
|
|||
_create_route_from_invoice = LNWallet._create_route_from_invoice
|
||||
_check_invoice = staticmethod(LNWallet._check_invoice)
|
||||
_pay_to_route = LNWallet._pay_to_route
|
||||
_pay = LNWallet._pay
|
||||
force_close_channel = LNWallet.force_close_channel
|
||||
get_first_timestamp = lambda self: 0
|
||||
payment_completed = LNWallet.payment_completed
|
||||
|
@ -250,7 +251,7 @@ class TestPeer(ElectrumTestCase):
|
|||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
pay_req = self.prepare_invoice(w2)
|
||||
async def pay():
|
||||
result = await LNWallet._pay(w1, pay_req)
|
||||
result = await w1._pay(pay_req)
|
||||
self.assertEqual(result, True)
|
||||
gath.cancel()
|
||||
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||
|
@ -282,7 +283,7 @@ class TestPeer(ElectrumTestCase):
|
|||
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
|
||||
pay_req = self.prepare_invoice(w2)
|
||||
async def pay():
|
||||
result = await LNWallet._pay(w1, pay_req)
|
||||
result = await w1._pay(pay_req)
|
||||
self.assertTrue(result)
|
||||
gath.cancel()
|
||||
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
|
||||
|
@ -306,7 +307,7 @@ class TestPeer(ElectrumTestCase):
|
|||
await asyncio.wait_for(p2.initialized, 1)
|
||||
# alice sends htlc
|
||||
route = await w1._create_route_from_invoice(decoded_invoice=lnaddr)
|
||||
htlc = await p1.pay(route, alice_channel, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry())
|
||||
htlc = p1.pay(route, alice_channel, int(lnaddr.amount * COIN * 1000), lnaddr.paymenthash, lnaddr.get_min_final_cltv_expiry())
|
||||
# alice closes
|
||||
await p1.close_channel(alice_channel.channel_id)
|
||||
gath.cancel()
|
||||
|
|
Loading…
Add table
Reference in a new issue