diff --git a/btcwallet.go b/btcwallet.go index d6ce549..1a5d410 100644 --- a/btcwallet.go +++ b/btcwallet.go @@ -16,7 +16,6 @@ import ( "github.com/btcsuite/btcwallet/chain" "github.com/btcsuite/btcwallet/rpc/legacyrpc" "github.com/btcsuite/btcwallet/wallet" - "github.com/btcsuite/btcwallet/walletdb" ) var ( @@ -77,18 +76,8 @@ func walletMain() error { go rpcClientConnectLoop(legacyRPCServer, loader) } - var closeDB func() error - defer func() { - if closeDB != nil { - err := closeDB() - if err != nil { - log.Errorf("Unable to close wallet database: %v", err) - } - } - }() - loader.RunAfterLoad(func(w *wallet.Wallet, db walletdb.DB) { + loader.RunAfterLoad(func(w *wallet.Wallet) { startWalletRPCServices(w, rpcs, legacyRPCServer) - closeDB = db.Close }) if !cfg.NoInitialLoad { @@ -101,7 +90,15 @@ func walletMain() error { } } - // Shutdown the server(s) when interrupt signal is received. + // Add interrupt handlers to shutdown the various process components + // before exiting. Interrupt handlers run in LIFO order, so the wallet + // (which should be closed last) is added first. + addInterruptHandler(func() { + err := loader.UnloadWallet() + if err != nil && err != wallet.ErrNotLoaded { + log.Errorf("Failed to close wallet: %v", err) + } + }) if rpcs != nil { addInterruptHandler(func() { // TODO: Does this need to wait for the grpc server to @@ -112,15 +109,15 @@ func walletMain() error { }) } if legacyRPCServer != nil { - go func() { - <-legacyRPCServer.RequestProcessShutdown() - simulateInterrupt() - }() addInterruptHandler(func() { log.Warn("Stopping legacy RPC server...") legacyRPCServer.Stop() log.Info("Legacy RPC server shutdown") }) + go func() { + <-legacyRPCServer.RequestProcessShutdown() + simulateInterrupt() + }() } <-interruptHandlersDone @@ -158,7 +155,7 @@ func rpcClientConnectLoop(legacyRPCServer *legacyrpc.Server, loader *wallet.Load } } mu := new(sync.Mutex) - loader.RunAfterLoad(func(w *wallet.Wallet, db walletdb.DB) { + loader.RunAfterLoad(func(w *wallet.Wallet) { mu.Lock() associate := associateRPCClient mu.Unlock() diff --git a/rpc/rpcserver/server.go b/rpc/rpcserver/server.go index c46a883..7d7bdd9 100644 --- a/rpc/rpcserver/server.go +++ b/rpc/rpcserver/server.go @@ -806,13 +806,13 @@ func (s *loaderServer) WalletExists(ctx context.Context, req *pb.WalletExistsReq func (s *loaderServer) CloseWallet(ctx context.Context, req *pb.CloseWalletRequest) ( *pb.CloseWalletResponse, error) { - loadedWallet, ok := s.loader.LoadedWallet() - if !ok { + err := s.loader.UnloadWallet() + if err == wallet.ErrNotLoaded { return nil, grpc.Errorf(codes.FailedPrecondition, "wallet is not loaded") } - - loadedWallet.Stop() - loadedWallet.WaitForShutdown() + if err != nil { + return nil, translateError(err) + } return &pb.CloseWalletResponse{}, nil } diff --git a/wallet/loader.go b/wallet/loader.go index 19eb974..20a8e9b 100644 --- a/wallet/loader.go +++ b/wallet/loader.go @@ -27,6 +27,10 @@ var ( // create a wallet when the loader has already done so. ErrLoaded = errors.New("wallet already loaded") + // ErrNotLoaded describes the error condition of attempting to close a + // loaded wallet when a wallet has not been loaded. + ErrNotLoaded = errors.New("wallet is not loaded") + // ErrExists describes the error condition of attempting to create a new // wallet when one exists already. ErrExists = errors.New("wallet already exists") @@ -40,7 +44,7 @@ var ( // // Loader is safe for concurrent access. type Loader struct { - callbacks []func(*Wallet, walletdb.DB) + callbacks []func(*Wallet) chainParams *chaincfg.Params dbDirPath string wallet *Wallet @@ -60,7 +64,7 @@ func NewLoader(chainParams *chaincfg.Params, dbDirPath string) *Loader { // additional wallets. Requires mutex to be locked. func (l *Loader) onLoaded(w *Wallet, db walletdb.DB) { for _, fn := range l.callbacks { - fn(w, db) + fn(w) } l.wallet = w @@ -71,13 +75,12 @@ func (l *Loader) onLoaded(w *Wallet, db walletdb.DB) { // RunAfterLoad adds a function to be executed when the loader creates or opens // a wallet. Functions are executed in a single goroutine in the order they are // added. -func (l *Loader) RunAfterLoad(fn func(*Wallet, walletdb.DB)) { +func (l *Loader) RunAfterLoad(fn func(*Wallet)) { l.mu.Lock() if l.wallet != nil { w := l.wallet - db := l.db l.mu.Unlock() - fn(w, db) + fn(w) } else { l.callbacks = append(l.callbacks, fn) l.mu.Unlock() @@ -157,6 +160,7 @@ func (l *Loader) CreateNewWallet(pubPassphrase, privPassphrase, seed []byte) (*W if err != nil { return nil, err } + w.Start() l.onLoaded(w, db) return w, nil @@ -240,6 +244,30 @@ func (l *Loader) LoadedWallet() (*Wallet, bool) { return w, w != nil } +// UnloadWallet stops the loaded wallet, if any, and closes the wallet database. +// This returns ErrNotLoaded if the wallet has not been loaded with +// CreateNewWallet or LoadExistingWallet. The Loader may be reused if this +// function returns without error. +func (l *Loader) UnloadWallet() error { + defer l.mu.Unlock() + l.mu.Lock() + + if l.wallet == nil { + return ErrNotLoaded + } + + l.wallet.Stop() + l.wallet.WaitForShutdown() + err := l.db.Close() + if err != nil { + return err + } + + l.wallet = nil + l.db = nil + return nil +} + func fileExists(filePath string) (bool, error) { _, err := os.Stat(filePath) if err != nil { diff --git a/walletsetup.go b/walletsetup.go index 42b4178..5a38747 100644 --- a/walletsetup.go +++ b/walletsetup.go @@ -148,7 +148,7 @@ func createWallet(cfg *config) error { // Import the addresses in the legacy keystore to the new wallet if // any exist, locking each wallet again when finished. - loader.RunAfterLoad(func(w *wallet.Wallet, db walletdb.DB) { + loader.RunAfterLoad(func(w *wallet.Wallet) { defer legacyKeyStore.Lock() fmt.Println("Importing addresses from existing wallet...")