From 63e1bc44d9c37db938f0917bd3314dfb02bd2bbc Mon Sep 17 00:00:00 2001 From: MohitKatare-protean Date: Tue, 20 May 2025 12:01:14 +0530 Subject: [PATCH] added test cases for keymanager and plugin --- .../implementation/keymanager/cmd/plugin.go | 9 +- .../keymanager/cmd/plugin_test.go | 164 +++ .../implementation/keymanager/keymanager.go | 25 +- .../keymanager/keymanager_test.go | 1182 +++++++++++++++++ 4 files changed, 1373 insertions(+), 7 deletions(-) create mode 100644 pkg/plugin/implementation/keymanager/cmd/plugin_test.go create mode 100644 pkg/plugin/implementation/keymanager/keymanager_test.go diff --git a/pkg/plugin/implementation/keymanager/cmd/plugin.go b/pkg/plugin/implementation/keymanager/cmd/plugin.go index 9326a8b..5e37af4 100644 --- a/pkg/plugin/implementation/keymanager/cmd/plugin.go +++ b/pkg/plugin/implementation/keymanager/cmd/plugin.go @@ -9,7 +9,12 @@ import ( ) // keyManagerProvider implements the plugin provider for the KeyManager plugin. -type keyManagerProvider struct{} +type keyManagerProvider struct { + newFunc func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (definition.KeyManager, func() error, error) +} + +// newKeyManagerFunc is a function type that creates a new KeyManager instance. +var newKeyManagerFunc = keymanager.New // New creates and initializes a new KeyManager instance using the provided cache, registry lookup, and configuration. func (k *keyManagerProvider) New(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg map[string]string) (definition.KeyManager, func() error, error) { @@ -18,7 +23,7 @@ func (k *keyManagerProvider) New(ctx context.Context, cache definition.Cache, re KVVersion: cfg["kv_version"], } log.Debugf(ctx, "Keymanager config mapped: %+v", cfg) - km, cleanup, err := keymanager.New(ctx, cache, registry, config) + km, cleanup, err := newKeyManagerFunc(ctx, cache, registry, config) if err != nil { log.Error(ctx, err, "Failed to initialize KeyManager") return nil, nil, err diff --git a/pkg/plugin/implementation/keymanager/cmd/plugin_test.go b/pkg/plugin/implementation/keymanager/cmd/plugin_test.go new file mode 100644 index 0000000..bec7c6f --- /dev/null +++ b/pkg/plugin/implementation/keymanager/cmd/plugin_test.go @@ -0,0 +1,164 @@ +package main + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/beckn/beckn-onix/pkg/model" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/beckn/beckn-onix/pkg/plugin/implementation/keymanager" +) + +// Mock KeyManager implementation +type mockKeyManager struct{} + +func (m *mockKeyManager) SigningPublicKey(ctx context.Context, subscriberID, keyID string) (string, error) { + return "mock-signing-public-key", nil +} + +func (m *mockKeyManager) SigningPrivateKey(ctx context.Context, subscriberID string) (string, string, error) { + return "mock-key-id", "mock-signing-private-key", nil +} + +func (m *mockKeyManager) EncrPublicKey(ctx context.Context, subscriberID, keyID string) (string, error) { + return "mock-encryption-public-key", nil +} + +func (m *mockKeyManager) EncrPrivateKey(ctx context.Context, subscriberID string) (string, string, error) { + return "mock-key-id", "mock-encryption-private-key", nil +} + +func (m *mockKeyManager) DeletePrivateKeys(ctx context.Context, subscriberID string) error { + return nil +} + +func (m *mockKeyManager) StorePrivateKeys(ctx context.Context, subscriberID string, keys *model.Keyset) error { + return nil +} + +func (m *mockKeyManager) GenerateKeyPairs() (*model.Keyset, error) { + return &model.Keyset{ + UniqueKeyID: "mock-key-id", + SigningPrivate: "mock-signing-private-key", + SigningPublic: "mock-signing-public-key", + EncrPrivate: "mock-encryption-private-key", + EncrPublic: "mock-encryption-public-key", + }, nil +} + +type mockRegistry struct { + LookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) +} + +func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + if m.LookupFunc != nil { + return m.LookupFunc(ctx, sub) + } + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + URL: "https://mock.registry/subscriber", + Type: "BPP", + Domain: "retail", + }, + KeyID: sub.KeyID, + SigningPublicKey: "mock-signing-public-key", + EncrPublicKey: "mock-encryption-public-key", + ValidFrom: time.Now().Add(-time.Hour), + ValidUntil: time.Now().Add(time.Hour), + Status: "SUBSCRIBED", + Created: time.Now().Add(-2 * time.Hour), + Updated: time.Now(), + Nonce: "mock-nonce", + }, + }, nil +} + +type mockCache struct{} + +func (m *mockCache) Get(ctx context.Context, key string) (string, error) { + return "", nil +} +func (m *mockCache) Set(ctx context.Context, key string, value string, ttl time.Duration) error { + return nil +} +func (m *mockCache) Clear(ctx context.Context) error { + return nil +} + +func (m *mockCache) Delete(ctx context.Context, key string) error { + return nil +} + +func TestNewSuccess(t *testing.T) { + // Setup dummy implementations and variables + ctx := context.Background() + cache := &mockCache{} + registry := &mockRegistry{} + cfg := map[string]string{ + "vault_addr": "http://dummy-vault", + "kv_version": "2", + } + + cleanupCalled := false + fakeCleanup := func() error { + cleanupCalled = true + return nil + } + + newKeyManagerFunc = func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (*keymanager.KeyMgr, func() error, error) { + // return a mock struct pointer of *keymanager.KeyMgr or a stub instance + return &keymanager.KeyMgr{}, fakeCleanup, nil + } + + // Create provider and call New + provider := &keyManagerProvider{} + km, cleanup, err := provider.New(ctx, cache, registry, cfg) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if km == nil { + t.Fatal("Expected non-nil KeyManager instance") + } + if cleanup == nil { + t.Fatal("Expected non-nil cleanup function") + } + + // Call cleanup and check if it behaves correctly + if err := cleanup(); err != nil { + t.Fatalf("Expected no error from cleanup, got %v", err) + } + if !cleanupCalled { + t.Error("Expected cleanup function to be called") + } +} + +func TestNewFailure(t *testing.T) { + // Setup dummy variables + ctx := context.Background() + cache := &mockCache{} + registry := &mockRegistry{} + cfg := map[string]string{ + "vault_addr": "http://dummy-vault", + "kv_version": "2", + } + + newKeyManagerFunc = func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (*keymanager.KeyMgr, func() error, error) { + return nil, nil, fmt.Errorf("some error") + } + + provider := &keyManagerProvider{} + km, cleanup, err := provider.New(ctx, cache, registry, cfg) + if err == nil { + t.Fatal("Expected error, got nil") + } + if km != nil { + t.Error("Expected nil KeyManager on error") + } + if cleanup != nil { + t.Error("Expected nil cleanup function on error") + } +} diff --git a/pkg/plugin/implementation/keymanager/keymanager.go b/pkg/plugin/implementation/keymanager/keymanager.go index 229fa2b..c50079d 100644 --- a/pkg/plugin/implementation/keymanager/keymanager.go +++ b/pkg/plugin/implementation/keymanager/keymanager.go @@ -70,6 +70,11 @@ func ValidateCfg(cfg *Config) error { return nil } +// getVaultClient is a function that creates a new Vault client. +// This is exported for testing purposes. +var getVaultClient = GetVaultClient + +// New creates a new KeyMgr instance with the provided configuration, cache, and registry lookup. func New(ctx context.Context, cache definition.Cache, registryLookup definition.RegistryLookup, cfg *Config) (*KeyMgr, func() error, error) { log.Info(ctx, "Initializing KeyManager plugin") // Validate configuration. @@ -91,7 +96,7 @@ func New(ctx context.Context, cache definition.Cache, registryLookup definition. // Initialize Vault client. log.Debugf(ctx, "Creating Vault client with address: %s", cfg.VaultAddr) - vaultClient, err := GetVaultClient(ctx, cfg.VaultAddr) + vaultClient, err := getVaultClient(ctx, cfg.VaultAddr) if err != nil { log.Errorf(ctx, err, "Failed to create Vault client at address: %s", cfg.VaultAddr) return nil, nil, fmt.Errorf("failed to create vault client: %w", err) @@ -120,6 +125,10 @@ func New(ctx context.Context, cache definition.Cache, registryLookup definition. return km, cleanup, nil } +// NewVaultClient creates a new Vault client instance. +// This function is exported for testing purposes. +var NewVaultClient = vault.NewClient + // GetVaultClient creates and authenticates a Vault client using AppRole. func GetVaultClient(ctx context.Context, vaultAddr string) (*vault.Client, error) { roleID := os.Getenv("VAULT_ROLE_ID") @@ -133,7 +142,7 @@ func GetVaultClient(ctx context.Context, vaultAddr string) (*vault.Client, error config := vault.DefaultConfig() config.Address = vaultAddr - client, err := vault.NewClient(config) + client, err := NewVaultClient(config) if err != nil { log.Error(ctx, err, "failed to create Vault client") return nil, fmt.Errorf("failed to create Vault client: %w", err) @@ -160,19 +169,25 @@ func GetVaultClient(ctx context.Context, vaultAddr string) (*vault.Client, error return client, nil } +var ( + ed25519KeyGenFunc = ed25519.GenerateKey + x25519KeyGenFunc = ecdh.X25519().GenerateKey + uuidGenFunc = uuid.NewRandom +) + // GenerateKeyPairs generates a new signing (Ed25519) and encryption (X25519) key pair. func (km *KeyMgr) GenerateKeyPairs() (*model.Keyset, error) { - signingPublic, signingPrivate, err := ed25519.GenerateKey(rand.Reader) + signingPublic, signingPrivate, err := ed25519KeyGenFunc(rand.Reader) if err != nil { return nil, fmt.Errorf("failed to generate signing key pair: %w", err) } - encrPrivateKey, err := ecdh.X25519().GenerateKey(rand.Reader) + encrPrivateKey, err := x25519KeyGenFunc(rand.Reader) if err != nil { return nil, fmt.Errorf("failed to generate encryption key pair: %w", err) } encrPublicKey := encrPrivateKey.PublicKey().Bytes() - uuid, err := uuid.NewRandom() + uuid, err := uuidGenFunc() if err != nil { return nil, fmt.Errorf("failed to generate unique key id uuid: %w", err) } diff --git a/pkg/plugin/implementation/keymanager/keymanager_test.go b/pkg/plugin/implementation/keymanager/keymanager_test.go new file mode 100644 index 0000000..b08215e --- /dev/null +++ b/pkg/plugin/implementation/keymanager/keymanager_test.go @@ -0,0 +1,1182 @@ +package keymanager + +import ( + "context" + "crypto/ecdh" + "crypto/ed25519" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/beckn/beckn-onix/pkg/model" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/google/uuid" + "github.com/hashicorp/vault/api" + vault "github.com/hashicorp/vault/api" +) + +type mockRegistry struct { + LookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) +} + +func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + if m.LookupFunc != nil { + return m.LookupFunc(ctx, sub) + } + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + URL: "https://mock.registry/subscriber", + Type: "BPP", + Domain: "retail", + }, + KeyID: sub.KeyID, + SigningPublicKey: "mock-signing-public-key", + EncrPublicKey: "mock-encryption-public-key", + ValidFrom: time.Now().Add(-time.Hour), + ValidUntil: time.Now().Add(time.Hour), + Status: "SUBSCRIBED", + Created: time.Now().Add(-2 * time.Hour), + Updated: time.Now(), + Nonce: "mock-nonce", + }, + }, nil +} + +type mockCache struct{} + +func (m *mockCache) Get(ctx context.Context, key string) (string, error) { + return "", nil +} +func (m *mockCache) Set(ctx context.Context, key string, value string, ttl time.Duration) error { + return nil +} +func (m *mockCache) Clear(ctx context.Context) error { + return nil +} + +func (m *mockCache) Delete(ctx context.Context, key string) error { + return nil +} + +func TestValidateCfgSuccess(t *testing.T) { + tests := []struct { + name string + cfg *Config + wantKV string + }{ + { + name: "valid config with v1", + cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v1"}, + wantKV: "v1", + }, + { + name: "valid config with v2", + cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v2"}, + wantKV: "v2", + }, + { + name: "default KV version applied", + cfg: &Config{VaultAddr: "http://localhost:8200"}, + wantKV: "v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateCfg(tt.cfg) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if tt.cfg.KVVersion != tt.wantKV { + t.Errorf("expected KVVersion %s, got %s", tt.wantKV, tt.cfg.KVVersion) + } + }) + } +} + +func TestValidateCfgFailure(t *testing.T) { + tests := []struct { + name string + cfg *Config + }{ + { + name: "missing Vault address", + cfg: &Config{VaultAddr: "", KVVersion: "v1"}, + }, + { + name: "invalid KV version", + cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateCfg(tt.cfg) + if err == nil { + t.Errorf("expected error, got nil") + } + }) + } +} + +func TestGenerateKeyPairs(t *testing.T) { + originalEd25519 := ed25519KeyGenFunc + originalX25519 := x25519KeyGenFunc + originalUUID := uuidGenFunc + + defer func() { + ed25519KeyGenFunc = originalEd25519 + x25519KeyGenFunc = originalX25519 + uuidGenFunc = originalUUID + }() + + tests := []struct { + name string + mockEd25519Err error + mockX25519Err error + mockUUIDErr error + expectErr bool + }{ + { + name: "success case", + expectErr: false, + }, + { + name: "ed25519 key generation failure", + mockEd25519Err: errors.New("mock ed25519 failure"), + expectErr: true, + }, + { + name: "x25519 key generation failure", + mockX25519Err: errors.New("mock x25519 failure"), + expectErr: true, + }, + { + name: "UUID generation failure", + mockUUIDErr: errors.New("mock uuid failure"), + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.mockEd25519Err != nil { + ed25519KeyGenFunc = func(_ io.Reader) (ed25519.PublicKey, ed25519.PrivateKey, error) { + return nil, nil, tt.mockEd25519Err + } + } else { + ed25519KeyGenFunc = ed25519.GenerateKey + } + + if tt.mockX25519Err != nil { + x25519KeyGenFunc = func(_ io.Reader) (*ecdh.PrivateKey, error) { + return nil, tt.mockX25519Err + } + } else { + x25519KeyGenFunc = ecdh.X25519().GenerateKey + } + + if tt.mockUUIDErr != nil { + uuidGenFunc = func() (uuid.UUID, error) { + return uuid.Nil, tt.mockUUIDErr + } + } else { + uuidGenFunc = uuid.NewRandom + } + + km := &KeyMgr{} + keyset, err := km.GenerateKeyPairs() + + if tt.expectErr { + if err == nil { + t.Errorf("expected error, got nil") + } + if keyset != nil { + t.Errorf("expected nil keyset, got non-nil") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if keyset == nil { + t.Fatal("expected keyset, got nil") + } + if keyset.SigningPrivate == "" || keyset.SigningPublic == "" || keyset.EncrPrivate == "" || keyset.EncrPublic == "" { + t.Error("expected all keys to be populated and base64-encoded") + } + if keyset.UniqueKeyID == "" { + t.Error("expected UniqueKeyID to be non-empty") + } + } + }) + } +} + +type mockLogical struct { + writeFn func(path string, data map[string]interface{}) (*vault.Secret, error) +} + +func (m *mockLogical) Write(path string, data map[string]interface{}) (*vault.Secret, error) { + return m.writeFn(path, data) +} + +type mockClient struct { + *vault.Client + setTokenFn func(string) + logicalFn func() *vault.Logical +} + +func (m *mockClient) SetToken(token string) { + if m.setTokenFn != nil { + m.setTokenFn(token) + } +} + +func (m *mockClient) Logical() *vault.Logical { + if m.logicalFn != nil { + return m.logicalFn() + } + return &vault.Logical{} +} + +func TestGetVaultClient_Failures(t *testing.T) { + originalNewVaultClient := NewVaultClient + defer func() { NewVaultClient = originalNewVaultClient }() + + ctx := context.Background() + + tests := []struct { + name string + roleID string + secretID string + setupServer func(t *testing.T) *httptest.Server + expectErr string + }{ + { + name: "missing credentials", + roleID: "", + secretID: "", + expectErr: "VAULT_ROLE_ID or VAULT_SECRET_ID is not set", + }, + { + name: "vault client creation failure", + roleID: "test-role", + secretID: "test-secret", + setupServer: func(t *testing.T) *httptest.Server { + NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) { + return nil, errors.New("mock client creation error") + } + return nil + }, + expectErr: "failed to create Vault client: mock client creation error", + }, + { + name: "AppRole login failure", + roleID: "test-role", + secretID: "test-secret", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "login failed", http.StatusBadRequest) + })) + }, + expectErr: "failed to login with AppRole: Error making API request", + }, + { + name: "AppRole login returns nil auth", + roleID: "test-role", + secretID: "test-secret", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{ "auth": null }`) + })) + }, + expectErr: "AppRole login failed: no auth info returned", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv("VAULT_ROLE_ID", tt.roleID) + os.Setenv("VAULT_SECRET_ID", tt.secretID) + + var server *httptest.Server + if tt.setupServer != nil { + server = tt.setupServer(t) + if server != nil { + NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) { + cfg.Address = server.URL + return vault.NewClient(cfg) + } + defer server.Close() + } + } + + client, err := GetVaultClient(ctx, "http://ignored") + if err == nil || !strings.Contains(err.Error(), tt.expectErr) { + t.Errorf("expected error to contain '%s', got: %v", tt.expectErr, err) + } + if client != nil { + t.Error("expected client to be nil on failure") + } + }) + } +} + +func TestGetVaultClient_Success(t *testing.T) { + originalNewVaultClient := NewVaultClient + defer func() { NewVaultClient = originalNewVaultClient }() + + ctx := context.Background() + + os.Setenv("VAULT_ROLE_ID", "test-role") + os.Setenv("VAULT_SECRET_ID", "test-secret") + + // Mock Vault server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(r.URL.Path, "/auth/approle/login") { + t.Errorf("unexpected request path: %s", r.URL.Path) + } + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{ + "auth": { + "client_token": "mock-token" + } + }`) + })) + defer server.Close() + + NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) { + cfg.Address = server.URL + return vault.NewClient(cfg) + } + + client, err := GetVaultClient(ctx, "http://ignored") + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } + if token := client.Token(); token != "mock-token" { + t.Errorf("expected token to be 'mock-token', got: %s", token) + } +} + +type mockRegistryLookup struct{} + +func (m *mockRegistryLookup) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + Type: sub.Type, + }, + KeyID: "mock-key-id", + SigningPublicKey: "mock-signing-pubkey", + EncrPublicKey: "mock-encryption-pubkey", + ValidFrom: time.Now().Add(-time.Hour), + ValidUntil: time.Now().Add(time.Hour), + Status: "SUBSCRIBED", + Created: time.Now(), + Updated: time.Now(), + Nonce: "mock-nonce", + }, + }, nil +} + +func TestNewSuccess(t *testing.T) { + tests := []struct { + name string + cfg *Config + cache definition.Cache + registry definition.RegistryLookup + mockVaultStatus int + mockVaultBody string + }{ + { + name: "valid config", + cfg: &Config{ + VaultAddr: "http://dummy", + KVVersion: "v2", + }, + cache: &mockCache{}, + registry: &mockRegistryLookup{}, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + } + + originalGetVaultClient := getVaultClient + defer func() { getVaultClient = originalGetVaultClient }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.mockVaultStatus) + fmt.Fprint(w, tt.mockVaultBody) + })) + defer vaultServer.Close() + + tt.cfg.VaultAddr = vaultServer.URL + + getVaultClient = func(ctx context.Context, addr string) (*vault.Client, error) { + cfg := vault.DefaultConfig() + cfg.Address = addr + return vault.NewClient(cfg) + } + + ctx := context.Background() + km, cleanup, err := New(ctx, tt.cache, tt.registry, tt.cfg) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if km == nil { + t.Fatalf("expected KeyMgr instance, got nil") + } + if cleanup == nil { + t.Fatalf("expected cleanup function, got nil") + } + _ = cleanup() + }) + } +} + +func TestNewFailure(t *testing.T) { + tests := []struct { + name string + cfg *Config + cache definition.Cache + registry definition.RegistryLookup + mockVaultStatus int + mockVaultBody string + }{ + { + name: "nil cache", + cfg: &Config{ + VaultAddr: "http://dummy", + KVVersion: "v2", + }, + cache: nil, + registry: &mockRegistryLookup{}, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + { + name: "nil registry", + cfg: &Config{ + VaultAddr: "http://dummy", + KVVersion: "v2", + }, + cache: &mockCache{}, + registry: nil, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + { + name: "invalid config", + cfg: &Config{ + VaultAddr: "", // Invalid + KVVersion: "v3", // Unsupported + }, + cache: &mockCache{}, + registry: &mockRegistryLookup{}, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + { + name: "vault client creation failure", + cfg: &Config{ + VaultAddr: "http://dummy", + KVVersion: "v2", + }, + cache: &mockCache{}, + registry: &mockRegistryLookup{}, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + } + + originalGetVaultClient := getVaultClient + defer func() { getVaultClient = originalGetVaultClient }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.mockVaultStatus) + fmt.Fprint(w, tt.mockVaultBody) + })) + defer vaultServer.Close() + + if tt.cfg != nil { + tt.cfg.VaultAddr = vaultServer.URL + } + + getVaultClient = func(ctx context.Context, addr string) (*vault.Client, error) { + if tt.name == "vault client creation failure" { + return nil, errors.New("simulated vault client creation error") + } + cfg := vault.DefaultConfig() + cfg.Address = addr + return vault.NewClient(cfg) + } + + ctx := context.Background() + km, cleanup, err := New(ctx, tt.cache, tt.registry, tt.cfg) + + if err == nil { + t.Error("expected error, got nil") + } + if km != nil { + t.Error("expected KeyMgr to be nil, got non-nil") + } + if cleanup != nil { + t.Error("expected cleanup to be nil, got non-nil") + } + }) + } + +} + +func TestStorePrivateKeysSuccess(t *testing.T) { + ctx := context.Background() + + keys := &model.Keyset{ + UniqueKeyID: "uuid", + SigningPublic: "signPub", + SigningPrivate: "signPriv", + EncrPublic: "encrPub", + EncrPrivate: "encrPriv", + } + + tests := []struct { + name string + kvVersion string + keyID string + keys *model.Keyset + }{ + { + name: "success kv v1", + kvVersion: "v1", + keyID: "mykeyid", + keys: keys, + }, + { + name: "success kv v2", + kvVersion: "v2", + keyID: "mykeyid", + keys: keys, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := "" + if tt.kvVersion == "v2" { + expectedPath = "/v1/secret/data/keys/" + tt.keyID + } else { + expectedPath = "/v1/secret/keys/" + tt.keyID + } + + if r.URL.Path != expectedPath { + t.Errorf("unexpected request path: got %s, want %s", r.URL.Path, expectedPath) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + fmt.Fprintln(w, `{"data":{}}`) + })) + defer server.Close() + + config := api.DefaultConfig() + config.Address = server.URL + client, err := api.NewClient(config) + if err != nil { + t.Fatalf("failed to create Vault client: %v", err) + } + + km := &KeyMgr{ + VaultClient: client, + KvVersion: tt.kvVersion, + } + + err = km.StorePrivateKeys(ctx, tt.keyID, tt.keys) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + }) + } +} + +func TestStorePrivateKeysFailure(t *testing.T) { + ctx := context.Background() + + keys := &model.Keyset{ + UniqueKeyID: "uuid", + SigningPublic: "signPub", + SigningPrivate: "signPriv", + EncrPublic: "encrPub", + EncrPrivate: "encrPriv", + } + + tests := []struct { + name string + kvVersion string + keyID string + keys *model.Keyset + statusCode int // for HTTP error simulation + expectedErr string + }{ + { + name: "empty keyID", + keyID: "", + keys: keys, + expectedErr: ErrEmptyKeyID.Error(), + }, + { + name: "nil keys", + keyID: "mykeyid", + keys: nil, + expectedErr: ErrNilKeySet.Error(), + }, + { + name: "vault write error", + kvVersion: "v1", + keyID: "mykeyid", + keys: keys, + statusCode: 500, + expectedErr: "failed to store secret in Vault: Error making API request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var server *httptest.Server + if tt.statusCode != 0 { + // Setup test HTTP server to simulate Vault error + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal error", tt.statusCode) + })) + defer server.Close() + } + + var client *api.Client + var err error + if server != nil { + config := api.DefaultConfig() + config.Address = server.URL + client, err = api.NewClient(config) + if err != nil { + t.Fatalf("failed to create Vault client: %v", err) + } + } else { + client = nil + } + + km := &KeyMgr{ + VaultClient: client, + KvVersion: tt.kvVersion, + } + + err = km.StorePrivateKeys(ctx, tt.keyID, tt.keys) + + if err == nil { + t.Fatalf("expected error %q but got nil", tt.expectedErr) + } + if !strings.Contains(err.Error(), tt.expectedErr) { + t.Errorf("expected error containing %q, got %v", tt.expectedErr, err) + } + }) + } +} + +func TestDeletePrivateKeys(t *testing.T) { + tests := []struct { + name string + kvVersion string + keyID string + wantPath string + wantErr error + }{ + { + name: "empty keyID", + kvVersion: "v1", + keyID: "", + wantErr: ErrEmptyKeyID, + }, + { + name: "v1 delete", + kvVersion: "v1", + keyID: "key123", + wantPath: "/v1/secret/private_keys/key123/data/key123", + wantErr: nil, + }, + { + name: "v2 delete", + kvVersion: "v2", + keyID: "key123", + wantPath: "/v1/secret/data/private_keys/key123/data/key123", + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // If empty keyID, no Vault calls, just check error + if tt.keyID == "" { + km := &KeyMgr{ + KvVersion: tt.kvVersion, + VaultClient: nil, + } + err := km.DeletePrivateKeys(context.Background(), tt.keyID) + if err != tt.wantErr { + t.Errorf("expected error %v, got %v", tt.wantErr, err) + } + return + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + t.Errorf("Expected DELETE method, got %s", r.Method) + } + if r.URL.Path != tt.wantPath { + t.Errorf("Expected path %s, got %s", tt.wantPath, r.URL.Path) + } + w.WriteHeader(http.StatusNoContent) + })) + defer ts.Close() + + vaultClient, err := NewVaultClient(&vault.Config{Address: ts.URL}) + if err != nil { + t.Fatalf("failed to create vault client: %v", err) + } + + km := &KeyMgr{ + KvVersion: tt.kvVersion, + VaultClient: vaultClient, + } + + err = km.DeletePrivateKeys(context.Background(), tt.keyID) + if err != tt.wantErr { + t.Errorf("DeletePrivateKeys() error = %v, want %v", err, tt.wantErr) + } + }) + } +} + +func setupMockVaultServer(t *testing.T, kvVersion, keyID string, success bool) *httptest.Server { + t.Helper() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check that request path is expected + expectedPathV1 := fmt.Sprintf("/v1/secret/private_keys/%s", keyID) + expectedPathV2 := fmt.Sprintf("/v1/secret/data/private_keys/%s", keyID) + + if (kvVersion == "v2" && r.URL.Path != expectedPathV2) || (kvVersion != "v2" && r.URL.Path != expectedPathV1) { + http.Error(w, "not found", http.StatusNotFound) + return + } + + if !success { + // Simulate Vault error or not found + http.Error(w, `{"errors":["key not found"]}`, http.StatusNotFound) + return + } + + // Success response JSON, different for v1 and v2 + if kvVersion == "v2" { + resp := fmt.Sprintf(`{ + "data": { + "data": { + "uniqueKeyID": "%s", + "signingPublicKey": "sign-pub", + "signingPrivateKey": "sign-priv", + "encrPublicKey": "encr-pub", + "encrPrivateKey": "encr-priv" + } + } + }`, keyID) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(resp)) + } else { + resp := fmt.Sprintf(`{ + "data": { + "uniqueKeyID": "%s", + "signingPublicKey": "sign-pub", + "signingPrivateKey": "sign-priv", + "encrPublicKey": "encr-pub", + "encrPrivateKey": "encr-priv" + } + }`, keyID) + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(resp)) + } + }) + + return httptest.NewServer(handler) +} + +func TestGetKeysSuccess(t *testing.T) { + tests := []struct { + name string + kvVersion string + keyID string + }{ + { + name: "success with KV v2", + kvVersion: "v2", + keyID: "test-key-v2", + }, + { + name: "success with KV v1", + kvVersion: "v1", + keyID: "test-key-v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts := setupMockVaultServer(t, tt.kvVersion, tt.keyID, true) + defer ts.Close() + + cfg := vault.DefaultConfig() + cfg.Address = ts.URL + + client, err := vault.NewClient(cfg) + if err != nil { + t.Fatalf("failed to create Vault client: %v", err) + } + + km := &KeyMgr{ + VaultClient: client, + KvVersion: tt.kvVersion, + } + + keys, err := km.getKeys(context.Background(), tt.keyID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if keys == nil { + t.Fatalf("expected keys but got nil") + } + if keys.UniqueKeyID != tt.keyID { + t.Errorf("expected UniqueKeyID %q, got %q", tt.keyID, keys.UniqueKeyID) + } + if keys.SigningPrivate != "sign-priv" { + t.Errorf("expected SigningPrivate 'sign-priv', got %q", keys.SigningPrivate) + } + }) + } +} + +func TestGetKeysFailure(t *testing.T) { + tests := []struct { + name string + kvVersion string + keyID string + success bool + }{ + { + name: "failure: vault returns 404 v2", + kvVersion: "v2", + keyID: "missing-key-v2", + success: false, + }, + { + name: "failure: vault returns 404 v1", + kvVersion: "v1", + keyID: "missing-key-v1", + success: false, + }, + { + name: "failure: empty keyID", + kvVersion: "v2", + keyID: "", + success: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ts *httptest.Server + if tt.keyID != "" { + ts = setupMockVaultServer(t, tt.kvVersion, tt.keyID, tt.success) + defer ts.Close() + } + + cfg := vault.DefaultConfig() + if ts != nil { + cfg.Address = ts.URL + } else { + // For empty keyID case or no mock server, use invalid URL to force error + cfg.Address = "http://invalid" + } + + client, err := vault.NewClient(cfg) + if err != nil { + t.Fatalf("failed to create Vault client: %v", err) + } + + km := &KeyMgr{ + VaultClient: client, + KvVersion: tt.kvVersion, + } + + keys, err := km.getKeys(context.Background(), tt.keyID) + if err == nil { + t.Fatalf("expected error but got nil") + } + if keys != nil { + t.Fatalf("expected nil keys but got %+v", keys) + } + }) + } +} + +func TestGetPublicKeysSuccess(t *testing.T) { + km := &KeyMgr{ + Cache: &mockCache{}, + Registry: &mockRegistry{ + LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + }, + KeyID: sub.KeyID, + SigningPublicKey: "mock-signing-public-key", + EncrPublicKey: "mock-encryption-public-key", + }, + }, nil + }, + }, + } + + got, err := km.getPublicKeys(context.Background(), "sub-id", "key-id") + if err != nil { + t.Fatalf("getPublicKeys() unexpected error: %v", err) + } + if got == nil { + t.Fatal("getPublicKeys() returned nil Keyset") + } + if got.SigningPublic != "mock-signing-public-key" { + t.Errorf("SigningPublic = %v, want %v", got.SigningPublic, "mock-signing-public-key") + } + if got.EncrPublic != "mock-encryption-public-key" { + t.Errorf("EncrPublic = %v, want %v", got.EncrPublic, "mock-encryption-public-key") + } +} + +func TestGetPublicKeysFailure(t *testing.T) { + type fields struct { + cache definition.Cache + registry definition.RegistryLookup + } + type args struct { + subscriberID string + uniqueKeyID string + } + tests := []struct { + name string + fields fields + args args + errMessage string + }{ + { + name: "failure - invalid parameters", + fields: fields{ + cache: &mockCache{}, + registry: &mockRegistry{}, + }, + args: args{ + subscriberID: "", + uniqueKeyID: "", + }, + errMessage: "invalid", + }, + { + name: "failure - registry lookup fails", + fields: fields{ + cache: &mockCache{}, + registry: &mockRegistry{ + LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return nil, fmt.Errorf("registry down") + }, + }, + }, + args: args{ + subscriberID: "sub-id", + uniqueKeyID: "key-id", + }, + errMessage: "registry down", + }, + { + name: "failure - registry returns empty", + fields: fields{ + cache: &mockCache{}, + registry: &mockRegistry{ + LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return []model.Subscription{}, nil + }, + }, + }, + args: args{ + subscriberID: "sub-id", + uniqueKeyID: "key-id", + }, + errMessage: "no subscriber found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + km := &KeyMgr{ + Cache: tt.fields.cache, + Registry: tt.fields.registry, + } + got, err := km.getPublicKeys(context.Background(), tt.args.subscriberID, tt.args.uniqueKeyID) + if err == nil { + t.Errorf("getPublicKeys() expected error but got none, result: %v", got) + return + } + if !strings.Contains(err.Error(), tt.errMessage) { + t.Errorf("getPublicKeys() error = %v, want error message to contain %v", err.Error(), tt.errMessage) + } + }) + } +} + +func TestLookupRegistrySuccess(t *testing.T) { + km := &KeyMgr{ + Registry: &mockRegistry{ + LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + }, + KeyID: sub.KeyID, + SigningPublicKey: "signing-key", + EncrPublicKey: "encryption-key", + }, + }, nil + }, + }, + } + + got, err := km.lookupRegistry(context.Background(), "test-subscriber", "key123") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + want := &model.Keyset{ + SigningPublic: "signing-key", + EncrPublic: "encryption-key", + } + + if got.SigningPublic != want.SigningPublic { + t.Errorf("SigningPublic = %v, want %v", got.SigningPublic, want.SigningPublic) + } + if got.EncrPublic != want.EncrPublic { + t.Errorf("EncrPublic = %v, want %v", got.EncrPublic, want.EncrPublic) + } +} + +func TestLookupRegistryFailure(t *testing.T) { + tests := []struct { + name string + mockLookup func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) + wantErr error + }{ + { + name: "lookup error", + mockLookup: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return nil, fmt.Errorf("registry failure") + }, + wantErr: fmt.Errorf("registry failure"), + }, + { + name: "no subscriber found", + mockLookup: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return []model.Subscription{}, nil + }, + wantErr: ErrSubscriberNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + km := &KeyMgr{ + Registry: &mockRegistry{ + LookupFunc: tt.mockLookup, + }, + } + got, err := km.lookupRegistry(context.Background(), "some-id", "key-id") + if err == nil { + t.Fatalf("expected error, got none") + } + if got != nil { + t.Errorf("expected nil keyset, got %v", got) + } + }) + } +} + +func TestValidateParamsSuccess(t *testing.T) { + err := validateParams("someSubscriberID", "someUniqueKeyID") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestValidateParamsFailure(t *testing.T) { + tests := []struct { + name string + subscriberID string + uniqueKeyID string + wantErr error + }{ + { + name: "empty subscriberID", + subscriberID: "", + uniqueKeyID: "validKeyID", + wantErr: ErrEmptySubscriberID, + }, + { + name: "empty uniqueKeyID", + subscriberID: "validSubscriberID", + uniqueKeyID: "", + wantErr: ErrEmptyUniqueKeyID, + }, + { + name: "both empty", + subscriberID: "", + uniqueKeyID: "", + wantErr: ErrEmptySubscriberID, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateParams(tt.subscriberID, tt.uniqueKeyID) + if err == nil { + t.Fatalf("expected error %v but got nil", tt.wantErr) + } + if err != tt.wantErr { + t.Errorf("expected error %v, got %v", tt.wantErr, err) + } + }) + } +}