From a795db6b12ffab7daa34f726d3cf11c059d1e9a7 Mon Sep 17 00:00:00 2001 From: Andras Banki-Horvath Date: Tue, 27 Apr 2021 22:18:50 +0200 Subject: [PATCH] wallet: support for external wallet DB --- wallet/loader.go | 123 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 86 insertions(+), 37 deletions(-) diff --git a/wallet/loader.go b/wallet/loader.go index 029114c..32f3bd5 100644 --- a/wallet/loader.go +++ b/wallet/loader.go @@ -6,6 +6,7 @@ package wallet import ( "errors" + "fmt" "os" "path/filepath" "sync" @@ -55,6 +56,8 @@ type Loader struct { timeout time.Duration recoveryWindow uint32 wallet *Wallet + localDB bool + walletExists func() (bool, error) db walletdb.DB mu sync.Mutex } @@ -72,18 +75,42 @@ func NewLoader(chainParams *chaincfg.Params, dbDirPath string, noFreelistSync: noFreelistSync, timeout: timeout, recoveryWindow: recoveryWindow, + localDB: true, } } +// NewLoaderWithDB constructs a Loader with an externally provided DB. This way +// users are free to use their own walletdb implementation (eg. leveldb, etcd) +// to store the wallet. Given that the external DB may be shared an additional +// function is also passed which will override Loader.WalletExists(). +func NewLoaderWithDB(chainParams *chaincfg.Params, recoveryWindow uint32, + db walletdb.DB, walletExists func() (bool, error)) (*Loader, error) { + + if db == nil { + return nil, fmt.Errorf("no DB provided") + } + + if walletExists == nil { + return nil, fmt.Errorf("unable to check if wallet exists") + } + + return &Loader{ + chainParams: chainParams, + recoveryWindow: recoveryWindow, + localDB: false, + walletExists: walletExists, + db: db, + }, nil +} + // onLoaded executes each added callback and prevents loader from loading any // additional wallets. Requires mutex to be locked. -func (l *Loader) onLoaded(w *Wallet, db walletdb.DB) { +func (l *Loader) onLoaded(w *Wallet) { for _, fn := range l.callbacks { fn(w) } l.wallet = w - l.db = db l.callbacks = nil // not needed anymore } @@ -134,8 +161,7 @@ func (l *Loader) createNewWallet(pubPassphrase, privPassphrase, return nil, ErrLoaded } - dbPath := filepath.Join(l.dbDirPath, WalletDBName) - exists, err := fileExists(dbPath) + exists, err := l.WalletExists() if err != nil { return nil, err } @@ -143,25 +169,34 @@ func (l *Loader) createNewWallet(pubPassphrase, privPassphrase, return nil, ErrExists } - // Create the wallet database backed by bolt db. - err = os.MkdirAll(l.dbDirPath, 0700) - if err != nil { - return nil, err - } - db, err := walletdb.Create("bdb", dbPath, l.noFreelistSync, l.timeout) - if err != nil { - return nil, err + if l.localDB { + dbPath := filepath.Join(l.dbDirPath, WalletDBName) + + // Create the wallet database backed by bolt db. + err = os.MkdirAll(l.dbDirPath, 0700) + if err != nil { + return nil, err + } + l.db, err = walletdb.Create( + "bdb", dbPath, l.noFreelistSync, l.timeout, + ) + if err != nil { + return nil, err + } } // Initialize the newly created database for the wallet before opening. if isWatchingOnly { - err = CreateWatchingOnly(db, pubPassphrase, l.chainParams, bday) + err := CreateWatchingOnly( + l.db, pubPassphrase, l.chainParams, bday, + ) if err != nil { return nil, err } } else { - err = Create( - db, pubPassphrase, privPassphrase, seed, l.chainParams, bday, + err := Create( + l.db, pubPassphrase, privPassphrase, seed, + l.chainParams, bday, ) if err != nil { return nil, err @@ -169,13 +204,13 @@ func (l *Loader) createNewWallet(pubPassphrase, privPassphrase, } // Open the newly-created wallet. - w, err := Open(db, pubPassphrase, nil, l.chainParams, l.recoveryWindow) + w, err := Open(l.db, pubPassphrase, nil, l.chainParams, l.recoveryWindow) if err != nil { return nil, err } w.Start() - l.onLoaded(w, db) + l.onLoaded(w) return w, nil } @@ -197,17 +232,22 @@ func (l *Loader) OpenExistingWallet(pubPassphrase []byte, canConsolePrompt bool) return nil, ErrLoaded } - // Ensure that the network directory exists. - if err := checkCreateDir(l.dbDirPath); err != nil { - return nil, err - } + if l.localDB { + var err error + // Ensure that the network directory exists. + if err = checkCreateDir(l.dbDirPath); err != nil { + return nil, err + } - // Open the database using the boltdb backend. - dbPath := filepath.Join(l.dbDirPath, WalletDBName) - db, err := walletdb.Open("bdb", dbPath, l.noFreelistSync, l.timeout) - if err != nil { - log.Errorf("Failed to open database: %v", err) - return nil, err + // Open the database using the boltdb backend. + dbPath := filepath.Join(l.dbDirPath, WalletDBName) + l.db, err = walletdb.Open( + "bdb", dbPath, l.noFreelistSync, l.timeout, + ) + if err != nil { + log.Errorf("Failed to open database: %v", err) + return nil, err + } } var cbs *waddrmgr.OpenCallbacks @@ -222,28 +262,35 @@ func (l *Loader) OpenExistingWallet(pubPassphrase []byte, canConsolePrompt bool) ObtainPrivatePass: noConsole, } } - w, err := Open(db, pubPassphrase, cbs, l.chainParams, l.recoveryWindow) + w, err := Open(l.db, pubPassphrase, cbs, l.chainParams, l.recoveryWindow) if err != nil { // If opening the wallet fails (e.g. because of wrong // passphrase), we must close the backing database to // allow future calls to walletdb.Open(). - e := db.Close() - if e != nil { - log.Warnf("Error closing database: %v", e) + if l.localDB { + e := l.db.Close() + if e != nil { + log.Warnf("Error closing database: %v", e) + } } + return nil, err } w.Start() - l.onLoaded(w, db) + l.onLoaded(w) return w, nil } // WalletExists returns whether a file exists at the loader's database path. // This may return an error for unexpected I/O failures. func (l *Loader) WalletExists() (bool, error) { - dbPath := filepath.Join(l.dbDirPath, WalletDBName) - return fileExists(dbPath) + if l.localDB { + dbPath := filepath.Join(l.dbDirPath, WalletDBName) + return fileExists(dbPath) + } + + return l.walletExists() } // LoadedWallet returns the loaded wallet, if any, and a bool for whether the @@ -270,9 +317,11 @@ func (l *Loader) UnloadWallet() error { l.wallet.Stop() l.wallet.WaitForShutdown() - err := l.db.Close() - if err != nil { - return err + if l.localDB { + err := l.db.Close() + if err != nil { + return err + } } l.wallet = nil