mirror of
https://github.com/LBRYFoundation/LBRY-Vault.git
synced 2025-08-23 17:47:31 +00:00
coinchooser: refactor so that penalty_func has access to change outputs
This commit is contained in:
parent
6424163d4b
commit
f409b5da40
1 changed files with 88 additions and 76 deletions
|
@ -24,7 +24,7 @@
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from math import floor, log10
|
from math import floor, log10
|
||||||
from typing import NamedTuple, List
|
from typing import NamedTuple, List, Callable
|
||||||
|
|
||||||
from .bitcoin import sha256, COIN, TYPE_ADDRESS, is_address
|
from .bitcoin import sha256, COIN, TYPE_ADDRESS, is_address
|
||||||
from .transaction import Transaction, TxOutput
|
from .transaction import Transaction, TxOutput
|
||||||
|
@ -79,6 +79,12 @@ class Bucket(NamedTuple):
|
||||||
witness: bool # whether any coin uses segwit
|
witness: bool # whether any coin uses segwit
|
||||||
|
|
||||||
|
|
||||||
|
class ScoredCandidate(NamedTuple):
|
||||||
|
penalty: float
|
||||||
|
tx: Transaction
|
||||||
|
buckets: List[Bucket]
|
||||||
|
|
||||||
|
|
||||||
def strip_unneeded(bkts, sufficient_funds):
|
def strip_unneeded(bkts, sufficient_funds):
|
||||||
'''Remove buckets that are unnecessary in achieving the spend amount'''
|
'''Remove buckets that are unnecessary in achieving the spend amount'''
|
||||||
if sufficient_funds([], bucket_value_sum=0):
|
if sufficient_funds([], bucket_value_sum=0):
|
||||||
|
@ -121,12 +127,10 @@ class CoinChooserBase(Logger):
|
||||||
|
|
||||||
return list(map(make_Bucket, buckets.keys(), buckets.values()))
|
return list(map(make_Bucket, buckets.keys(), buckets.values()))
|
||||||
|
|
||||||
def penalty_func(self, tx, *, fee_for_buckets):
|
def penalty_func(self, base_tx, *, tx_from_buckets) -> Callable[[List[Bucket]], ScoredCandidate]:
|
||||||
def penalty(candidate):
|
raise NotImplementedError
|
||||||
return 0
|
|
||||||
return penalty
|
|
||||||
|
|
||||||
def change_amounts(self, tx, count, fee_estimator, dust_threshold):
|
def _change_amounts(self, tx, count, fee_estimator):
|
||||||
# Break change up if bigger than max_change
|
# Break change up if bigger than max_change
|
||||||
output_amounts = [o.value for o in tx.outputs()]
|
output_amounts = [o.value for o in tx.outputs()]
|
||||||
# Don't split change of less than 0.02 BTC
|
# Don't split change of less than 0.02 BTC
|
||||||
|
@ -180,22 +184,60 @@ class CoinChooserBase(Logger):
|
||||||
|
|
||||||
return amounts
|
return amounts
|
||||||
|
|
||||||
def change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
|
def _change_outputs(self, tx, change_addrs, fee_estimator, dust_threshold):
|
||||||
amounts = self.change_amounts(tx, len(change_addrs), fee_estimator,
|
amounts = self._change_amounts(tx, len(change_addrs), fee_estimator)
|
||||||
dust_threshold)
|
|
||||||
assert min(amounts) >= 0
|
assert min(amounts) >= 0
|
||||||
assert len(change_addrs) >= len(amounts)
|
assert len(change_addrs) >= len(amounts)
|
||||||
# If change is above dust threshold after accounting for the
|
# If change is above dust threshold after accounting for the
|
||||||
# size of the change output, add it to the transaction.
|
# size of the change output, add it to the transaction.
|
||||||
dust = sum(amount for amount in amounts if amount < dust_threshold)
|
|
||||||
amounts = [amount for amount in amounts if amount >= dust_threshold]
|
amounts = [amount for amount in amounts if amount >= dust_threshold]
|
||||||
change = [TxOutput(TYPE_ADDRESS, addr, amount)
|
change = [TxOutput(TYPE_ADDRESS, addr, amount)
|
||||||
for addr, amount in zip(change_addrs, amounts)]
|
for addr, amount in zip(change_addrs, amounts)]
|
||||||
self.logger.info(f'change: {change}')
|
|
||||||
if dust:
|
|
||||||
self.logger.info(f'not keeping dust {dust}')
|
|
||||||
return change
|
return change
|
||||||
|
|
||||||
|
def _construct_tx_from_selected_buckets(self, *, buckets, base_tx, change_addrs,
|
||||||
|
fee_estimator_w, dust_threshold, base_weight):
|
||||||
|
# make a copy of base_tx so it won't get mutated
|
||||||
|
tx = Transaction.from_io(base_tx.inputs()[:], base_tx.outputs()[:])
|
||||||
|
|
||||||
|
tx.add_inputs([coin for b in buckets for coin in b.coins])
|
||||||
|
tx_weight = self._get_tx_weight(buckets, base_weight=base_weight)
|
||||||
|
|
||||||
|
# change is sent back to sending address unless specified
|
||||||
|
if not change_addrs:
|
||||||
|
change_addrs = [tx.inputs()[0]['address']]
|
||||||
|
# note: this is not necessarily the final "first input address"
|
||||||
|
# because the inputs had not been sorted at this point
|
||||||
|
assert is_address(change_addrs[0])
|
||||||
|
|
||||||
|
# This takes a count of change outputs and returns a tx fee
|
||||||
|
output_weight = 4 * Transaction.estimated_output_size(change_addrs[0])
|
||||||
|
fee = lambda count: fee_estimator_w(tx_weight + count * output_weight)
|
||||||
|
change = self._change_outputs(tx, change_addrs, fee, dust_threshold)
|
||||||
|
tx.add_outputs(change)
|
||||||
|
|
||||||
|
return tx, change
|
||||||
|
|
||||||
|
def _get_tx_weight(self, buckets, *, base_weight) -> int:
|
||||||
|
"""Given a collection of buckets, return the total weight of the
|
||||||
|
resulting transaction.
|
||||||
|
base_weight is the weight of the tx that includes the fixed (non-change)
|
||||||
|
outputs and potentially some fixed inputs. Note that the change outputs
|
||||||
|
at this point are not yet known so they are NOT accounted for.
|
||||||
|
"""
|
||||||
|
total_weight = base_weight + sum(bucket.weight for bucket in buckets)
|
||||||
|
is_segwit_tx = any(bucket.witness for bucket in buckets)
|
||||||
|
if is_segwit_tx:
|
||||||
|
total_weight += 2 # marker and flag
|
||||||
|
# non-segwit inputs were previously assumed to have
|
||||||
|
# a witness of '' instead of '00' (hex)
|
||||||
|
# note that mixed legacy/segwit buckets are already ok
|
||||||
|
num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
|
||||||
|
for bucket in buckets)
|
||||||
|
total_weight += num_legacy_inputs
|
||||||
|
|
||||||
|
return total_weight
|
||||||
|
|
||||||
def make_tx(self, coins, inputs, outputs, change_addrs, fee_estimator,
|
def make_tx(self, coins, inputs, outputs, change_addrs, fee_estimator,
|
||||||
dust_threshold):
|
dust_threshold):
|
||||||
"""Select unspent coins to spend to pay outputs. If the change is
|
"""Select unspent coins to spend to pay outputs. If the change is
|
||||||
|
@ -211,34 +253,20 @@ class CoinChooserBase(Logger):
|
||||||
self.p = PRNG(''.join(sorted(utxos)))
|
self.p = PRNG(''.join(sorted(utxos)))
|
||||||
|
|
||||||
# Copy the outputs so when adding change we don't modify "outputs"
|
# Copy the outputs so when adding change we don't modify "outputs"
|
||||||
tx = Transaction.from_io(inputs[:], outputs[:])
|
base_tx = Transaction.from_io(inputs[:], outputs[:])
|
||||||
input_value = tx.input_value()
|
input_value = base_tx.input_value()
|
||||||
|
|
||||||
# Weight of the transaction with no inputs and no change
|
# Weight of the transaction with no inputs and no change
|
||||||
# Note: this will use legacy tx serialization as the need for "segwit"
|
# Note: this will use legacy tx serialization as the need for "segwit"
|
||||||
# would be detected from inputs. The only side effect should be that the
|
# would be detected from inputs. The only side effect should be that the
|
||||||
# marker and flag are excluded, which is compensated in get_tx_weight()
|
# marker and flag are excluded, which is compensated in get_tx_weight()
|
||||||
# FIXME calculation will be off by this (2 wu) in case of RBF batching
|
# FIXME calculation will be off by this (2 wu) in case of RBF batching
|
||||||
base_weight = tx.estimated_weight()
|
base_weight = base_tx.estimated_weight()
|
||||||
spent_amount = tx.output_value()
|
spent_amount = base_tx.output_value()
|
||||||
|
|
||||||
def fee_estimator_w(weight):
|
def fee_estimator_w(weight):
|
||||||
return fee_estimator(Transaction.virtual_size_from_weight(weight))
|
return fee_estimator(Transaction.virtual_size_from_weight(weight))
|
||||||
|
|
||||||
def get_tx_weight(buckets):
|
|
||||||
total_weight = base_weight + sum(bucket.weight for bucket in buckets)
|
|
||||||
is_segwit_tx = any(bucket.witness for bucket in buckets)
|
|
||||||
if is_segwit_tx:
|
|
||||||
total_weight += 2 # marker and flag
|
|
||||||
# non-segwit inputs were previously assumed to have
|
|
||||||
# a witness of '' instead of '00' (hex)
|
|
||||||
# note that mixed legacy/segwit buckets are already ok
|
|
||||||
num_legacy_inputs = sum((not bucket.witness) * len(bucket.coins)
|
|
||||||
for bucket in buckets)
|
|
||||||
total_weight += num_legacy_inputs
|
|
||||||
|
|
||||||
return total_weight
|
|
||||||
|
|
||||||
def sufficient_funds(buckets, *, bucket_value_sum):
|
def sufficient_funds(buckets, *, bucket_value_sum):
|
||||||
'''Given a list of buckets, return True if it has enough
|
'''Given a list of buckets, return True if it has enough
|
||||||
value to pay for the transaction'''
|
value to pay for the transaction'''
|
||||||
|
@ -248,45 +276,30 @@ class CoinChooserBase(Logger):
|
||||||
return False
|
return False
|
||||||
# note re performance: so far this was constant time
|
# note re performance: so far this was constant time
|
||||||
# what follows is linear in len(buckets)
|
# what follows is linear in len(buckets)
|
||||||
total_weight = get_tx_weight(buckets)
|
total_weight = self._get_tx_weight(buckets, base_weight=base_weight)
|
||||||
return total_input >= spent_amount + fee_estimator_w(total_weight)
|
return total_input >= spent_amount + fee_estimator_w(total_weight)
|
||||||
|
|
||||||
def fee_for_buckets(buckets) -> int:
|
def tx_from_buckets(buckets):
|
||||||
"""Given a list of buckets, return the total fee paid by the
|
return self._construct_tx_from_selected_buckets(buckets=buckets,
|
||||||
transaction, in satoshis.
|
base_tx=base_tx,
|
||||||
Note that the change output(s) are not yet known here,
|
change_addrs=change_addrs,
|
||||||
so fees for those are excluded and hence this is a lower bound.
|
fee_estimator_w=fee_estimator_w,
|
||||||
"""
|
dust_threshold=dust_threshold,
|
||||||
total_weight = get_tx_weight(buckets)
|
base_weight=base_weight)
|
||||||
return fee_estimator_w(total_weight)
|
|
||||||
|
|
||||||
# Collect the coins into buckets, choose a subset of the buckets
|
# Collect the coins into buckets, choose a subset of the buckets
|
||||||
buckets = self.bucketize_coins(coins)
|
all_buckets = self.bucketize_coins(coins)
|
||||||
buckets = self.choose_buckets(buckets, sufficient_funds,
|
scored_candidate = self.choose_buckets(all_buckets, sufficient_funds,
|
||||||
self.penalty_func(tx, fee_for_buckets=fee_for_buckets))
|
self.penalty_func(base_tx, tx_from_buckets=tx_from_buckets))
|
||||||
|
tx = scored_candidate.tx
|
||||||
tx.add_inputs([coin for b in buckets for coin in b.coins])
|
|
||||||
tx_weight = get_tx_weight(buckets)
|
|
||||||
|
|
||||||
# change is sent back to sending address unless specified
|
|
||||||
if not change_addrs:
|
|
||||||
change_addrs = [tx.inputs()[0]['address']]
|
|
||||||
# note: this is not necessarily the final "first input address"
|
|
||||||
# because the inputs had not been sorted at this point
|
|
||||||
assert is_address(change_addrs[0])
|
|
||||||
|
|
||||||
# This takes a count of change outputs and returns a tx fee
|
|
||||||
output_weight = 4 * Transaction.estimated_output_size(change_addrs[0])
|
|
||||||
fee = lambda count: fee_estimator_w(tx_weight + count * output_weight)
|
|
||||||
change = self.change_outputs(tx, change_addrs, fee, dust_threshold)
|
|
||||||
tx.add_outputs(change)
|
|
||||||
|
|
||||||
self.logger.info(f"using {len(tx.inputs())} inputs")
|
self.logger.info(f"using {len(tx.inputs())} inputs")
|
||||||
self.logger.info(f"using buckets: {[bucket.desc for bucket in buckets]}")
|
self.logger.info(f"using buckets: {[bucket.desc for bucket in scored_candidate.buckets]}")
|
||||||
|
|
||||||
return tx
|
return tx
|
||||||
|
|
||||||
def choose_buckets(self, buckets, sufficient_funds, penalty_func):
|
def choose_buckets(self, buckets, sufficient_funds,
|
||||||
|
penalty_func: Callable[[List[Bucket]], ScoredCandidate]) -> ScoredCandidate:
|
||||||
raise NotImplemented('To be subclassed')
|
raise NotImplemented('To be subclassed')
|
||||||
|
|
||||||
|
|
||||||
|
@ -368,12 +381,14 @@ class CoinChooserRandom(CoinChooserBase):
|
||||||
|
|
||||||
def choose_buckets(self, buckets, sufficient_funds, penalty_func):
|
def choose_buckets(self, buckets, sufficient_funds, penalty_func):
|
||||||
candidates = self.bucket_candidates_prefer_confirmed(buckets, sufficient_funds)
|
candidates = self.bucket_candidates_prefer_confirmed(buckets, sufficient_funds)
|
||||||
penalties = [penalty_func(cand) for cand in candidates]
|
scored_candidates = [penalty_func(cand) for cand in candidates]
|
||||||
winner = candidates[penalties.index(min(penalties))]
|
winner = min(scored_candidates, key=lambda x: x.penalty)
|
||||||
self.logger.info(f"Bucket sets: {len(buckets)}")
|
self.logger.info(f"Total number of buckets: {len(buckets)}")
|
||||||
self.logger.info(f"Winning penalty: {min(penalties)}")
|
self.logger.info(f"Num candidates considered: {len(candidates)}. "
|
||||||
|
f"Winning penalty: {winner.penalty}")
|
||||||
return winner
|
return winner
|
||||||
|
|
||||||
|
|
||||||
class CoinChooserPrivacy(CoinChooserRandom):
|
class CoinChooserPrivacy(CoinChooserRandom):
|
||||||
"""Attempts to better preserve user privacy.
|
"""Attempts to better preserve user privacy.
|
||||||
First, if any coin is spent from a user address, all coins are.
|
First, if any coin is spent from a user address, all coins are.
|
||||||
|
@ -388,18 +403,15 @@ class CoinChooserPrivacy(CoinChooserRandom):
|
||||||
def keys(self, coins):
|
def keys(self, coins):
|
||||||
return [coin['address'] for coin in coins]
|
return [coin['address'] for coin in coins]
|
||||||
|
|
||||||
def penalty_func(self, tx, *, fee_for_buckets):
|
def penalty_func(self, base_tx, *, tx_from_buckets):
|
||||||
min_change = min(o.value for o in tx.outputs()) * 0.75
|
min_change = min(o.value for o in base_tx.outputs()) * 0.75
|
||||||
max_change = max(o.value for o in tx.outputs()) * 1.33
|
max_change = max(o.value for o in base_tx.outputs()) * 1.33
|
||||||
spent_amount = sum(o.value for o in tx.outputs())
|
|
||||||
|
|
||||||
def penalty(buckets):
|
def penalty(buckets) -> ScoredCandidate:
|
||||||
|
# Penalize using many buckets (~inputs)
|
||||||
badness = len(buckets) - 1
|
badness = len(buckets) - 1
|
||||||
total_input = sum(bucket.value for bucket in buckets)
|
tx, change_outputs = tx_from_buckets(buckets)
|
||||||
# FIXME fee_for_buckets does not include fees needed to cover the change output(s)
|
change = sum(o.value for o in change_outputs)
|
||||||
# so fee here is a lower bound
|
|
||||||
fee = fee_for_buckets(buckets)
|
|
||||||
change = float(total_input - spent_amount - fee)
|
|
||||||
# Penalize change not roughly in output range
|
# Penalize change not roughly in output range
|
||||||
if change < min_change:
|
if change < min_change:
|
||||||
badness += (min_change - change) / (min_change + 10000)
|
badness += (min_change - change) / (min_change + 10000)
|
||||||
|
@ -407,7 +419,7 @@ class CoinChooserPrivacy(CoinChooserRandom):
|
||||||
badness += (change - max_change) / (max_change + 10000)
|
badness += (change - max_change) / (max_change + 10000)
|
||||||
# Penalize large change; 5 BTC excess ~= using 1 more input
|
# Penalize large change; 5 BTC excess ~= using 1 more input
|
||||||
badness += change / (COIN * 5)
|
badness += change / (COIN * 5)
|
||||||
return badness
|
return ScoredCandidate(badness, tx, buckets)
|
||||||
|
|
||||||
return penalty
|
return penalty
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue