diff --git a/wallet/chainntfns_test.go b/wallet/chainntfns_test.go index ba7f8aa..8d62390 100644 --- a/wallet/chainntfns_test.go +++ b/wallet/chainntfns_test.go @@ -13,14 +13,16 @@ import ( _ "github.com/btcsuite/btcwallet/walletdb/bdb" ) +const ( + // defaultBlockInterval is the default time interval between any two + // blocks in a mocked chain. + defaultBlockInterval = 10 * time.Minute +) + var ( // chainParams are the chain parameters used throughout the wallet // tests. chainParams = chaincfg.MainNetParams - - // blockInterval is the time interval between any two blocks in a mocked - // chain. - blockInterval = 10 * time.Minute ) // mockChainConn is a mock in-memory implementation of the chainConn interface @@ -36,9 +38,11 @@ type mockChainConn struct { var _ chainConn = (*mockChainConn)(nil) // createMockChainConn creates a new mock chain connection backed by a chain -// with N blocks. Each block has a timestamp that is exactly 10 minutes after +// with N blocks. Each block has a timestamp that is exactly blockInterval after // the previous block's timestamp. -func createMockChainConn(genesis *wire.MsgBlock, n uint32) *mockChainConn { +func createMockChainConn(genesis *wire.MsgBlock, n uint32, + blockInterval time.Duration) *mockChainConn { + c := &mockChainConn{ chainTip: n, blockHashes: make(map[uint32]chainhash.Hash), @@ -163,7 +167,9 @@ func TestBirthdaySanityCheckVerifiedBirthdayBlock(t *testing.T) { t.Parallel() const chainTip = 5000 - chainConn := createMockChainConn(chainParams.GenesisBlock, chainTip) + chainConn := createMockChainConn( + chainParams.GenesisBlock, chainTip, defaultBlockInterval, + ) expectedBirthdayBlock := waddrmgr.BlockStamp{Height: 1337} // Our birthday store reflects that our birthday block has already been @@ -205,10 +211,12 @@ func TestBirthdaySanityCheckLowerEstimate(t *testing.T) { // We'll start by defining our birthday timestamp to be around the // timestamp of the 1337th block. genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp - birthday := genesisTimestamp.Add(1337 * blockInterval) + birthday := genesisTimestamp.Add(1337 * defaultBlockInterval) // We'll establish a connection to a mock chain of 5000 blocks. - chainConn := createMockChainConn(chainParams.GenesisBlock, 5000) + chainConn := createMockChainConn( + chainParams.GenesisBlock, 5000, defaultBlockInterval, + ) // Our birthday store will reflect that our birthday block is currently // set as the genesis block. This value is too low and should be @@ -256,10 +264,12 @@ func TestBirthdaySanityCheckHigherEstimate(t *testing.T) { // We'll start by defining our birthday timestamp to be around the // timestamp of the 1337th block. genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp - birthday := genesisTimestamp.Add(1337 * blockInterval) + birthday := genesisTimestamp.Add(1337 * defaultBlockInterval) // We'll establish a connection to a mock chain of 5000 blocks. - chainConn := createMockChainConn(chainParams.GenesisBlock, 5000) + chainConn := createMockChainConn( + chainParams.GenesisBlock, 5000, defaultBlockInterval, + ) // Our birthday store will reflect that our birthday block is currently // set as the chain tip. This value is too high and should be adjusted diff --git a/wallet/wallet.go b/wallet/wallet.go index 5fad970..25e6f0f 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -343,11 +343,24 @@ func (w *Wallet) syncWithChain(birthdayStamp *waddrmgr.BlockStamp) error { log.Debug("Chain backend synced to tip!") } + // If we've yet to find our birthday block, we'll do so now. if birthdayStamp == nil { var err error - birthdayStamp, err = w.syncToBirthday() + birthdayStamp, err = locateBirthdayBlock( + chainClient, w.Manager.Birthday(), + ) if err != nil { - return err + return fmt.Errorf("unable to locate birthday block: %v", + err) + } + + err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error { + ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) + return w.Manager.SetBirthdayBlock(ns, *birthdayStamp, true) + }) + if err != nil { + return fmt.Errorf("unable to write birthday block: %v", + err) } } @@ -497,6 +510,85 @@ func (w *Wallet) waitUntilBackendSynced(chainClient chain.Interface) error { } } +// locateBirthdayBlock returns a block that meets the given birthday timestamp +// by a margin of +/-2 hours. This is safe to do as the timestamp is already 2 +// days in the past of the actual timestamp. +func locateBirthdayBlock(chainClient chainConn, + birthday time.Time) (*waddrmgr.BlockStamp, error) { + + // Retrieve the lookup range for our block. + startHeight := int32(0) + _, bestHeight, err := chainClient.GetBestBlock() + if err != nil { + return nil, err + } + + log.Debugf("Locating suitable block for birthday %v between blocks "+ + "%v-%v", birthday, startHeight, bestHeight) + + var ( + birthdayBlock *waddrmgr.BlockStamp + left, right = startHeight, bestHeight + ) + + // Binary search for a block that meets the birthday timestamp by a + // margin of +/-2 hours. + for { + // Retrieve the timestamp for the block halfway through our + // range. + mid := left + (right-left)/2 + hash, err := chainClient.GetBlockHash(int64(mid)) + if err != nil { + return nil, err + } + header, err := chainClient.GetBlockHeader(hash) + if err != nil { + return nil, err + } + + log.Debugf("Checking candidate block: height=%v, hash=%v, "+ + "timestamp=%v", mid, hash, header.Timestamp) + + // If the search happened to reach either of our range extremes, + // then we'll just use that as there's nothing left to search. + if mid == startHeight || mid == bestHeight || mid == left { + birthdayBlock = &waddrmgr.BlockStamp{ + Hash: *hash, + Height: int32(mid), + Timestamp: header.Timestamp, + } + break + } + + // The block's timestamp is more than 2 hours after the + // birthday, so look for a lower block. + if header.Timestamp.Sub(birthday) > birthdayBlockDelta { + right = mid + continue + } + + // The birthday is more than 2 hours before the block's + // timestamp, so look for a higher block. + if header.Timestamp.Sub(birthday) < -birthdayBlockDelta { + left = mid + continue + } + + birthdayBlock = &waddrmgr.BlockStamp{ + Hash: *hash, + Height: int32(mid), + Timestamp: header.Timestamp, + } + break + } + + log.Debugf("Found birthday block: height=%d, hash=%v, timestamp=%v", + birthdayBlock.Height, birthdayBlock.Hash, + birthdayBlock.Timestamp) + + return birthdayBlock, nil +} + // scanChain is a helper method that scans the chain from the starting height // until the tip of the chain. The onBlock callback can be used to perform // certain operations for every block that we process as we scan the chain. @@ -573,113 +665,6 @@ func (w *Wallet) scanChain(startHeight int32, return nil } -// syncToBirthday attempts to sync the wallet's point of view of the chain until -// it finds the first block whose timestamp is above the wallet's birthday. The -// wallet's birthday is already two days in the past of its actual birthday, so -// this is relatively safe to do. -func (w *Wallet) syncToBirthday() (*waddrmgr.BlockStamp, error) { - var birthdayStamp *waddrmgr.BlockStamp - birthday := w.Manager.Birthday() - - tx, err := w.db.BeginReadWriteTx() - if err != nil { - return nil, err - } - ns := tx.ReadWriteBucket(waddrmgrNamespaceKey) - - // We'll begin scanning the chain from our last sync point until finding - // the first block with a timestamp greater than our birthday. We'll use - // this block to represent our birthday stamp. errDone is an error we'll - // use to signal that we've found it and no longer need to keep scanning - // the chain. - errDone := errors.New("done") - err = w.scanChain(w.Manager.SyncedTo().Height, func(height int32, - hash *chainhash.Hash, header *wire.BlockHeader) error { - - if header.Timestamp.After(birthday) { - log.Debugf("Found birthday block: height=%d, hash=%v", - height, hash) - - birthdayStamp = &waddrmgr.BlockStamp{ - Hash: *hash, - Height: height, - Timestamp: header.Timestamp, - } - - err := w.Manager.SetBirthdayBlock( - ns, *birthdayStamp, true, - ) - if err != nil { - return err - } - } - - err = w.Manager.SetSyncedTo(ns, &waddrmgr.BlockStamp{ - Hash: *hash, - Height: height, - Timestamp: header.Timestamp, - }) - if err != nil { - return err - } - - // Checkpoint our state every 10K blocks. - if height%10000 == 0 { - if err := tx.Commit(); err != nil { - return err - } - - log.Infof("Caught up to height %d", height) - - tx, err = w.db.BeginReadWriteTx() - if err != nil { - return err - } - ns = tx.ReadWriteBucket(waddrmgrNamespaceKey) - } - - // If we've found our birthday, we can return errDone to signal - // that we should stop scanning the chain and persist our state. - if birthdayStamp != nil { - return errDone - } - - return nil - }) - if err != nil && err != errDone { - tx.Rollback() - return nil, err - } - - // If a birthday stamp has yet to be found, we'll return an error - // indicating so, but only if this is a live chain like it is the case - // with testnet and mainnet. - if birthdayStamp == nil && !w.isDevEnv() { - tx.Rollback() - return nil, fmt.Errorf("did not find a suitable birthday "+ - "block with a timestamp greater than %v", birthday) - } - - // Otherwise, if we're in a development environment and we've yet to - // find a birthday block due to the chain not being current, we'll - // use the last block we've synced to as our birthday to proceed. - if birthdayStamp == nil { - syncedTo := w.Manager.SyncedTo() - err := w.Manager.SetBirthdayBlock(ns, syncedTo, true) - if err != nil { - return nil, err - } - birthdayStamp = &syncedTo - } - - if err := tx.Commit(); err != nil { - tx.Rollback() - return nil, err - } - - return birthdayStamp, nil -} - // recovery attempts to recover any unspent outputs that pay to any of our // addresses starting from the specified height. // diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go new file mode 100644 index 0000000..f70f43c --- /dev/null +++ b/wallet/wallet_test.go @@ -0,0 +1,85 @@ +package wallet + +import ( + "testing" + "time" +) + +// TestLocateBirthdayBlock ensures we can properly map a block in the chain to a +//timestamp. +func TestLocateBirthdayBlock(t *testing.T) { + t.Parallel() + + // We'll use test chains of 30 blocks with a duration between two + // consecutive blocks being slightly greater than the largest margin + // allowed by locateBirthdayBlock. Doing so lets us test the method more + // effectively as there is only one block within the chain that can map + // to a timestamp (this does not apply to the first and last blocks, + // which can map to many timestamps beyond either end of chain). + const ( + numBlocks = 30 + blockInterval = birthdayBlockDelta + 1 + ) + + genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp + + testCases := []struct { + name string + birthday time.Time + birthdayHeight int32 + }{ + { + name: "left-right-left-left", + birthday: genesisTimestamp.Add(8 * blockInterval), + birthdayHeight: 8, + }, + { + name: "right-right-right-left", + birthday: genesisTimestamp.Add(27 * blockInterval), + birthdayHeight: 27, + }, + { + name: "before start height", + birthday: genesisTimestamp.Add(-blockInterval), + birthdayHeight: 0, + }, + { + name: "start height", + birthday: genesisTimestamp, + birthdayHeight: 0, + }, + { + name: "end height", + birthday: genesisTimestamp.Add(numBlocks * blockInterval), + birthdayHeight: numBlocks - 1, + }, + { + name: "after end height", + birthday: genesisTimestamp.Add(2 * numBlocks * blockInterval), + birthdayHeight: numBlocks - 1, + }, + } + + for _, testCase := range testCases { + success := t.Run(testCase.name, func(t *testing.T) { + chainConn := createMockChainConn( + chainParams.GenesisBlock, numBlocks, blockInterval, + ) + birthdayBlock, err := locateBirthdayBlock( + chainConn, testCase.birthday, + ) + if err != nil { + t.Fatalf("unable to locate birthday block: %v", + err) + } + if birthdayBlock.Height != testCase.birthdayHeight { + t.Fatalf("expected birthday block with height "+ + "%d, got %d", testCase.birthdayHeight, + birthdayBlock.Height) + } + }) + if !success { + break + } + } +}