diff --git a/wallet/psbt.go b/wallet/psbt.go index 1b5b211..e2d4fe8 100644 --- a/wallet/psbt.go +++ b/wallet/psbt.go @@ -19,7 +19,9 @@ import ( // FundPsbt creates a fully populated PSBT packet that contains enough inputs to // fund the outputs specified in the passed in packet with the specified fee -// rate. If there is change left, a change output from the wallet is added. +// rate. If there is change left, a change output from the wallet is added and +// the index of the change output is returned. Otherwise no additional output +// is created and the index -1 is returned. // // NOTE: If the packet doesn't contain any inputs, coin selection is performed // automatically. If the packet does contain any inputs, it is assumed that full @@ -32,13 +34,13 @@ import ( // selected/validated inputs by this method. It is in the caller's // responsibility to lock the inputs before handing the partial transaction out. func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, - feeSatPerKB btcutil.Amount) error { + feeSatPerKB btcutil.Amount) (int32, error) { // Make sure the packet is well formed. We only require there to be at // least one output but not necessarily any inputs. err := psbt.VerifyInputOutputLen(packet, false, true) if err != nil { - return err + return 0, err } txOut := packet.UnsignedTx.TxOut @@ -53,7 +55,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, // dust. err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb) if err != nil { - return err + return 0, err } } @@ -108,7 +110,8 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, false, ) if err != nil { - return fmt.Errorf("error creating funding TX: %v", err) + return 0, fmt.Errorf("error creating funding TX: %v", + err) } // Copy over the inputs now then collect all UTXO information @@ -118,7 +121,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, packet.UnsignedTx.TxIn = tx.Tx.TxIn err = addInputInfo(tx.Tx.TxIn) if err != nil { - return err + return 0, err } // If there are inputs, we need to check if they're sufficient and add @@ -127,7 +130,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, // Make sure all inputs provided are actually ours. err = addInputInfo(txIn) if err != nil { - return err + return 0, err } // We can leverage the fee calculation of the txauthor package @@ -147,7 +150,7 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, // a new change addresse into the database. dbtx, err := w.db.BeginReadWriteTx() if err != nil { - return err + return 0, err } _, changeSource := w.addrMgrWithChangeSource(dbtx, account) @@ -159,24 +162,25 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, ) if err != nil { _ = dbtx.Rollback() - return fmt.Errorf("fee estimation not successful: %v", - err) + return 0, fmt.Errorf("fee estimation not successful: "+ + "%v", err) } // The transaction could be created, let's commit the DB TX to // store the change address (if one was created). err = dbtx.Commit() if err != nil { - return fmt.Errorf("could not add change address to "+ + return 0, fmt.Errorf("could not add change address to "+ "database: %v", err) } } // If there is a change output, we need to copy it over to the PSBT now. + var changeTxOut *wire.TxOut if tx.ChangeIndex >= 0 { + changeTxOut = tx.Tx.TxOut[tx.ChangeIndex] packet.UnsignedTx.TxOut = append( - packet.UnsignedTx.TxOut, - tx.Tx.TxOut[tx.ChangeIndex], + packet.UnsignedTx.TxOut, changeTxOut, ) packet.Outputs = append(packet.Outputs, psbt.POutput{}) } @@ -186,10 +190,22 @@ func (w *Wallet) FundPsbt(packet *psbt.Packet, account uint32, // partial inputs and outputs accordingly. err = psbt.InPlaceSort(packet) if err != nil { - return fmt.Errorf("could not sort PSBT: %v", err) + return 0, fmt.Errorf("could not sort PSBT: %v", err) } - return nil + // The change output index might have changed after the sorting. We need + // to find our index again. + changeIndex := int32(-1) + if changeTxOut != nil { + for idx, txOut := range packet.UnsignedTx.TxOut { + if psbt.TxOutsEqual(changeTxOut, txOut) { + changeIndex = int32(idx) + break + } + } + } + + return changeIndex, nil } // FinalizePsbt expects a partial transaction with all inputs and outputs fully diff --git a/wallet/psbt_test.go b/wallet/psbt_test.go index 5d77808..dfb67cd 100644 --- a/wallet/psbt_test.go +++ b/wallet/psbt_test.go @@ -142,7 +142,9 @@ func TestFundPsbt(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - err := w.FundPsbt(tc.packet, 0, tc.feeRateSatPerKB) + changeIndex, err := w.FundPsbt( + tc.packet, 0, tc.feeRateSatPerKB, + ) // Make sure the error is what we expected. if err == nil && tc.expectedErr != "" { @@ -258,9 +260,10 @@ func TestFundPsbt(t *testing.T) { } p2wkhIndex := -1 p2wshIndex := -1 - changeIndex := -1 + totalOut := int64(0) for idx, txOut := range txOuts { script := txOut.PkScript + totalOut += txOut.Value switch { case bytes.Equal(script, testScriptP2WKH): @@ -269,10 +272,12 @@ func TestFundPsbt(t *testing.T) { case bytes.Equal(script, testScriptP2WSH): p2wshIndex = idx - default: - changeIndex = idx } } + totalIn := int64(0) + for _, txIn := range packet.Inputs { + totalIn += txIn.WitnessUtxo.Value + } // All outputs must be found. if p2wkhIndex < 0 || p2wshIndex < 0 || changeIndex < 0 {