diff --git a/claimtrie/merkletrie/collapsedtrie.go b/claimtrie/merkletrie/collapsedtrie.go index 18af30a0..1c6f2c5f 100644 --- a/claimtrie/merkletrie/collapsedtrie.go +++ b/claimtrie/merkletrie/collapsedtrie.go @@ -1,16 +1,17 @@ package merkletrie -import ( - "github.com/lbryio/lbcd/chaincfg/chainhash" -) - type KeyType []byte +type VertexPayload interface { + clear() + childModified() + isEmpty() bool +} + type collapsedVertex struct { - children []*collapsedVertex - key KeyType - merkleHash *chainhash.Hash - claimHash *chainhash.Hash + children []*collapsedVertex + key KeyType + payload VertexPayload // TODO: remove all stupid null checks on this once we have generics } // insertAt inserts v into s at index i and returns the new slice. @@ -96,7 +97,9 @@ func (pt *collapsedTrie) insert(value KeyType, node *collapsedVertex) (bool, *co index, child := node.findNearest(value) match := 0 if index >= 0 { // if we found a child - child.merkleHash = nil + if child.payload != nil { + child.payload.childModified() + } match = matchLength(value, child.key) if len(value) == match && len(child.key) == match { return false, child @@ -107,8 +110,7 @@ func (pt *collapsedTrie) insert(value KeyType, node *collapsedVertex) (bool, *co return true, node.Insert(&collapsedVertex{key: value}) } if match < len(child.key) { - grandChild := collapsedVertex{key: child.key[match:], children: child.children, - claimHash: child.claimHash, merkleHash: child.merkleHash} + grandChild := collapsedVertex{key: child.key[match:], children: child.children, payload: child.payload} newChild := collapsedVertex{key: child.key[0:match], children: []*collapsedVertex{&grandChild}} child = &newChild node.children[index] = child @@ -121,7 +123,9 @@ func (pt *collapsedTrie) insert(value KeyType, node *collapsedVertex) (bool, *co } func (pt *collapsedTrie) InsertOrFind(value KeyType) (bool, *collapsedVertex) { - pt.Root.merkleHash = nil + if pt.Root.payload != nil { + pt.Root.payload.childModified() + } if len(value) <= 0 { return false, pt.Root } @@ -200,27 +204,30 @@ func iterateFrom(name KeyType, node *collapsedVertex, handler func(name KeyType, func (pt *collapsedTrie) Erase(value KeyType) bool { indexes, path := pt.FindPath(value) if path == nil || len(path) <= 1 { - if len(path) == 1 { - path[0].merkleHash = nil - path[0].claimHash = nil + if len(path) == 1 && path[0].payload != nil { + path[0].payload.clear() } return false } nodes := pt.Nodes i := len(path) - 1 - path[i].claimHash = nil // this is the thing we are erasing; the rest is book-keeping + if path[i].payload != nil { + path[i].payload.clear() // this is the thing we are erasing; the rest is book-keeping + } for ; i > 0; i-- { childCount := len(path[i].children) - noClaimData := path[i].claimHash == nil - path[i].merkleHash = nil - if childCount == 1 && noClaimData { + emptyPayload := path[i].payload == nil || path[i].payload.isEmpty() + if path[i].payload != nil { + path[i].payload.childModified() + } + if childCount == 1 && emptyPayload { path[i].key = append(path[i].key, path[i].children[0].key...) - path[i].claimHash = path[i].children[0].claimHash + path[i].payload = path[i].children[0].payload path[i].children = path[i].children[0].children pt.Nodes-- continue } - if childCount == 0 && noClaimData { + if childCount == 0 && emptyPayload { index := indexes[i] path[i-1].children = append(path[i-1].children[:index], path[i-1].children[index+1:]...) pt.Nodes-- @@ -229,7 +236,9 @@ func (pt *collapsedTrie) Erase(value KeyType) bool { break } for ; i >= 0; i-- { - path[i].merkleHash = nil + if path[i].payload != nil { + path[i].payload.childModified() + } } return nodes > pt.Nodes } diff --git a/claimtrie/merkletrie/collapsedtrie_test.go b/claimtrie/merkletrie/collapsedtrie_test.go index ce41c35f..e2f0959a 100644 --- a/claimtrie/merkletrie/collapsedtrie_test.go +++ b/claimtrie/merkletrie/collapsedtrie_test.go @@ -46,24 +46,43 @@ func TestInsertAndErase(t *testing.T) { assert.Equal(t, 1, trie.NodeCount()) } +type testPayload struct { + modifies int + full bool +} + +func (t *testPayload) clear() { + t.full = false +} + +func (t *testPayload) childModified() { + t.modifies++ +} + +func (t *testPayload) isEmpty() bool { + return !t.full +} + +var _ VertexPayload = &testPayload{} + func TestNilNameHandling(t *testing.T) { trie := NewCollapsedTrie() inserted, n := trie.InsertOrFind([]byte("test")) assert.True(t, inserted) - n.claimHash = EmptyTrieHash + p := testPayload{} inserted, n = trie.InsertOrFind(nil) assert.False(t, inserted) - n.claimHash = EmptyTrieHash - n.merkleHash = EmptyTrieHash + n.payload = &p + p.full = true inserted, n = trie.InsertOrFind(nil) assert.False(t, inserted) - assert.NotNil(t, n.claimHash) - assert.Nil(t, n.merkleHash) + assert.NotNil(t, n.payload) + assert.True(t, p.modifies > 0) nodeRemoved := trie.Erase(nil) assert.False(t, nodeRemoved) inserted, n = trie.InsertOrFind(nil) assert.False(t, inserted) - assert.Nil(t, n.claimHash) + assert.True(t, n.payload.isEmpty()) } func TestCollapsedTriePerformance(t *testing.T) { diff --git a/claimtrie/merkletrie/ramtrie.go b/claimtrie/merkletrie/ramtrie.go index 2ae6bcc1..646a6b32 100644 --- a/claimtrie/merkletrie/ramtrie.go +++ b/claimtrie/merkletrie/ramtrie.go @@ -34,10 +34,39 @@ func NewRamTrie() *RamTrie { } } +type ramTriePayload struct { + merkleHash *chainhash.Hash + claimHash *chainhash.Hash +} + +func (r *ramTriePayload) clear() { + r.claimHash = nil + r.merkleHash = nil +} + +func (r *ramTriePayload) childModified() { + r.merkleHash = nil +} + +func (r *ramTriePayload) isEmpty() bool { + return r.claimHash == nil +} + +func getOrMakePayload(v *collapsedVertex) *ramTriePayload { + if v.payload == nil { + r := &ramTriePayload{} + v.payload = r + return r + } + return v.payload.(*ramTriePayload) +} + +var _ VertexPayload = &ramTriePayload{} + var ErrFullRebuildRequired = errors.New("a full rebuild is required") func (rt *RamTrie) SetRoot(h *chainhash.Hash) error { - if rt.Root.merkleHash.IsEqual(h) { + if getOrMakePayload(rt.Root).merkleHash.IsEqual(h) { runtime.GC() return nil } @@ -51,7 +80,7 @@ func (rt *RamTrie) Update(name []byte, h *chainhash.Hash, _ bool) { rt.Erase(name) } else { _, n := rt.InsertOrFind(name) - n.claimHash = h + getOrMakePayload(n).claimHash = h } } @@ -59,12 +88,13 @@ func (rt *RamTrie) MerkleHash() *chainhash.Hash { if h := rt.merkleHash(rt.Root); h == nil { return EmptyTrieHash } - return rt.Root.merkleHash + return getOrMakePayload(rt.Root).merkleHash } func (rt *RamTrie) merkleHash(v *collapsedVertex) *chainhash.Hash { - if v.merkleHash != nil { - return v.merkleHash + p := getOrMakePayload(v) + if p.merkleHash != nil { + return p.merkleHash } b := rt.bufs.Get().(*bytes.Buffer) @@ -77,16 +107,16 @@ func (rt *RamTrie) merkleHash(v *collapsedVertex) *chainhash.Hash { b.Write(rt.completeHash(h, ch.key)) // nolint : errchk } - if v.claimHash != nil { - b.Write(v.claimHash[:]) + if p.claimHash != nil { + b.Write(p.claimHash[:]) } if b.Len() > 0 { h := chainhash.DoubleHashH(b.Bytes()) - v.merkleHash = &h + p.merkleHash = &h } - return v.merkleHash + return p.merkleHash } func (rt *RamTrie) completeHash(h *chainhash.Hash, childKey KeyType) []byte { @@ -103,12 +133,13 @@ func (rt *RamTrie) MerkleHashAllClaims() *chainhash.Hash { if h := rt.merkleHashAllClaims(rt.Root); h == nil { return EmptyTrieHash } - return rt.Root.merkleHash + return getOrMakePayload(rt.Root).merkleHash } func (rt *RamTrie) merkleHashAllClaims(v *collapsedVertex) *chainhash.Hash { - if v.merkleHash != nil { - return v.merkleHash + p := getOrMakePayload(v) + if p.merkleHash != nil { + return p.merkleHash } childHashes := make([]*chainhash.Hash, 0, len(v.children)) @@ -118,8 +149,8 @@ func (rt *RamTrie) merkleHashAllClaims(v *collapsedVertex) *chainhash.Hash { } claimHash := NoClaimsHash - if v.claimHash != nil { - claimHash = v.claimHash + if p.claimHash != nil { + claimHash = p.claimHash } else if len(childHashes) == 0 { return nil } @@ -130,8 +161,8 @@ func (rt *RamTrie) merkleHashAllClaims(v *collapsedVertex) *chainhash.Hash { childHash = node.ComputeMerkleRoot(childHashes) } - v.merkleHash = node.HashMerkleBranches(childHash, claimHash) - return v.merkleHash + p.merkleHash = node.HashMerkleBranches(childHash, claimHash) + return p.merkleHash } func (rt *RamTrie) Flush() error {