diff --git a/spvsvc/spvchain/blockmanager.go b/spvsvc/spvchain/blockmanager.go index ee12f57..288aa54 100644 --- a/spvsvc/spvchain/blockmanager.go +++ b/spvsvc/spvchain/blockmanager.go @@ -35,6 +35,13 @@ const ( maxTimeOffset = 2 * time.Hour ) +var ( + // WaitForMoreCFHeaders is a configurable time to wait for CFHeaders + // messages from peers. It defaults to 3 seconds but can be increased + // for higher security and decreased for faster synchronization. + WaitForMoreCFHeaders = 3 * time.Second +) + // zeroHash is the zero value hash (all zeros). It is defined as a convenience. var zeroHash chainhash.Hash @@ -947,20 +954,26 @@ func (b *blockManager) handleHeadersMsg(hmsg *headersMsg) { b.syncPeer = hmsg.peer b.server.rollbackToHeight(backHeight) b.server.putBlock(*blockHeader, backHeight+1) - b.mapMutex.Lock() - b.basicHeaders[node.header.BlockHash()] = make( - map[chainhash.Hash][]*serverPeer, - ) - b.extendedHeaders[node.header.BlockHash()] = make( - map[chainhash.Hash][]*serverPeer, - ) - b.mapMutex.Unlock() b.server.putMaxBlockHeight(backHeight + 1) b.resetHeaderState(&backHead, int32(backHeight)) b.headerList.PushBack(&headerNode{ header: blockHeader, height: int32(backHeight + 1), }) + b.mapMutex.Lock() + b.basicHeaders[blockHeader.BlockHash()] = make( + map[chainhash.Hash][]*serverPeer, + ) + b.extendedHeaders[blockHeader.BlockHash()] = make( + map[chainhash.Hash][]*serverPeer, + ) + b.mapMutex.Unlock() + if b.lastBasicCFHeaderHeight > int32(backHeight) { + b.lastBasicCFHeaderHeight = int32(backHeight) + } + if b.lastExtCFHeaderHeight > int32(backHeight) { + b.lastExtCFHeaderHeight = int32(backHeight) + } } // Verify the header at the next checkpoint height matches. @@ -1139,6 +1152,7 @@ func (b *blockManager) handleCFHeadersMsg(cfhmsg *cfheadersMsg) { b.mapMutex.Lock() if _, ok := headerMap[hash]; !ok { b.mapMutex.Unlock() + log.Tracef("Breaking at %d (%s)", node.height, hash) break } // Process this header and set up the next iteration. @@ -1179,6 +1193,8 @@ func (b *blockManager) handleProcessCFHeadersMsg(msg *processCFHeadersMsg) { pendingMsgs = &b.numExtCFHeadersMsgs } + stopHash := msg.earliestNode.header.PrevBlock + // If we have started receiving cfheaders messages for blocks farther // than the last set we haven't made a decision on, it's time to make // a decision. @@ -1186,11 +1202,33 @@ func (b *blockManager) handleProcessCFHeadersMsg(msg *processCFHeadersMsg) { ready = true } + // If we have fewer processed cfheaders messages for the earliest node + // than the number of connected peers, give the other peers some time to + // catch up before checking if we've processed all of the queued + // cfheaders messages. + numHeaders := 0 + blockMap := headerMap[msg.earliestNode.header.BlockHash()] + for headerHash := range blockMap { + numHeaders += len(blockMap[headerHash]) + } + // Sleep for a bit if we have more peers than cfheaders messages for the + // earliest node for which we're trying to get cfheaders. This lets us + // wait for other peers to send cfheaders messages before making any + // decisions about whether we should write the headers in this message. + connCount := int(b.server.ConnectedCount()) + log.Tracef("Number of peers for which we've processed a cfheaders for "+ + "block %s: %d of %d", msg.earliestNode.header.BlockHash(), + numHeaders, connCount) + if numHeaders <= connCount { + time.Sleep(WaitForMoreCFHeaders) + } + // If there are no other cfheaders messages left for this type (basic vs // extended), we should go ahead and make a decision because we have all // the info we're going to get. if atomic.LoadInt32(pendingMsgs) == 0 { ready = true + stopHash = msg.stopHash } // Do nothing if we're not ready to make a decision yet. @@ -1220,6 +1258,7 @@ func (b *blockManager) handleProcessCFHeadersMsg(msg *processCFHeadersMsg) { log.Warnf("Somehow we have 0 cfheaders"+ " for block %d (%s)", node.height, hash) + b.mapMutex.Unlock() return } // This is the normal case when nobody's trying to @@ -1252,7 +1291,7 @@ func (b *blockManager) handleProcessCFHeadersMsg(msg *processCFHeadersMsg) { //b.startHeader = el // If we've reached the end, we can return - if hash == msg.stopHash { + if hash == stopHash { log.Tracef("Finished processing cfheaders messages up "+ "to height %d/hash %s, extended: %t", node.height, hash, msg.extended) diff --git a/spvsvc/spvchain/sync_test.go b/spvsvc/spvchain/sync_test.go index b684289..2d6b3e8 100644 --- a/spvsvc/spvchain/sync_test.go +++ b/spvsvc/spvchain/sync_test.go @@ -18,7 +18,11 @@ import ( _ "github.com/btcsuite/btcwallet/walletdb/bdb" ) -var logLevel = btclog.TraceLvl +const ( + logLevel = btclog.TraceLvl + syncTimeout = 30 * time.Second + syncUpdate = time.Second +) func TestSetup(t *testing.T) { // Create a btcd SimNet node and generate 500 blocks @@ -135,6 +139,7 @@ func TestSetup(t *testing.T) { spvchain.MaxPeers = 3 spvchain.BanDuration = 5 * time.Second spvchain.RequiredServices = wire.SFNodeNetwork + spvchain.WaitForMoreCFHeaders = time.Second logger, err := btctestlog.NewTestLogger(t) if err != nil { t.Fatalf("Could not set up logger: %s", err) @@ -150,7 +155,7 @@ func TestSetup(t *testing.T) { defer svc.Stop() // Make sure the client synchronizes with the correct node - err = waitForSync(t, svc, h1, time.Second, 30*time.Second) + err = waitForSync(t, svc, h1) if err != nil { t.Fatalf("Couldn't sync ChainService: %s", err) } @@ -158,7 +163,7 @@ func TestSetup(t *testing.T) { // Generate 125 blocks on h1 to make sure it reorgs the other nodes. // Ensure the ChainService instance stays caught up. h1.Node.Generate(125) - err = waitForSync(t, svc, h1, time.Second, 30*time.Second) + err = waitForSync(t, svc, h1) if err != nil { t.Fatalf("Couldn't sync ChainService: %s", err) } @@ -173,7 +178,7 @@ func TestSetup(t *testing.T) { // ChainService instance stays caught up. for i := 0; i < 3; i++ { h1.Node.Generate(1) - err = waitForSync(t, svc, h1, time.Second, 30*time.Second) + err = waitForSync(t, svc, h1) if err != nil { t.Fatalf("Couldn't sync ChainService: %s", err) } @@ -181,8 +186,8 @@ func TestSetup(t *testing.T) { // Generate 5 blocks on h2 and wait for ChainService to sync to the // newly-best chain on h2. - /*h2.Node.Generate(5) - err = waitForSync(t, svc, h2, time.Second, 30*time.Second) + h2.Node.Generate(5) + err = waitForSync(t, svc, h2) if err != nil { t.Fatalf("Couldn't sync ChainService: %s", err) } @@ -190,10 +195,10 @@ func TestSetup(t *testing.T) { // Generate 7 blocks on h1 and wait for ChainService to sync to the // newly-best chain on h1. h1.Node.Generate(7) - err = waitForSync(t, svc, h1, time.Second, 30*time.Second) + err = waitForSync(t, svc, h1) if err != nil { t.Fatalf("Couldn't sync ChainService: %s", err) - }*/ + } } // csd does a connect-sync-disconnect between nodes in order to support @@ -221,8 +226,7 @@ func csd(harnesses []*rpctest.Harness) error { // waitForSync waits for the ChainService to sync to the current chain state. func waitForSync(t *testing.T, svc *spvchain.ChainService, - correctSyncNode *rpctest.Harness, checkInterval, - timeout time.Duration) error { + correctSyncNode *rpctest.Harness) error { knownBestHash, knownBestHeight, err := correctSyncNode.Node.GetBestBlock() if err != nil { @@ -239,15 +243,15 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService, } var total time.Duration for haveBest.Hash != *knownBestHash { - if total > timeout { + if total > syncTimeout { return fmt.Errorf("Timed out after %v waiting for "+ - "header synchronization.", timeout) + "header synchronization.", syncTimeout) } if haveBest.Height > knownBestHeight { return fmt.Errorf("Synchronized to the wrong chain.") } - time.Sleep(checkInterval) - total += checkInterval + time.Sleep(syncUpdate) + total += syncUpdate haveBest, err = svc.BestSnapshot() if err != nil { return fmt.Errorf("Couldn't get best snapshot from "+ @@ -275,9 +279,9 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService, return fmt.Errorf("Couldn't get latest extended header from "+ "%s: %s", correctSyncNode.P2PAddress(), err) } - for total <= timeout { - time.Sleep(checkInterval) - total += checkInterval + for total <= syncTimeout { + time.Sleep(syncUpdate) + total += syncUpdate haveBasicHeader, err := svc.GetBasicHeader(*knownBestHash) if err != nil { if logLevel != btclog.Off { @@ -369,5 +373,5 @@ func waitForSync(t *testing.T, svc *spvchain.ChainService, return nil } return fmt.Errorf("Timeout waiting for cfheaders synchronization after"+ - " %v", timeout) + " %v", syncTimeout) }