wtxmgr: refactor dependencySort to use wire.MsgTx

This commit is contained in:
Wilmer Paulino 2019-01-31 16:41:19 -08:00
parent ba03278a64
commit fe56fdb828
No known key found for this signature in database
GPG key ID: 6DF57B9F9514972F
2 changed files with 38 additions and 34 deletions

View file

@ -4,29 +4,33 @@
package wtxmgr package wtxmgr
import "github.com/btcsuite/btcd/chaincfg/chainhash" import (
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
)
type graphNode struct { type graphNode struct {
value *TxRecord value *wire.MsgTx
outEdges []*chainhash.Hash outEdges []*chainhash.Hash
inDegree int inDegree int
} }
type hashGraph map[chainhash.Hash]graphNode type hashGraph map[chainhash.Hash]graphNode
func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph { func makeGraph(set map[chainhash.Hash]*wire.MsgTx) hashGraph {
graph := make(hashGraph) graph := make(hashGraph)
for _, rec := range set { for _, tx := range set {
// Add a node for every transaction record. The output edges // Add a node for every transaction. The output edges and input
// and input degree are set by iterating over each record's // degree are set by iterating over each transaction's inputs
// inputs below. // below.
if _, ok := graph[rec.Hash]; !ok { txHash := tx.TxHash()
graph[rec.Hash] = graphNode{value: rec} if _, ok := graph[txHash]; !ok {
graph[txHash] = graphNode{value: tx}
} }
inputLoop: inputLoop:
for _, input := range rec.MsgTx.TxIn { for _, input := range tx.TxIn {
// Transaction inputs that reference transactions not // Transaction inputs that reference transactions not
// included in the set do not create any (local) graph // included in the set do not create any (local) graph
// edges. // edges.
@ -44,20 +48,20 @@ func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph {
} }
// Mark a directed edge from the previous transaction // Mark a directed edge from the previous transaction
// hash to this transaction record and increase the // hash to this transaction and increase the input
// input degree for this record's node. // degree for this transaction's node.
inputRec := inputNode.value inputTx := inputNode.value
if inputRec == nil { if inputTx == nil {
inputRec = set[input.PreviousOutPoint.Hash] inputTx = set[input.PreviousOutPoint.Hash]
} }
graph[input.PreviousOutPoint.Hash] = graphNode{ graph[input.PreviousOutPoint.Hash] = graphNode{
value: inputRec, value: inputTx,
outEdges: append(inputNode.outEdges, &rec.Hash), outEdges: append(inputNode.outEdges, &txHash),
inDegree: inputNode.inDegree, inDegree: inputNode.inDegree,
} }
node := graph[rec.Hash] node := graph[txHash]
graph[rec.Hash] = graphNode{ graph[txHash] = graphNode{
value: rec, value: tx,
outEdges: node.outEdges, outEdges: node.outEdges,
inDegree: node.inDegree + 1, inDegree: node.inDegree + 1,
} }
@ -69,8 +73,8 @@ func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph {
// graphRoots returns the roots of the graph. That is, it returns the node's // graphRoots returns the roots of the graph. That is, it returns the node's
// values for all nodes which contain an input degree of 0. // values for all nodes which contain an input degree of 0.
func graphRoots(graph hashGraph) []*TxRecord { func graphRoots(graph hashGraph) []*wire.MsgTx {
roots := make([]*TxRecord, 0, len(graph)) roots := make([]*wire.MsgTx, 0, len(graph))
for _, node := range graph { for _, node := range graph {
if node.inDegree == 0 { if node.inDegree == 0 {
roots = append(roots, node.value) roots = append(roots, node.value)
@ -79,9 +83,9 @@ func graphRoots(graph hashGraph) []*TxRecord {
return roots return roots
} }
// dependencySort topologically sorts a set of transaction records by their // dependencySort topologically sorts a set of transactions by their dependency
// dependency order. It is implemented using Kahn's algorithm. // order. It is implemented using Kahn's algorithm.
func dependencySort(txs map[chainhash.Hash]*TxRecord) []*TxRecord { func dependencySort(txs map[chainhash.Hash]*wire.MsgTx) []*wire.MsgTx {
graph := makeGraph(txs) graph := makeGraph(txs)
s := graphRoots(graph) s := graphRoots(graph)
@ -91,13 +95,13 @@ func dependencySort(txs map[chainhash.Hash]*TxRecord) []*TxRecord {
return s return s
} }
sorted := make([]*TxRecord, 0, len(txs)) sorted := make([]*wire.MsgTx, 0, len(txs))
for len(s) != 0 { for len(s) != 0 {
rec := s[0] tx := s[0]
s = s[1:] s = s[1:]
sorted = append(sorted, rec) sorted = append(sorted, tx)
n := graph[rec.Hash] n := graph[tx.TxHash()]
for _, mHash := range n.outEdges { for _, mHash := range n.outEdges {
m := graph[*mHash] m := graph[*mHash]
if m.inDegree != 0 { if m.inDegree != 0 {

View file

@ -164,12 +164,12 @@ func (s *Store) UnminedTxs(ns walletdb.ReadBucket) ([]*wire.MsgTx, error) {
return nil, err return nil, err
} }
recs := dependencySort(recSet) txSet := make(map[chainhash.Hash]*wire.MsgTx, len(recSet))
txs := make([]*wire.MsgTx, 0, len(recs)) for txHash, txRec := range recSet {
for _, rec := range recs { txSet[txHash] = &txRec.MsgTx
txs = append(txs, &rec.MsgTx)
} }
return txs, nil
return dependencySort(txSet), nil
} }
func (s *Store) unminedTxRecords(ns walletdb.ReadBucket) (map[chainhash.Hash]*TxRecord, error) { func (s *Store) unminedTxRecords(ns walletdb.ReadBucket) (map[chainhash.Hash]*TxRecord, error) {