diff --git a/waddrmgr/internal_test.go b/waddrmgr/internal_test.go index 211a8c7..3ddb75d 100644 --- a/waddrmgr/internal_test.go +++ b/waddrmgr/internal_test.go @@ -23,7 +23,11 @@ interface. The functions are only exported while the tests are being run. package waddrmgr -import "github.com/conformal/btcwallet/snacl" +import ( + "errors" + + "github.com/conformal/btcwallet/snacl" +) // TstMaxRecentHashes makes the unexported maxRecentHashes constant available // when tests are run. @@ -51,3 +55,24 @@ func (m *Manager) TstCheckPublicPassphrase(pubPassphrase []byte) bool { err := secretKey.DeriveKey(&pubPassphrase) return err == nil } + +type failingCryptoKey struct { + cryptoKey +} + +func (c *failingCryptoKey) Encrypt(in []byte) ([]byte, error) { + return nil, errors.New("failed to encrypt") +} + +func (c *failingCryptoKey) Decrypt(in []byte) ([]byte, error) { + return nil, errors.New("failed to decrypt") +} + +func TstRunWithFailingCryptoKeyPriv(m *Manager, callback func()) { + orig := m.cryptoKeyPriv + defer func() { + m.cryptoKeyPriv = orig + }() + m.cryptoKeyPriv = &failingCryptoKey{} + callback() +} diff --git a/waddrmgr/manager_test.go b/waddrmgr/manager_test.go index 1afbabb..66d6f26 100644 --- a/waddrmgr/manager_test.go +++ b/waddrmgr/manager_test.go @@ -19,6 +19,7 @@ package waddrmgr_test import ( "encoding/hex" "fmt" + "io/ioutil" "os" "reflect" "testing" @@ -1508,3 +1509,105 @@ func TestManager(t *testing.T) { t.Errorf("Unlock: unexpected error: %v", err) } } + +func setUp(t *testing.T) (tearDownFunc func(), mgr *waddrmgr.Manager) { + t.Parallel() + // Create a new manager. + // We create the file and immediately delete it, as the waddrmgr + // needs to be doing the creating. + file, err := ioutil.TempDir("", "pool_test") + if err != nil { + t.Fatalf("Failed to create db file: %v", err) + } + os.Remove(file) + mgr, err = waddrmgr.Create(file, seed, pubPassphrase, privPassphrase, + &btcnet.MainNetParams, fastScrypt) + if err != nil { + t.Fatalf("Failed to create Manager: %v", err) + } + tearDownFunc = func() { + os.Remove(file) + mgr.Close() + } + return tearDownFunc, mgr +} + +func TestEncryptDecryptErrors(t *testing.T) { + teardown, mgr := setUp(t) + defer teardown() + + invalidKeyType := waddrmgr.CryptoKeyType(0xff) + if _, err := mgr.Encrypt(invalidKeyType, []byte{}); err == nil { + t.Fatalf("Encrypt accepted an invalid key type!") + } + + if _, err := mgr.Decrypt(invalidKeyType, []byte{}); err == nil { + t.Fatalf("Encrypt accepted an invalid key type!") + } + + if !mgr.IsLocked() { + t.Fatal("Manager should be locked at this point.") + } + + var err error + // Now the mgr is locked and encrypting/decrypting with private + // keys should fail. + _, err = mgr.Encrypt(waddrmgr.CKTPrivate, []byte{}) + checkManagerError(t, "encryption with private key fails when manager is locked", + err, waddrmgr.ErrLocked) + + _, err = mgr.Decrypt(waddrmgr.CKTPrivate, []byte{}) + checkManagerError(t, "decryption with private key fails when manager is locked", + err, waddrmgr.ErrLocked) + + // Unlock the manager for these tests + if err = mgr.Unlock(privPassphrase); err != nil { + t.Fatal("Attempted to unlock the manager, but failed:", err) + } + + // Make sure to cover the ErrCrypto error path in Encrypt. + waddrmgr.TstRunWithFailingCryptoKeyPriv(mgr, func() { + _, err = mgr.Encrypt(waddrmgr.CKTPrivate, []byte{}) + }) + checkManagerError(t, "failed encryption", err, waddrmgr.ErrCrypto) + + // Make sure to cover the ErrCrypto error path in Decrypt. + waddrmgr.TstRunWithFailingCryptoKeyPriv(mgr, func() { + _, err = mgr.Decrypt(waddrmgr.CKTPrivate, []byte{}) + }) + checkManagerError(t, "failed decryption", err, waddrmgr.ErrCrypto) +} + +func TestEncryptDecrypt(t *testing.T) { + teardown, mgr := setUp(t) + defer teardown() + + plainText := []byte("this is a plaintext") + + // Make sure address manager is unlocked + if err := mgr.Unlock(privPassphrase); err != nil { + t.Fatal("Attempted to unlock the manager, but failed:", err) + } + + keyTypes := []waddrmgr.CryptoKeyType{ + waddrmgr.CKTPublic, + waddrmgr.CKTPrivate, + waddrmgr.CKTScript, + } + + for _, keyType := range keyTypes { + cipherText, err := mgr.Encrypt(keyType, plainText) + if err != nil { + t.Fatalf("Failed to encrypt plaintext: %v", err) + } + + decryptedCipherText, err := mgr.Decrypt(keyType, cipherText) + if err != nil { + t.Fatalf("Failed to decrypt plaintext: %v", err) + } + + if !reflect.DeepEqual(decryptedCipherText, plainText) { + t.Fatal("Got:", decryptedCipherText, ", want:", plainText) + } + } +}