made payload interface

This commit is contained in:
Brannon King 2021-10-01 23:08:59 -04:00
parent 55abc9d25c
commit 742881a29a
3 changed files with 103 additions and 44 deletions

View file

@ -1,16 +1,17 @@
package merkletrie package merkletrie
import (
"github.com/lbryio/lbcd/chaincfg/chainhash"
)
type KeyType []byte type KeyType []byte
type VertexPayload interface {
clear()
childModified()
isEmpty() bool
}
type collapsedVertex struct { type collapsedVertex struct {
children []*collapsedVertex children []*collapsedVertex
key KeyType key KeyType
merkleHash *chainhash.Hash payload VertexPayload // TODO: remove all stupid null checks on this once we have generics
claimHash *chainhash.Hash
} }
// insertAt inserts v into s at index i and returns the new slice. // insertAt inserts v into s at index i and returns the new slice.
@ -96,7 +97,9 @@ func (pt *collapsedTrie) insert(value KeyType, node *collapsedVertex) (bool, *co
index, child := node.findNearest(value) index, child := node.findNearest(value)
match := 0 match := 0
if index >= 0 { // if we found a child if index >= 0 { // if we found a child
child.merkleHash = nil if child.payload != nil {
child.payload.childModified()
}
match = matchLength(value, child.key) match = matchLength(value, child.key)
if len(value) == match && len(child.key) == match { if len(value) == match && len(child.key) == match {
return false, child return false, child
@ -107,8 +110,7 @@ func (pt *collapsedTrie) insert(value KeyType, node *collapsedVertex) (bool, *co
return true, node.Insert(&collapsedVertex{key: value}) return true, node.Insert(&collapsedVertex{key: value})
} }
if match < len(child.key) { if match < len(child.key) {
grandChild := collapsedVertex{key: child.key[match:], children: child.children, grandChild := collapsedVertex{key: child.key[match:], children: child.children, payload: child.payload}
claimHash: child.claimHash, merkleHash: child.merkleHash}
newChild := collapsedVertex{key: child.key[0:match], children: []*collapsedVertex{&grandChild}} newChild := collapsedVertex{key: child.key[0:match], children: []*collapsedVertex{&grandChild}}
child = &newChild child = &newChild
node.children[index] = child node.children[index] = child
@ -121,7 +123,9 @@ func (pt *collapsedTrie) insert(value KeyType, node *collapsedVertex) (bool, *co
} }
func (pt *collapsedTrie) InsertOrFind(value KeyType) (bool, *collapsedVertex) { func (pt *collapsedTrie) InsertOrFind(value KeyType) (bool, *collapsedVertex) {
pt.Root.merkleHash = nil if pt.Root.payload != nil {
pt.Root.payload.childModified()
}
if len(value) <= 0 { if len(value) <= 0 {
return false, pt.Root return false, pt.Root
} }
@ -200,27 +204,30 @@ func iterateFrom(name KeyType, node *collapsedVertex, handler func(name KeyType,
func (pt *collapsedTrie) Erase(value KeyType) bool { func (pt *collapsedTrie) Erase(value KeyType) bool {
indexes, path := pt.FindPath(value) indexes, path := pt.FindPath(value)
if path == nil || len(path) <= 1 { if path == nil || len(path) <= 1 {
if len(path) == 1 { if len(path) == 1 && path[0].payload != nil {
path[0].merkleHash = nil path[0].payload.clear()
path[0].claimHash = nil
} }
return false return false
} }
nodes := pt.Nodes nodes := pt.Nodes
i := len(path) - 1 i := len(path) - 1
path[i].claimHash = nil // this is the thing we are erasing; the rest is book-keeping if path[i].payload != nil {
path[i].payload.clear() // this is the thing we are erasing; the rest is book-keeping
}
for ; i > 0; i-- { for ; i > 0; i-- {
childCount := len(path[i].children) childCount := len(path[i].children)
noClaimData := path[i].claimHash == nil emptyPayload := path[i].payload == nil || path[i].payload.isEmpty()
path[i].merkleHash = nil if path[i].payload != nil {
if childCount == 1 && noClaimData { path[i].payload.childModified()
}
if childCount == 1 && emptyPayload {
path[i].key = append(path[i].key, path[i].children[0].key...) path[i].key = append(path[i].key, path[i].children[0].key...)
path[i].claimHash = path[i].children[0].claimHash path[i].payload = path[i].children[0].payload
path[i].children = path[i].children[0].children path[i].children = path[i].children[0].children
pt.Nodes-- pt.Nodes--
continue continue
} }
if childCount == 0 && noClaimData { if childCount == 0 && emptyPayload {
index := indexes[i] index := indexes[i]
path[i-1].children = append(path[i-1].children[:index], path[i-1].children[index+1:]...) path[i-1].children = append(path[i-1].children[:index], path[i-1].children[index+1:]...)
pt.Nodes-- pt.Nodes--
@ -229,7 +236,9 @@ func (pt *collapsedTrie) Erase(value KeyType) bool {
break break
} }
for ; i >= 0; i-- { for ; i >= 0; i-- {
path[i].merkleHash = nil if path[i].payload != nil {
path[i].payload.childModified()
}
} }
return nodes > pt.Nodes return nodes > pt.Nodes
} }

View file

@ -46,24 +46,43 @@ func TestInsertAndErase(t *testing.T) {
assert.Equal(t, 1, trie.NodeCount()) assert.Equal(t, 1, trie.NodeCount())
} }
type testPayload struct {
modifies int
full bool
}
func (t *testPayload) clear() {
t.full = false
}
func (t *testPayload) childModified() {
t.modifies++
}
func (t *testPayload) isEmpty() bool {
return !t.full
}
var _ VertexPayload = &testPayload{}
func TestNilNameHandling(t *testing.T) { func TestNilNameHandling(t *testing.T) {
trie := NewCollapsedTrie() trie := NewCollapsedTrie()
inserted, n := trie.InsertOrFind([]byte("test")) inserted, n := trie.InsertOrFind([]byte("test"))
assert.True(t, inserted) assert.True(t, inserted)
n.claimHash = EmptyTrieHash p := testPayload{}
inserted, n = trie.InsertOrFind(nil) inserted, n = trie.InsertOrFind(nil)
assert.False(t, inserted) assert.False(t, inserted)
n.claimHash = EmptyTrieHash n.payload = &p
n.merkleHash = EmptyTrieHash p.full = true
inserted, n = trie.InsertOrFind(nil) inserted, n = trie.InsertOrFind(nil)
assert.False(t, inserted) assert.False(t, inserted)
assert.NotNil(t, n.claimHash) assert.NotNil(t, n.payload)
assert.Nil(t, n.merkleHash) assert.True(t, p.modifies > 0)
nodeRemoved := trie.Erase(nil) nodeRemoved := trie.Erase(nil)
assert.False(t, nodeRemoved) assert.False(t, nodeRemoved)
inserted, n = trie.InsertOrFind(nil) inserted, n = trie.InsertOrFind(nil)
assert.False(t, inserted) assert.False(t, inserted)
assert.Nil(t, n.claimHash) assert.True(t, n.payload.isEmpty())
} }
func TestCollapsedTriePerformance(t *testing.T) { func TestCollapsedTriePerformance(t *testing.T) {

View file

@ -34,10 +34,39 @@ func NewRamTrie() *RamTrie {
} }
} }
type ramTriePayload struct {
merkleHash *chainhash.Hash
claimHash *chainhash.Hash
}
func (r *ramTriePayload) clear() {
r.claimHash = nil
r.merkleHash = nil
}
func (r *ramTriePayload) childModified() {
r.merkleHash = nil
}
func (r *ramTriePayload) isEmpty() bool {
return r.claimHash == nil
}
func getOrMakePayload(v *collapsedVertex) *ramTriePayload {
if v.payload == nil {
r := &ramTriePayload{}
v.payload = r
return r
}
return v.payload.(*ramTriePayload)
}
var _ VertexPayload = &ramTriePayload{}
var ErrFullRebuildRequired = errors.New("a full rebuild is required") var ErrFullRebuildRequired = errors.New("a full rebuild is required")
func (rt *RamTrie) SetRoot(h *chainhash.Hash) error { func (rt *RamTrie) SetRoot(h *chainhash.Hash) error {
if rt.Root.merkleHash.IsEqual(h) { if getOrMakePayload(rt.Root).merkleHash.IsEqual(h) {
runtime.GC() runtime.GC()
return nil return nil
} }
@ -51,7 +80,7 @@ func (rt *RamTrie) Update(name []byte, h *chainhash.Hash, _ bool) {
rt.Erase(name) rt.Erase(name)
} else { } else {
_, n := rt.InsertOrFind(name) _, n := rt.InsertOrFind(name)
n.claimHash = h getOrMakePayload(n).claimHash = h
} }
} }
@ -59,12 +88,13 @@ func (rt *RamTrie) MerkleHash() *chainhash.Hash {
if h := rt.merkleHash(rt.Root); h == nil { if h := rt.merkleHash(rt.Root); h == nil {
return EmptyTrieHash return EmptyTrieHash
} }
return rt.Root.merkleHash return getOrMakePayload(rt.Root).merkleHash
} }
func (rt *RamTrie) merkleHash(v *collapsedVertex) *chainhash.Hash { func (rt *RamTrie) merkleHash(v *collapsedVertex) *chainhash.Hash {
if v.merkleHash != nil { p := getOrMakePayload(v)
return v.merkleHash if p.merkleHash != nil {
return p.merkleHash
} }
b := rt.bufs.Get().(*bytes.Buffer) b := rt.bufs.Get().(*bytes.Buffer)
@ -77,16 +107,16 @@ func (rt *RamTrie) merkleHash(v *collapsedVertex) *chainhash.Hash {
b.Write(rt.completeHash(h, ch.key)) // nolint : errchk b.Write(rt.completeHash(h, ch.key)) // nolint : errchk
} }
if v.claimHash != nil { if p.claimHash != nil {
b.Write(v.claimHash[:]) b.Write(p.claimHash[:])
} }
if b.Len() > 0 { if b.Len() > 0 {
h := chainhash.DoubleHashH(b.Bytes()) h := chainhash.DoubleHashH(b.Bytes())
v.merkleHash = &h p.merkleHash = &h
} }
return v.merkleHash return p.merkleHash
} }
func (rt *RamTrie) completeHash(h *chainhash.Hash, childKey KeyType) []byte { func (rt *RamTrie) completeHash(h *chainhash.Hash, childKey KeyType) []byte {
@ -103,12 +133,13 @@ func (rt *RamTrie) MerkleHashAllClaims() *chainhash.Hash {
if h := rt.merkleHashAllClaims(rt.Root); h == nil { if h := rt.merkleHashAllClaims(rt.Root); h == nil {
return EmptyTrieHash return EmptyTrieHash
} }
return rt.Root.merkleHash return getOrMakePayload(rt.Root).merkleHash
} }
func (rt *RamTrie) merkleHashAllClaims(v *collapsedVertex) *chainhash.Hash { func (rt *RamTrie) merkleHashAllClaims(v *collapsedVertex) *chainhash.Hash {
if v.merkleHash != nil { p := getOrMakePayload(v)
return v.merkleHash if p.merkleHash != nil {
return p.merkleHash
} }
childHashes := make([]*chainhash.Hash, 0, len(v.children)) childHashes := make([]*chainhash.Hash, 0, len(v.children))
@ -118,8 +149,8 @@ func (rt *RamTrie) merkleHashAllClaims(v *collapsedVertex) *chainhash.Hash {
} }
claimHash := NoClaimsHash claimHash := NoClaimsHash
if v.claimHash != nil { if p.claimHash != nil {
claimHash = v.claimHash claimHash = p.claimHash
} else if len(childHashes) == 0 { } else if len(childHashes) == 0 {
return nil return nil
} }
@ -130,8 +161,8 @@ func (rt *RamTrie) merkleHashAllClaims(v *collapsedVertex) *chainhash.Hash {
childHash = node.ComputeMerkleRoot(childHashes) childHash = node.ComputeMerkleRoot(childHashes)
} }
v.merkleHash = node.HashMerkleBranches(childHash, claimHash) p.merkleHash = node.HashMerkleBranches(childHash, claimHash)
return v.merkleHash return p.merkleHash
} }
func (rt *RamTrie) Flush() error { func (rt *RamTrie) Flush() error {