Files
onix/pkg/plugin/implementation/keymanager/keymanager_test.go

1115 lines
27 KiB
Go

package keymanager
import (
"context"
"crypto/ecdh"
"crypto/ed25519"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/beckn-one/beckn-onix/pkg/model"
"github.com/beckn-one/beckn-onix/pkg/plugin/definition"
"github.com/google/uuid"
"github.com/hashicorp/vault/api"
vault "github.com/hashicorp/vault/api"
)
type mockRegistry struct {
LookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
}
func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
if m.LookupFunc != nil {
return m.LookupFunc(ctx, sub)
}
return []model.Subscription{
{
Subscriber: model.Subscriber{
SubscriberID: sub.SubscriberID,
URL: "https://mock.registry/subscriber",
Type: "BPP",
Domain: "retail",
},
KeyID: sub.KeyID,
SigningPublicKey: "mock-signing-public-key",
EncrPublicKey: "mock-encryption-public-key",
ValidFrom: time.Now().Add(-time.Hour),
ValidUntil: time.Now().Add(time.Hour),
Status: "SUBSCRIBED",
Created: time.Now().Add(-2 * time.Hour),
Updated: time.Now(),
Nonce: "mock-nonce",
},
}, nil
}
type mockCache struct {
GetFunc func(ctx context.Context, key string) (string, error)
}
func (m *mockCache) Get(ctx context.Context, key string) (string, error) {
return "", nil
}
func (m *mockCache) Set(ctx context.Context, key string, value string, ttl time.Duration) error {
return nil
}
func (m *mockCache) Clear(ctx context.Context) error {
return nil
}
func (m *mockCache) Delete(ctx context.Context, key string) error {
return nil
}
func TestValidateCfgSuccess(t *testing.T) {
tests := []struct {
name string
cfg *Config
wantKV string
}{
{
name: "valid config with v1",
cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v1"},
wantKV: "v1",
},
{
name: "valid config with v2",
cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v2"},
wantKV: "v2",
},
{
name: "default KV version applied",
cfg: &Config{VaultAddr: "http://localhost:8200"},
wantKV: "v1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateCfg(tt.cfg)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if tt.cfg.KVVersion != tt.wantKV {
t.Errorf("expected KVVersion %s, got %s", tt.wantKV, tt.cfg.KVVersion)
}
})
}
}
func TestValidateCfgFailure(t *testing.T) {
tests := []struct {
name string
cfg *Config
}{
{
name: "missing Vault address",
cfg: &Config{VaultAddr: "", KVVersion: "v1"},
},
{
name: "invalid KV version",
cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateCfg(tt.cfg)
if err == nil {
t.Errorf("expected error, got nil")
}
})
}
}
func TestGenerateKeyPairs(t *testing.T) {
originalEd25519 := ed25519KeyGenFunc
originalX25519 := x25519KeyGenFunc
originalUUID := uuidGenFunc
defer func() {
ed25519KeyGenFunc = originalEd25519
x25519KeyGenFunc = originalX25519
uuidGenFunc = originalUUID
}()
tests := []struct {
name string
mockEd25519Err error
mockX25519Err error
mockUUIDErr error
expectErr bool
}{
{
name: "success case",
expectErr: false,
},
{
name: "ed25519 key generation failure",
mockEd25519Err: errors.New("mock ed25519 failure"),
expectErr: true,
},
{
name: "x25519 key generation failure",
mockX25519Err: errors.New("mock x25519 failure"),
expectErr: true,
},
{
name: "UUID generation failure",
mockUUIDErr: errors.New("mock uuid failure"),
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.mockEd25519Err != nil {
ed25519KeyGenFunc = func(_ io.Reader) (ed25519.PublicKey, ed25519.PrivateKey, error) {
return nil, nil, tt.mockEd25519Err
}
} else {
ed25519KeyGenFunc = ed25519.GenerateKey
}
if tt.mockX25519Err != nil {
x25519KeyGenFunc = func(_ io.Reader) (*ecdh.PrivateKey, error) {
return nil, tt.mockX25519Err
}
} else {
x25519KeyGenFunc = ecdh.X25519().GenerateKey
}
if tt.mockUUIDErr != nil {
uuidGenFunc = func() (uuid.UUID, error) {
return uuid.Nil, tt.mockUUIDErr
}
} else {
uuidGenFunc = uuid.NewRandom
}
km := &KeyMgr{}
keyset, err := km.GenerateKeyset()
if tt.expectErr {
if err == nil {
t.Errorf("expected error, got nil")
}
if keyset != nil {
t.Errorf("expected nil keyset, got non-nil")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if keyset == nil {
t.Fatal("expected keyset, got nil")
}
if keyset.SigningPrivate == "" || keyset.SigningPublic == "" || keyset.EncrPrivate == "" || keyset.EncrPublic == "" {
t.Error("expected all keys to be populated and base64-encoded")
}
if keyset.UniqueKeyID == "" {
t.Error("expected UniqueKeyID to be non-empty")
}
}
})
}
}
func TestGetVaultClient_Failures(t *testing.T) {
originalNewVaultClient := NewVaultClient
defer func() { NewVaultClient = originalNewVaultClient }()
ctx := context.Background()
tests := []struct {
name string
roleID string
secretID string
setupServer func(t *testing.T) *httptest.Server
expectErr string
}{
{
name: "missing credentials",
roleID: "",
secretID: "",
expectErr: "VAULT_ROLE_ID or VAULT_SECRET_ID is not set",
},
{
name: "vault client creation failure",
roleID: "test-role",
secretID: "test-secret",
setupServer: func(t *testing.T) *httptest.Server {
NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) {
return nil, errors.New("mock client creation error")
}
return nil
},
expectErr: "failed to create Vault client: mock client creation error",
},
{
name: "AppRole login failure",
roleID: "test-role",
secretID: "test-secret",
setupServer: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "login failed", http.StatusBadRequest)
}))
},
expectErr: "failed to login with AppRole: Error making API request",
},
{
name: "AppRole login returns nil auth",
roleID: "test-role",
secretID: "test-secret",
setupServer: func(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
if _, err := io.WriteString(w, `{ "auth": null }`); err != nil {
t.Fatalf("failed to write response: %v", err)
}
}))
},
expectErr: "AppRole login failed: no auth info returned",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv("VAULT_ROLE_ID", tt.roleID)
os.Setenv("VAULT_SECRET_ID", tt.secretID)
var server *httptest.Server
if tt.setupServer != nil {
server = tt.setupServer(t)
if server != nil {
NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) {
cfg.Address = server.URL
return vault.NewClient(cfg)
}
defer server.Close()
}
}
client, err := GetVaultClient(ctx, "http://ignored")
if err == nil || !strings.Contains(err.Error(), tt.expectErr) {
t.Errorf("expected error to contain '%s', got: %v", tt.expectErr, err)
}
if client != nil {
t.Error("expected client to be nil on failure")
}
})
}
}
func TestGetVaultClient_Success(t *testing.T) {
originalNewVaultClient := NewVaultClient
defer func() { NewVaultClient = originalNewVaultClient }()
ctx := context.Background()
os.Setenv("VAULT_ROLE_ID", "test-role")
os.Setenv("VAULT_SECRET_ID", "test-secret")
// Mock Vault server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasSuffix(r.URL.Path, "/auth/approle/login") {
t.Errorf("unexpected request path: %s", r.URL.Path)
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
if _, err := io.WriteString(w, `{
"auth": {
"client_token": "mock-token"
}
}`); err != nil {
t.Fatalf("failed to write response: %v", err)
}
}))
defer server.Close()
NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) {
cfg.Address = server.URL
return vault.NewClient(cfg)
}
client, err := GetVaultClient(ctx, "http://ignored")
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if client == nil {
t.Fatal("expected non-nil client")
}
if token := client.Token(); token != "mock-token" {
t.Errorf("expected token to be 'mock-token', got: %s", token)
}
}
type mockRegistryLookup struct{}
func (m *mockRegistryLookup) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return []model.Subscription{
{
Subscriber: model.Subscriber{
SubscriberID: sub.SubscriberID,
Type: sub.Type,
},
KeyID: "mock-key-id",
SigningPublicKey: "mock-signing-pubkey",
EncrPublicKey: "mock-encryption-pubkey",
ValidFrom: time.Now().Add(-time.Hour),
ValidUntil: time.Now().Add(time.Hour),
Status: "SUBSCRIBED",
Created: time.Now(),
Updated: time.Now(),
Nonce: "mock-nonce",
},
}, nil
}
func TestNewSuccess(t *testing.T) {
tests := []struct {
name string
cfg *Config
cache definition.Cache
registry definition.RegistryLookup
mockVaultStatus int
mockVaultBody string
}{
{
name: "valid config",
cfg: &Config{
VaultAddr: "http://dummy",
KVVersion: "v2",
},
cache: &mockCache{},
registry: &mockRegistryLookup{},
mockVaultStatus: http.StatusOK,
mockVaultBody: `{}`,
},
}
originalGetVaultClient := getVaultClient
defer func() { getVaultClient = originalGetVaultClient }()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
vaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.mockVaultStatus)
fmt.Fprint(w, tt.mockVaultBody)
}))
defer vaultServer.Close()
tt.cfg.VaultAddr = vaultServer.URL
getVaultClient = func(ctx context.Context, addr string) (*vault.Client, error) {
cfg := vault.DefaultConfig()
cfg.Address = addr
return vault.NewClient(cfg)
}
ctx := context.Background()
km, cleanup, err := New(ctx, tt.cache, tt.registry, tt.cfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if km == nil {
t.Fatalf("expected KeyMgr instance, got nil")
}
if cleanup == nil {
t.Fatalf("expected cleanup function, got nil")
}
_ = cleanup()
})
}
}
func TestNewFailure(t *testing.T) {
tests := []struct {
name string
cfg *Config
cache definition.Cache
registry definition.RegistryLookup
mockVaultStatus int
mockVaultBody string
}{
{
name: "nil cache",
cfg: &Config{
VaultAddr: "http://dummy",
KVVersion: "v2",
},
cache: nil,
registry: &mockRegistryLookup{},
mockVaultStatus: http.StatusOK,
mockVaultBody: `{}`,
},
{
name: "nil registry",
cfg: &Config{
VaultAddr: "http://dummy",
KVVersion: "v2",
},
cache: &mockCache{},
registry: nil,
mockVaultStatus: http.StatusOK,
mockVaultBody: `{}`,
},
{
name: "invalid config",
cfg: &Config{
VaultAddr: "", // Invalid
KVVersion: "v3", // Unsupported
},
cache: &mockCache{},
registry: &mockRegistryLookup{},
mockVaultStatus: http.StatusOK,
mockVaultBody: `{}`,
},
{
name: "vault client creation failure",
cfg: &Config{
VaultAddr: "http://dummy",
KVVersion: "v2",
},
cache: &mockCache{},
registry: &mockRegistryLookup{},
mockVaultStatus: http.StatusOK,
mockVaultBody: `{}`,
},
}
originalGetVaultClient := getVaultClient
defer func() { getVaultClient = originalGetVaultClient }()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
vaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.mockVaultStatus)
fmt.Fprint(w, tt.mockVaultBody)
}))
defer vaultServer.Close()
if tt.cfg != nil {
tt.cfg.VaultAddr = vaultServer.URL
}
getVaultClient = func(ctx context.Context, addr string) (*vault.Client, error) {
if tt.name == "vault client creation failure" {
return nil, errors.New("simulated vault client creation error")
}
cfg := vault.DefaultConfig()
cfg.Address = addr
return vault.NewClient(cfg)
}
ctx := context.Background()
km, cleanup, err := New(ctx, tt.cache, tt.registry, tt.cfg)
if err == nil {
t.Error("expected error, got nil")
}
if km != nil {
t.Error("expected KeyMgr to be nil, got non-nil")
}
if cleanup != nil {
t.Error("expected cleanup to be nil, got non-nil")
}
})
}
}
func TestStorePrivateKeysSuccess(t *testing.T) {
ctx := context.Background()
keys := &model.Keyset{
UniqueKeyID: "uuid",
SigningPublic: "signPub",
SigningPrivate: "signPriv",
EncrPublic: "encrPub",
EncrPrivate: "encrPriv",
}
tests := []struct {
name string
kvVersion string
keyID string
keys *model.Keyset
}{
{
name: "success kv v1",
kvVersion: "v1",
keyID: "mykeyid",
keys: keys,
},
{
name: "success kv v2",
kvVersion: "v2",
keyID: "mykeyid",
keys: keys,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath := ""
if tt.kvVersion == "v2" {
expectedPath = "/v1/secret/data/keys/" + tt.keyID
} else {
expectedPath = "/v1/secret/keys/" + tt.keyID
}
if r.URL.Path != expectedPath {
t.Errorf("unexpected request path: got %s, want %s", r.URL.Path, expectedPath)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
fmt.Fprintln(w, `{"data":{}}`)
}))
defer server.Close()
config := api.DefaultConfig()
config.Address = server.URL
client, err := api.NewClient(config)
if err != nil {
t.Fatalf("failed to create Vault client: %v", err)
}
km := &KeyMgr{
VaultClient: client,
KvVersion: tt.kvVersion,
}
err = km.InsertKeyset(ctx, tt.keyID, tt.keys)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
})
}
}
func TestStorePrivateKeysFailure(t *testing.T) {
ctx := context.Background()
keys := &model.Keyset{
UniqueKeyID: "uuid",
SigningPublic: "signPub",
SigningPrivate: "signPriv",
EncrPublic: "encrPub",
EncrPrivate: "encrPriv",
}
tests := []struct {
name string
kvVersion string
keyID string
keys *model.Keyset
statusCode int // for HTTP error simulation
expectedErr string
}{
{
name: "empty keyID",
keyID: "",
keys: keys,
expectedErr: ErrEmptyKeyID.Error(),
},
{
name: "nil keys",
keyID: "mykeyid",
keys: nil,
expectedErr: ErrNilKeySet.Error(),
},
{
name: "vault write error",
kvVersion: "v1",
keyID: "mykeyid",
keys: keys,
statusCode: 500,
expectedErr: "failed to store secret in Vault at path secret/keys/mykeyid: Error making API request.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var server *httptest.Server
if tt.statusCode != 0 {
// Setup test HTTP server to simulate Vault error
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal error", tt.statusCode)
}))
defer server.Close()
}
var client *api.Client
var err error
if server != nil {
config := api.DefaultConfig()
config.Address = server.URL
client, err = api.NewClient(config)
if err != nil {
t.Fatalf("failed to create Vault client: %v", err)
}
} else {
client = nil
}
km := &KeyMgr{
VaultClient: client,
KvVersion: tt.kvVersion,
}
err = km.InsertKeyset(ctx, tt.keyID, tt.keys)
if err == nil {
t.Fatalf("expected error %q but got nil", tt.expectedErr)
}
if !strings.Contains(err.Error(), tt.expectedErr) {
t.Errorf("expected error containing %q, got %v", tt.expectedErr, err)
}
})
}
}
func TestDeletePrivateKeys(t *testing.T) {
tests := []struct {
name string
kvVersion string
keyID string
wantPath string
wantErr error
}{
{
name: "empty keyID",
kvVersion: "v1",
keyID: "",
wantErr: ErrEmptyKeyID,
},
{
name: "v1 delete",
kvVersion: "v1",
keyID: "key123",
wantPath: "/v1/secret/keys/key123/data/key123",
wantErr: nil,
},
{
name: "v2 delete",
kvVersion: "v2",
keyID: "key123",
wantPath: "/v1/secret/data/keys/key123/data/key123",
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// If empty keyID, no Vault calls, just check error
if tt.keyID == "" {
km := &KeyMgr{
KvVersion: tt.kvVersion,
VaultClient: nil,
}
err := km.DeleteKeyset(context.Background(), tt.keyID)
if err != tt.wantErr {
t.Errorf("expected error %v, got %v", tt.wantErr, err)
}
return
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
t.Errorf("Expected DELETE method, got %s", r.Method)
}
if r.URL.Path != tt.wantPath {
t.Errorf("Expected path %s, got %s", tt.wantPath, r.URL.Path)
}
w.WriteHeader(http.StatusNoContent)
}))
defer ts.Close()
vaultClient, err := NewVaultClient(&vault.Config{Address: ts.URL})
if err != nil {
t.Fatalf("failed to create vault client: %v", err)
}
km := &KeyMgr{
KvVersion: tt.kvVersion,
VaultClient: vaultClient,
}
err = km.DeleteKeyset(context.Background(), tt.keyID)
if err != tt.wantErr {
t.Errorf("DeletePrivateKeys() error = %v, want %v", err, tt.wantErr)
}
})
}
}
func setupMockVaultServer(t *testing.T, kvVersion, keyID string, success bool) *httptest.Server {
t.Helper()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPathV1 := fmt.Sprintf("/v1/secret/keys/%s", keyID)
expectedPathV2 := fmt.Sprintf("/v1/secret/data/keys/%s", keyID)
if (kvVersion == "v2" && r.URL.Path != expectedPathV2) || (kvVersion != "v2" && r.URL.Path != expectedPathV1) {
http.Error(w, "not found", http.StatusNotFound)
return
}
if !success {
http.Error(w, `{"errors":["key not found"]}`, http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
if kvVersion == "v2" {
resp := fmt.Sprintf(`{
"request_id": "req-1234",
"lease_id": "",
"renewable": false,
"lease_duration": 0,
"data": {
"data": {
"uniqueKeyID": "%s",
"signingPublicKey": "sign-pub",
"signingPrivateKey": "sign-priv",
"encrPublicKey": "encr-pub",
"encrPrivateKey": "encr-priv"
},
"metadata": {
"created_time": "2025-05-28T00:00:00Z",
"deletion_time": "",
"destroyed": false,
"version": 1
}
},
"warnings": null,
"auth": null
}`, keyID)
if _, err := w.Write([]byte(resp)); err != nil {
t.Fatalf("failed to write response: %v", err)
}
} else {
resp := fmt.Sprintf(`{
"request_id": "req-1234",
"lease_id": "",
"renewable": false,
"lease_duration": 0,
"data": {
"uniqueKeyID": "%s",
"signingPublicKey": "sign-pub",
"signingPrivateKey": "sign-priv",
"encrPublicKey": "encr-pub",
"encrPrivateKey": "encr-priv"
},
"warnings": null,
"auth": null
}`, keyID)
if _, err := w.Write([]byte(resp)); err != nil {
t.Fatalf("failed to write response: %v", err)
}
}
})
return httptest.NewServer(handler)
}
func TestKeysetSuccess(t *testing.T) {
tests := []struct {
name string
kvVersion string
keyID string
}{
{
name: "success with KV v2",
kvVersion: "v2",
keyID: "test-key-v2",
},
{
name: "success with KV v1",
kvVersion: "v1",
keyID: "test-key-v1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ts := setupMockVaultServer(t, tt.kvVersion, tt.keyID, true)
defer ts.Close()
cfg := vault.DefaultConfig()
cfg.Address = ts.URL
client, err := vault.NewClient(cfg)
if err != nil {
t.Fatalf("failed to create Vault client: %v", err)
}
km := &KeyMgr{
VaultClient: client,
KvVersion: tt.kvVersion,
}
keys, err := km.Keyset(context.Background(), tt.keyID)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if keys == nil {
t.Fatalf("expected keys but got nil")
}
if keys.UniqueKeyID != tt.keyID {
t.Errorf("expected UniqueKeyID %q, got %q", tt.keyID, keys.UniqueKeyID)
}
if keys.SigningPrivate != "sign-priv" {
t.Errorf("expected SigningPrivate 'sign-priv', got %q", keys.SigningPrivate)
}
})
}
}
func TestKeysetFailure(t *testing.T) {
tests := []struct {
name string
kvVersion string
keyID string
success bool
}{
{
name: "failure: vault returns 404 v2",
kvVersion: "v2",
keyID: "missing-key-v2",
success: false,
},
{
name: "failure: vault returns 404 v1",
kvVersion: "v1",
keyID: "missing-key-v1",
success: false,
},
{
name: "failure: empty keyID",
kvVersion: "v2",
keyID: "",
success: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ts *httptest.Server
if tt.keyID != "" {
ts = setupMockVaultServer(t, tt.kvVersion, tt.keyID, tt.success)
defer ts.Close()
}
cfg := vault.DefaultConfig()
if ts != nil {
cfg.Address = ts.URL
} else {
// For empty keyID case or no mock server, use invalid URL to force error
cfg.Address = "http://invalid"
}
client, err := vault.NewClient(cfg)
if err != nil {
t.Fatalf("failed to create Vault client: %v", err)
}
km := &KeyMgr{
VaultClient: client,
KvVersion: tt.kvVersion,
}
keys, err := km.Keyset(context.Background(), tt.keyID)
if err == nil {
t.Fatalf("expected error but got nil")
}
if keys != nil {
t.Fatalf("expected nil keys but got %+v", keys)
}
})
}
}
func TestValidateParamsSuccess(t *testing.T) {
err := validateParams("someSubscriberID", "someUniqueKeyID")
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
}
func TestValidateParamsFailure(t *testing.T) {
tests := []struct {
name string
subscriberID string
uniqueKeyID string
wantErr error
}{
{
name: "empty subscriberID",
subscriberID: "",
uniqueKeyID: "validKeyID",
wantErr: ErrEmptySubscriberID,
},
{
name: "empty uniqueKeyID",
subscriberID: "validSubscriberID",
uniqueKeyID: "",
wantErr: ErrEmptyUniqueKeyID,
},
{
name: "both empty",
subscriberID: "",
uniqueKeyID: "",
wantErr: ErrEmptySubscriberID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateParams(tt.subscriberID, tt.uniqueKeyID)
if err == nil {
t.Fatalf("expected error %v but got nil", tt.wantErr)
}
if err != tt.wantErr {
t.Errorf("expected error %v, got %v", tt.wantErr, err)
}
})
}
}
func TestLookupNPKeysSuccess(t *testing.T) {
tests := []struct {
name string
cacheGetFunc func(ctx context.Context, key string) (string, error)
registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
expectedSigningPub string
expectedEncrPub string
}{
{
name: "Cache hit with valid keys",
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
return `{"SigningPublic":"mock-signing-public-key","EncrPublic":"mock-encryption-public-key"}`, nil
},
registryLookupFunc: nil,
expectedSigningPub: "mock-signing-public-key",
expectedEncrPub: "mock-encryption-public-key",
},
{
name: "Cache miss and registry success",
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
return "", nil
},
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return []model.Subscription{
{
Subscriber: model.Subscriber{
SubscriberID: sub.SubscriberID,
},
KeyID: sub.KeyID,
SigningPublicKey: "mock-signing-public-key",
EncrPublicKey: "mock-encryption-public-key",
},
}, nil
},
expectedSigningPub: "mock-signing-public-key",
expectedEncrPub: "mock-encryption-public-key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up the KeyMgr with mocks
km := &KeyMgr{
Cache: &mockCache{
GetFunc: tt.cacheGetFunc,
},
Registry: &mockRegistry{
LookupFunc: tt.registryLookupFunc,
},
}
// Call the method
signingPublic, encrPublic, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id")
// Validate no errors in success cases
if err != nil {
t.Fatalf("LookupNPKeys() unexpected error: %v", err)
}
// Validate returned public keys
if signingPublic != tt.expectedSigningPub {
t.Errorf("SigningPublic = %v, want %v", signingPublic, tt.expectedSigningPub)
}
if encrPublic != tt.expectedEncrPub {
t.Errorf("EncrPublic = %v, want %v", encrPublic, tt.expectedEncrPub)
}
})
}
}
func TestLookupNPKeysFailure(t *testing.T) {
tests := []struct {
name string
cacheGetFunc func(ctx context.Context, key string) (string, error)
registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
expectedError string
}{
{
name: "Cache miss and registry failure",
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
return "", nil
},
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return nil, fmt.Errorf("registry down")
},
expectedError: "registry down",
},
{
name: "Cache miss and registry returns no subscriber",
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
return "", nil
},
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return nil, nil
},
expectedError: "no subscriber found with given credentials",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up the KeyMgr with mocks
km := &KeyMgr{
Cache: &mockCache{
GetFunc: tt.cacheGetFunc,
},
Registry: &mockRegistry{
LookupFunc: tt.registryLookupFunc,
},
}
_, _, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id")
if err == nil {
t.Fatalf("expected an error but got none")
}
if !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error to contain %v, got %v", tt.expectedError, err.Error())
}
})
}
}