From 0b85b11dcc3bc6fac444529545cece109d3c70e5 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Fri, 19 Jun 2020 14:44:50 +0200 Subject: [PATCH] psbt: add BIP 69 in-place sort --- psbt/sort.go | 102 ++++++++++++++++++++++++++++ psbt/sort_test.go | 167 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 psbt/sort.go create mode 100644 psbt/sort_test.go diff --git a/psbt/sort.go b/psbt/sort.go new file mode 100644 index 0000000..2232d68 --- /dev/null +++ b/psbt/sort.go @@ -0,0 +1,102 @@ +package psbt + +import ( + "bytes" + "sort" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// InPlaceSort modifies the passed packet's wire TX inputs and outputs to be +// sorted based on BIP 69. The sorting happens in a way that the packet's +// partial inputs and outputs are also modified to match the sorted TxIn and +// TxOuts of the wire transaction. +// +// WARNING: This function must NOT be called with packages that already contain +// (partial) witness data since it will mutate the transaction if it's not +// already sorted. This can cause issues if you mutate a tx in a block, for +// example, which would invalidate the block. It could also cause cached hashes, +// such as in a btcutil.Tx to become invalidated. +// +// The function should only be used if the caller is creating the transaction or +// is otherwise 100% positive mutating will not cause adverse affects due to +// other dependencies. +func InPlaceSort(packet *Packet) error { + // To make sure we don't run into any nil pointers or array index + // violations during sorting, do a very basic sanity check first. + err := VerifyInputOutputLen(packet, false, false) + if err != nil { + return err + } + + sort.Sort(&sortableInputs{p: packet}) + sort.Sort(&sortableOutputs{p: packet}) + + return nil +} + +// sortableInputs is a simple wrapper around a packet that implements the +// sort.Interface for sorting the wire and partial inputs of a packet. +type sortableInputs struct { + p *Packet +} + +// sortableOutputs is a simple wrapper around a packet that implements the +// sort.Interface for sorting the wire and partial outputs of a packet. +type sortableOutputs struct { + p *Packet +} + +// For sortableInputs and sortableOutputs, three functions are needed to make +// them sortable with sort.Sort() -- Len, Less, and Swap. +// Len and Swap are trivial. Less is BIP 69 specific. +func (s *sortableInputs) Len() int { return len(s.p.UnsignedTx.TxIn) } +func (s sortableOutputs) Len() int { return len(s.p.UnsignedTx.TxOut) } + +// Swap swaps two inputs. +func (s *sortableInputs) Swap(i, j int) { + tx := s.p.UnsignedTx + tx.TxIn[i], tx.TxIn[j] = tx.TxIn[j], tx.TxIn[i] + s.p.Inputs[i], s.p.Inputs[j] = s.p.Inputs[j], s.p.Inputs[i] +} + +// Swap swaps two outputs. +func (s *sortableOutputs) Swap(i, j int) { + tx := s.p.UnsignedTx + tx.TxOut[i], tx.TxOut[j] = tx.TxOut[j], tx.TxOut[i] + s.p.Outputs[i], s.p.Outputs[j] = s.p.Outputs[j], s.p.Outputs[i] +} + +// Less is the input comparison function. First sort based on input hash +// (reversed / rpc-style), then index. +func (s *sortableInputs) Less(i, j int) bool { + ins := s.p.UnsignedTx.TxIn + + // Input hashes are the same, so compare the index. + ihash := ins[i].PreviousOutPoint.Hash + jhash := ins[j].PreviousOutPoint.Hash + if ihash == jhash { + return ins[i].PreviousOutPoint.Index < + ins[j].PreviousOutPoint.Index + } + + // At this point, the hashes are not equal, so reverse them to + // big-endian and return the result of the comparison. + const hashSize = chainhash.HashSize + for b := 0; b < hashSize/2; b++ { + ihash[b], ihash[hashSize-1-b] = ihash[hashSize-1-b], ihash[b] + jhash[b], jhash[hashSize-1-b] = jhash[hashSize-1-b], jhash[b] + } + return bytes.Compare(ihash[:], jhash[:]) == -1 +} + +// Less is the output comparison function. First sort based on amount (smallest +// first), then PkScript. +func (s *sortableOutputs) Less(i, j int) bool { + outs := s.p.UnsignedTx.TxOut + + if outs[i].Value == outs[j].Value { + return bytes.Compare(outs[i].PkScript, outs[j].PkScript) < 0 + } + return outs[i].Value < outs[j].Value +} diff --git a/psbt/sort_test.go b/psbt/sort_test.go new file mode 100644 index 0000000..3dee0f4 --- /dev/null +++ b/psbt/sort_test.go @@ -0,0 +1,167 @@ +package psbt + +import ( + "reflect" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +func TestInPlaceSort(t *testing.T) { + testCases := []struct { + name string + packet *Packet + expectedTxIn []*wire.TxIn + expectedTxOut []*wire.TxOut + expectedPIn []PInput + expectedPOut []POutput + expectErr bool + }{{ + name: "packet nil", + packet: nil, + expectErr: true, + }, { + name: "no inputs or outputs", + packet: &Packet{UnsignedTx: &wire.MsgTx{}}, + expectErr: false, + }, { + name: "inputs only", + packet: &Packet{ + UnsignedTx: &wire.MsgTx{ + TxIn: []*wire.TxIn{{ + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{99, 88}, + Index: 7, + }, + }, { + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{77, 88}, + Index: 12, + }, + }, { + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{77, 88}, + Index: 7, + }, + }}, + }, + // Abuse the SighashType as an index to make sure the + // partial inputs are also sorted together with the wire + // inputs. + Inputs: []PInput{{ + SighashType: 0, + }, { + SighashType: 1, + }, { + SighashType: 2, + }}, + }, + expectedTxIn: []*wire.TxIn{{ + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{77, 88}, + Index: 7, + }, + }, { + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{77, 88}, + Index: 12, + }, + }, { + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{99, 88}, + Index: 7, + }, + }}, + expectedPIn: []PInput{{ + SighashType: 2, + }, { + SighashType: 1, + }, { + SighashType: 0, + }}, + expectErr: false, + }, { + name: "outputs only", + packet: &Packet{ + UnsignedTx: &wire.MsgTx{ + TxOut: []*wire.TxOut{{ + PkScript: []byte{99, 88}, + Value: 7, + }, { + PkScript: []byte{77, 88}, + Value: 12, + }, { + PkScript: []byte{77, 88}, + Value: 7, + }}, + }, + // Abuse the RedeemScript as an index to make sure the + // partial inputs are also sorted together with the wire + // inputs. + Outputs: []POutput{{ + RedeemScript: []byte{0}, + }, { + RedeemScript: []byte{1}, + }, { + RedeemScript: []byte{2}, + }}, + }, + expectedTxOut: []*wire.TxOut{{ + PkScript: []byte{77, 88}, + Value: 7, + }, { + PkScript: []byte{99, 88}, + Value: 7, + }, { + PkScript: []byte{77, 88}, + Value: 12, + }}, + expectedPOut: []POutput{{ + RedeemScript: []byte{2}, + }, { + RedeemScript: []byte{0}, + }, { + RedeemScript: []byte{1}, + }}, + expectErr: false, + }} + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := tc.packet + err := InPlaceSort(p) + if (tc.expectErr && err == nil) || + (!tc.expectErr && err != nil) { + + t.Fatalf("got error '%v' but wanted it to be "+ + "nil: %v", err, tc.expectErr) + } + + // Don't continue on this special test case. + if p == nil { + return + } + + tx := p.UnsignedTx + if !reflect.DeepEqual(tx.TxIn, tc.expectedTxIn) { + t.Fatalf("unexpected txin, got %#v wanted %#v", + tx.TxIn, tc.expectedTxIn) + } + if !reflect.DeepEqual(tx.TxOut, tc.expectedTxOut) { + t.Fatalf("unexpected txout, got %#v wanted %#v", + tx.TxOut, tc.expectedTxOut) + } + + if !reflect.DeepEqual(p.Inputs, tc.expectedPIn) { + t.Fatalf("unexpected pin, got %#v wanted %#v", + p.Inputs, tc.expectedPIn) + } + if !reflect.DeepEqual(p.Outputs, tc.expectedPOut) { + t.Fatalf("unexpected pout, got %#v wanted %#v", + p.Inputs, tc.expectedPOut) + } + }) + } +}