diff --git a/common_test.go b/common_test.go index 340718a7..22751ca6 100644 --- a/common_test.go +++ b/common_test.go @@ -9,12 +9,16 @@ import ( "github.com/conformal/btcchain" "github.com/conformal/btcdb" _ "github.com/conformal/btcdb/ldb" + _ "github.com/conformal/btcdb/memdb" "github.com/conformal/btcutil" "github.com/conformal/btcwire" "os" "path/filepath" ) +// testDbType is the database backend type to use for the tests. +const testDbType = "memdb" + // testDbRoot is the root directory used to create all test databases. const testDbRoot = "testdbs" @@ -28,41 +32,78 @@ func fileExists(name string) bool { return true } +// isSupportedDbType returns whether or not the passed database type is +// currently supported. +func isSupportedDbType(dbType string) bool { + supportedDBs := btcdb.SupportedDBs() + for _, sDbType := range supportedDBs { + if dbType == sDbType { + return true + } + } + + return false +} + // chainSetup is used to create a new db and chain instance with the genesis // block already inserted. In addition to the new chain instnce, it returns // a teardown function the caller should invoke when done testing to clean up. func chainSetup(dbName string) (*btcchain.BlockChain, func(), error) { - // Create the root directory for test databases. - if !fileExists(testDbRoot) { - if err := os.MkdirAll(testDbRoot, 0700); err != nil { - err := fmt.Errorf("unable to create test db root: %v", err) - return nil, nil, err + if !isSupportedDbType(testDbType) { + return nil, nil, fmt.Errorf("unsupported db type %v", testDbType) + } + + // Handle memory database specially since it doesn't need the disk + // specific handling. + var db btcdb.Db + var teardown func() + if testDbType == "memdb" { + ndb, err := btcdb.CreateDB(testDbType, "") + if err != nil { + return nil, nil, fmt.Errorf("error creating db: %v", err) } - } + db = ndb - // Create a new database to store the accepted blocks into. - dbPath := filepath.Join(testDbRoot, dbName) - _ = os.RemoveAll(dbPath) - db, err := btcdb.CreateDB("leveldb", dbPath) - if err != nil { - return nil, nil, fmt.Errorf("error creating db: %v", err) - } + // Setup a teardown function for cleaning up. This function is + // returned to the caller to be invoked when it is done testing. + teardown = func() { + db.Close() + } + } else { + // Create the root directory for test databases. + if !fileExists(testDbRoot) { + if err := os.MkdirAll(testDbRoot, 0700); err != nil { + err := fmt.Errorf("unable to create test db "+ + "root: %v", err) + return nil, nil, err + } + } - // Setup a teardown function for cleaning up. This function is returned - // to the caller to be invoked when it is done testing. - teardown := func() { - dbVersionPath := filepath.Join(testDbRoot, dbName+".ver") - db.Sync() - db.Close() - os.RemoveAll(dbPath) - os.Remove(dbVersionPath) - os.RemoveAll(testDbRoot) + // Create a new database to store the accepted blocks into. + dbPath := filepath.Join(testDbRoot, dbName) + _ = os.RemoveAll(dbPath) + ndb, err := btcdb.CreateDB(testDbType, dbPath) + if err != nil { + return nil, nil, fmt.Errorf("error creating db: %v", err) + } + db = ndb + + // Setup a teardown function for cleaning up. This function is + // returned to the caller to be invoked when it is done testing. + teardown = func() { + dbVersionPath := filepath.Join(testDbRoot, dbName+".ver") + db.Sync() + db.Close() + os.RemoveAll(dbPath) + os.Remove(dbVersionPath) + os.RemoveAll(testDbRoot) + } } // Insert the main network genesis block. This is part of the initial // database setup. genesisBlock := btcutil.NewBlock(&btcwire.GenesisBlock) - _, err = db.InsertBlock(genesisBlock) + _, err := db.InsertBlock(genesisBlock) if err != nil { teardown() err := fmt.Errorf("failed to insert genesis block: %v", err)