1115 lines
27 KiB
Go
1115 lines
27 KiB
Go
package keymanager
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdh"
|
|
"crypto/ed25519"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/beckn-one/beckn-onix/pkg/model"
|
|
"github.com/beckn-one/beckn-onix/pkg/plugin/definition"
|
|
"github.com/google/uuid"
|
|
"github.com/hashicorp/vault/api"
|
|
vault "github.com/hashicorp/vault/api"
|
|
)
|
|
|
|
type mockRegistry struct {
|
|
LookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
|
|
}
|
|
|
|
func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
if m.LookupFunc != nil {
|
|
return m.LookupFunc(ctx, sub)
|
|
}
|
|
return []model.Subscription{
|
|
{
|
|
Subscriber: model.Subscriber{
|
|
SubscriberID: sub.SubscriberID,
|
|
URL: "https://mock.registry/subscriber",
|
|
Type: "BPP",
|
|
Domain: "retail",
|
|
},
|
|
KeyID: sub.KeyID,
|
|
SigningPublicKey: "mock-signing-public-key",
|
|
EncrPublicKey: "mock-encryption-public-key",
|
|
ValidFrom: time.Now().Add(-time.Hour),
|
|
ValidUntil: time.Now().Add(time.Hour),
|
|
Status: "SUBSCRIBED",
|
|
Created: time.Now().Add(-2 * time.Hour),
|
|
Updated: time.Now(),
|
|
Nonce: "mock-nonce",
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
type mockCache struct {
|
|
GetFunc func(ctx context.Context, key string) (string, error)
|
|
}
|
|
|
|
func (m *mockCache) Get(ctx context.Context, key string) (string, error) {
|
|
return "", nil
|
|
}
|
|
func (m *mockCache) Set(ctx context.Context, key string, value string, ttl time.Duration) error {
|
|
return nil
|
|
}
|
|
func (m *mockCache) Clear(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockCache) Delete(ctx context.Context, key string) error {
|
|
return nil
|
|
}
|
|
|
|
func TestValidateCfgSuccess(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg *Config
|
|
wantKV string
|
|
}{
|
|
{
|
|
name: "valid config with v1",
|
|
cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v1"},
|
|
wantKV: "v1",
|
|
},
|
|
{
|
|
name: "valid config with v2",
|
|
cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v2"},
|
|
wantKV: "v2",
|
|
},
|
|
{
|
|
name: "default KV version applied",
|
|
cfg: &Config{VaultAddr: "http://localhost:8200"},
|
|
wantKV: "v1",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := ValidateCfg(tt.cfg)
|
|
if err != nil {
|
|
t.Errorf("expected no error, got %v", err)
|
|
}
|
|
if tt.cfg.KVVersion != tt.wantKV {
|
|
t.Errorf("expected KVVersion %s, got %s", tt.wantKV, tt.cfg.KVVersion)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateCfgFailure(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg *Config
|
|
}{
|
|
{
|
|
name: "missing Vault address",
|
|
cfg: &Config{VaultAddr: "", KVVersion: "v1"},
|
|
},
|
|
{
|
|
name: "invalid KV version",
|
|
cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v3"},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := ValidateCfg(tt.cfg)
|
|
if err == nil {
|
|
t.Errorf("expected error, got nil")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateKeyPairs(t *testing.T) {
|
|
originalEd25519 := ed25519KeyGenFunc
|
|
originalX25519 := x25519KeyGenFunc
|
|
originalUUID := uuidGenFunc
|
|
|
|
defer func() {
|
|
ed25519KeyGenFunc = originalEd25519
|
|
x25519KeyGenFunc = originalX25519
|
|
uuidGenFunc = originalUUID
|
|
}()
|
|
|
|
tests := []struct {
|
|
name string
|
|
mockEd25519Err error
|
|
mockX25519Err error
|
|
mockUUIDErr error
|
|
expectErr bool
|
|
}{
|
|
{
|
|
name: "success case",
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "ed25519 key generation failure",
|
|
mockEd25519Err: errors.New("mock ed25519 failure"),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "x25519 key generation failure",
|
|
mockX25519Err: errors.New("mock x25519 failure"),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "UUID generation failure",
|
|
mockUUIDErr: errors.New("mock uuid failure"),
|
|
expectErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.mockEd25519Err != nil {
|
|
ed25519KeyGenFunc = func(_ io.Reader) (ed25519.PublicKey, ed25519.PrivateKey, error) {
|
|
return nil, nil, tt.mockEd25519Err
|
|
}
|
|
} else {
|
|
ed25519KeyGenFunc = ed25519.GenerateKey
|
|
}
|
|
|
|
if tt.mockX25519Err != nil {
|
|
x25519KeyGenFunc = func(_ io.Reader) (*ecdh.PrivateKey, error) {
|
|
return nil, tt.mockX25519Err
|
|
}
|
|
} else {
|
|
x25519KeyGenFunc = ecdh.X25519().GenerateKey
|
|
}
|
|
|
|
if tt.mockUUIDErr != nil {
|
|
uuidGenFunc = func() (uuid.UUID, error) {
|
|
return uuid.Nil, tt.mockUUIDErr
|
|
}
|
|
} else {
|
|
uuidGenFunc = uuid.NewRandom
|
|
}
|
|
|
|
km := &KeyMgr{}
|
|
keyset, err := km.GenerateKeyset()
|
|
|
|
if tt.expectErr {
|
|
if err == nil {
|
|
t.Errorf("expected error, got nil")
|
|
}
|
|
if keyset != nil {
|
|
t.Errorf("expected nil keyset, got non-nil")
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
if keyset == nil {
|
|
t.Fatal("expected keyset, got nil")
|
|
}
|
|
if keyset.SigningPrivate == "" || keyset.SigningPublic == "" || keyset.EncrPrivate == "" || keyset.EncrPublic == "" {
|
|
t.Error("expected all keys to be populated and base64-encoded")
|
|
}
|
|
if keyset.UniqueKeyID == "" {
|
|
t.Error("expected UniqueKeyID to be non-empty")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetVaultClient_Failures(t *testing.T) {
|
|
originalNewVaultClient := NewVaultClient
|
|
defer func() { NewVaultClient = originalNewVaultClient }()
|
|
|
|
ctx := context.Background()
|
|
|
|
tests := []struct {
|
|
name string
|
|
roleID string
|
|
secretID string
|
|
setupServer func(t *testing.T) *httptest.Server
|
|
expectErr string
|
|
}{
|
|
{
|
|
name: "missing credentials",
|
|
roleID: "",
|
|
secretID: "",
|
|
expectErr: "VAULT_ROLE_ID or VAULT_SECRET_ID is not set",
|
|
},
|
|
{
|
|
name: "vault client creation failure",
|
|
roleID: "test-role",
|
|
secretID: "test-secret",
|
|
setupServer: func(t *testing.T) *httptest.Server {
|
|
NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) {
|
|
return nil, errors.New("mock client creation error")
|
|
}
|
|
return nil
|
|
},
|
|
expectErr: "failed to create Vault client: mock client creation error",
|
|
},
|
|
{
|
|
name: "AppRole login failure",
|
|
roleID: "test-role",
|
|
secretID: "test-secret",
|
|
setupServer: func(t *testing.T) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Error(w, "login failed", http.StatusBadRequest)
|
|
}))
|
|
},
|
|
expectErr: "failed to login with AppRole: Error making API request",
|
|
},
|
|
{
|
|
name: "AppRole login returns nil auth",
|
|
roleID: "test-role",
|
|
secretID: "test-secret",
|
|
setupServer: func(t *testing.T) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if _, err := io.WriteString(w, `{ "auth": null }`); err != nil {
|
|
t.Fatalf("failed to write response: %v", err)
|
|
}
|
|
}))
|
|
},
|
|
expectErr: "AppRole login failed: no auth info returned",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
os.Setenv("VAULT_ROLE_ID", tt.roleID)
|
|
os.Setenv("VAULT_SECRET_ID", tt.secretID)
|
|
|
|
var server *httptest.Server
|
|
if tt.setupServer != nil {
|
|
server = tt.setupServer(t)
|
|
if server != nil {
|
|
NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) {
|
|
cfg.Address = server.URL
|
|
return vault.NewClient(cfg)
|
|
}
|
|
defer server.Close()
|
|
}
|
|
}
|
|
|
|
client, err := GetVaultClient(ctx, "http://ignored")
|
|
if err == nil || !strings.Contains(err.Error(), tt.expectErr) {
|
|
t.Errorf("expected error to contain '%s', got: %v", tt.expectErr, err)
|
|
}
|
|
if client != nil {
|
|
t.Error("expected client to be nil on failure")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetVaultClient_Success(t *testing.T) {
|
|
originalNewVaultClient := NewVaultClient
|
|
defer func() { NewVaultClient = originalNewVaultClient }()
|
|
|
|
ctx := context.Background()
|
|
|
|
os.Setenv("VAULT_ROLE_ID", "test-role")
|
|
os.Setenv("VAULT_SECRET_ID", "test-secret")
|
|
|
|
// Mock Vault server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !strings.HasSuffix(r.URL.Path, "/auth/approle/login") {
|
|
t.Errorf("unexpected request path: %s", r.URL.Path)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if _, err := io.WriteString(w, `{
|
|
"auth": {
|
|
"client_token": "mock-token"
|
|
}
|
|
}`); err != nil {
|
|
t.Fatalf("failed to write response: %v", err)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) {
|
|
cfg.Address = server.URL
|
|
return vault.NewClient(cfg)
|
|
}
|
|
|
|
client, err := GetVaultClient(ctx, "http://ignored")
|
|
if err != nil {
|
|
t.Fatalf("expected success, got error: %v", err)
|
|
}
|
|
if client == nil {
|
|
t.Fatal("expected non-nil client")
|
|
}
|
|
if token := client.Token(); token != "mock-token" {
|
|
t.Errorf("expected token to be 'mock-token', got: %s", token)
|
|
}
|
|
}
|
|
|
|
type mockRegistryLookup struct{}
|
|
|
|
func (m *mockRegistryLookup) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
return []model.Subscription{
|
|
{
|
|
Subscriber: model.Subscriber{
|
|
SubscriberID: sub.SubscriberID,
|
|
Type: sub.Type,
|
|
},
|
|
KeyID: "mock-key-id",
|
|
SigningPublicKey: "mock-signing-pubkey",
|
|
EncrPublicKey: "mock-encryption-pubkey",
|
|
ValidFrom: time.Now().Add(-time.Hour),
|
|
ValidUntil: time.Now().Add(time.Hour),
|
|
Status: "SUBSCRIBED",
|
|
Created: time.Now(),
|
|
Updated: time.Now(),
|
|
Nonce: "mock-nonce",
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func TestNewSuccess(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg *Config
|
|
cache definition.Cache
|
|
registry definition.RegistryLookup
|
|
mockVaultStatus int
|
|
mockVaultBody string
|
|
}{
|
|
{
|
|
name: "valid config",
|
|
cfg: &Config{
|
|
VaultAddr: "http://dummy",
|
|
KVVersion: "v2",
|
|
},
|
|
cache: &mockCache{},
|
|
registry: &mockRegistryLookup{},
|
|
mockVaultStatus: http.StatusOK,
|
|
mockVaultBody: `{}`,
|
|
},
|
|
}
|
|
|
|
originalGetVaultClient := getVaultClient
|
|
defer func() { getVaultClient = originalGetVaultClient }()
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
vaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(tt.mockVaultStatus)
|
|
fmt.Fprint(w, tt.mockVaultBody)
|
|
}))
|
|
defer vaultServer.Close()
|
|
|
|
tt.cfg.VaultAddr = vaultServer.URL
|
|
|
|
getVaultClient = func(ctx context.Context, addr string) (*vault.Client, error) {
|
|
cfg := vault.DefaultConfig()
|
|
cfg.Address = addr
|
|
return vault.NewClient(cfg)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
km, cleanup, err := New(ctx, tt.cache, tt.registry, tt.cfg)
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if km == nil {
|
|
t.Fatalf("expected KeyMgr instance, got nil")
|
|
}
|
|
if cleanup == nil {
|
|
t.Fatalf("expected cleanup function, got nil")
|
|
}
|
|
_ = cleanup()
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewFailure(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg *Config
|
|
cache definition.Cache
|
|
registry definition.RegistryLookup
|
|
mockVaultStatus int
|
|
mockVaultBody string
|
|
}{
|
|
{
|
|
name: "nil cache",
|
|
cfg: &Config{
|
|
VaultAddr: "http://dummy",
|
|
KVVersion: "v2",
|
|
},
|
|
cache: nil,
|
|
registry: &mockRegistryLookup{},
|
|
mockVaultStatus: http.StatusOK,
|
|
mockVaultBody: `{}`,
|
|
},
|
|
{
|
|
name: "nil registry",
|
|
cfg: &Config{
|
|
VaultAddr: "http://dummy",
|
|
KVVersion: "v2",
|
|
},
|
|
cache: &mockCache{},
|
|
registry: nil,
|
|
mockVaultStatus: http.StatusOK,
|
|
mockVaultBody: `{}`,
|
|
},
|
|
{
|
|
name: "invalid config",
|
|
cfg: &Config{
|
|
VaultAddr: "", // Invalid
|
|
KVVersion: "v3", // Unsupported
|
|
},
|
|
cache: &mockCache{},
|
|
registry: &mockRegistryLookup{},
|
|
mockVaultStatus: http.StatusOK,
|
|
mockVaultBody: `{}`,
|
|
},
|
|
{
|
|
name: "vault client creation failure",
|
|
cfg: &Config{
|
|
VaultAddr: "http://dummy",
|
|
KVVersion: "v2",
|
|
},
|
|
cache: &mockCache{},
|
|
registry: &mockRegistryLookup{},
|
|
mockVaultStatus: http.StatusOK,
|
|
mockVaultBody: `{}`,
|
|
},
|
|
}
|
|
|
|
originalGetVaultClient := getVaultClient
|
|
defer func() { getVaultClient = originalGetVaultClient }()
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
vaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(tt.mockVaultStatus)
|
|
fmt.Fprint(w, tt.mockVaultBody)
|
|
}))
|
|
defer vaultServer.Close()
|
|
|
|
if tt.cfg != nil {
|
|
tt.cfg.VaultAddr = vaultServer.URL
|
|
}
|
|
|
|
getVaultClient = func(ctx context.Context, addr string) (*vault.Client, error) {
|
|
if tt.name == "vault client creation failure" {
|
|
return nil, errors.New("simulated vault client creation error")
|
|
}
|
|
cfg := vault.DefaultConfig()
|
|
cfg.Address = addr
|
|
return vault.NewClient(cfg)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
km, cleanup, err := New(ctx, tt.cache, tt.registry, tt.cfg)
|
|
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
if km != nil {
|
|
t.Error("expected KeyMgr to be nil, got non-nil")
|
|
}
|
|
if cleanup != nil {
|
|
t.Error("expected cleanup to be nil, got non-nil")
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
func TestStorePrivateKeysSuccess(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
keys := &model.Keyset{
|
|
UniqueKeyID: "uuid",
|
|
SigningPublic: "signPub",
|
|
SigningPrivate: "signPriv",
|
|
EncrPublic: "encrPub",
|
|
EncrPrivate: "encrPriv",
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
kvVersion string
|
|
keyID string
|
|
keys *model.Keyset
|
|
}{
|
|
{
|
|
name: "success kv v1",
|
|
kvVersion: "v1",
|
|
keyID: "mykeyid",
|
|
keys: keys,
|
|
},
|
|
{
|
|
name: "success kv v2",
|
|
kvVersion: "v2",
|
|
keyID: "mykeyid",
|
|
keys: keys,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
expectedPath := ""
|
|
if tt.kvVersion == "v2" {
|
|
expectedPath = "/v1/secret/data/keys/" + tt.keyID
|
|
} else {
|
|
expectedPath = "/v1/secret/keys/" + tt.keyID
|
|
}
|
|
|
|
if r.URL.Path != expectedPath {
|
|
t.Errorf("unexpected request path: got %s, want %s", r.URL.Path, expectedPath)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(200)
|
|
fmt.Fprintln(w, `{"data":{}}`)
|
|
}))
|
|
defer server.Close()
|
|
|
|
config := api.DefaultConfig()
|
|
config.Address = server.URL
|
|
client, err := api.NewClient(config)
|
|
if err != nil {
|
|
t.Fatalf("failed to create Vault client: %v", err)
|
|
}
|
|
|
|
km := &KeyMgr{
|
|
VaultClient: client,
|
|
KvVersion: tt.kvVersion,
|
|
}
|
|
|
|
err = km.InsertKeyset(ctx, tt.keyID, tt.keys)
|
|
if err != nil {
|
|
t.Errorf("expected no error, got %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStorePrivateKeysFailure(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
keys := &model.Keyset{
|
|
UniqueKeyID: "uuid",
|
|
SigningPublic: "signPub",
|
|
SigningPrivate: "signPriv",
|
|
EncrPublic: "encrPub",
|
|
EncrPrivate: "encrPriv",
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
kvVersion string
|
|
keyID string
|
|
keys *model.Keyset
|
|
statusCode int // for HTTP error simulation
|
|
expectedErr string
|
|
}{
|
|
{
|
|
name: "empty keyID",
|
|
keyID: "",
|
|
keys: keys,
|
|
expectedErr: ErrEmptyKeyID.Error(),
|
|
},
|
|
{
|
|
name: "nil keys",
|
|
keyID: "mykeyid",
|
|
keys: nil,
|
|
expectedErr: ErrNilKeySet.Error(),
|
|
},
|
|
{
|
|
name: "vault write error",
|
|
kvVersion: "v1",
|
|
keyID: "mykeyid",
|
|
keys: keys,
|
|
statusCode: 500,
|
|
expectedErr: "failed to store secret in Vault at path secret/keys/mykeyid: Error making API request.",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var server *httptest.Server
|
|
if tt.statusCode != 0 {
|
|
// Setup test HTTP server to simulate Vault error
|
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
http.Error(w, "internal error", tt.statusCode)
|
|
}))
|
|
defer server.Close()
|
|
}
|
|
|
|
var client *api.Client
|
|
var err error
|
|
if server != nil {
|
|
config := api.DefaultConfig()
|
|
config.Address = server.URL
|
|
client, err = api.NewClient(config)
|
|
if err != nil {
|
|
t.Fatalf("failed to create Vault client: %v", err)
|
|
}
|
|
} else {
|
|
client = nil
|
|
}
|
|
|
|
km := &KeyMgr{
|
|
VaultClient: client,
|
|
KvVersion: tt.kvVersion,
|
|
}
|
|
|
|
err = km.InsertKeyset(ctx, tt.keyID, tt.keys)
|
|
|
|
if err == nil {
|
|
t.Fatalf("expected error %q but got nil", tt.expectedErr)
|
|
}
|
|
if !strings.Contains(err.Error(), tt.expectedErr) {
|
|
t.Errorf("expected error containing %q, got %v", tt.expectedErr, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDeletePrivateKeys(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
kvVersion string
|
|
keyID string
|
|
wantPath string
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "empty keyID",
|
|
kvVersion: "v1",
|
|
keyID: "",
|
|
wantErr: ErrEmptyKeyID,
|
|
},
|
|
{
|
|
name: "v1 delete",
|
|
kvVersion: "v1",
|
|
keyID: "key123",
|
|
wantPath: "/v1/secret/keys/key123/data/key123",
|
|
wantErr: nil,
|
|
},
|
|
{
|
|
name: "v2 delete",
|
|
kvVersion: "v2",
|
|
keyID: "key123",
|
|
wantPath: "/v1/secret/data/keys/key123/data/key123",
|
|
wantErr: nil,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// If empty keyID, no Vault calls, just check error
|
|
if tt.keyID == "" {
|
|
km := &KeyMgr{
|
|
KvVersion: tt.kvVersion,
|
|
VaultClient: nil,
|
|
}
|
|
err := km.DeleteKeyset(context.Background(), tt.keyID)
|
|
if err != tt.wantErr {
|
|
t.Errorf("expected error %v, got %v", tt.wantErr, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodDelete {
|
|
t.Errorf("Expected DELETE method, got %s", r.Method)
|
|
}
|
|
if r.URL.Path != tt.wantPath {
|
|
t.Errorf("Expected path %s, got %s", tt.wantPath, r.URL.Path)
|
|
}
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
vaultClient, err := NewVaultClient(&vault.Config{Address: ts.URL})
|
|
if err != nil {
|
|
t.Fatalf("failed to create vault client: %v", err)
|
|
}
|
|
|
|
km := &KeyMgr{
|
|
KvVersion: tt.kvVersion,
|
|
VaultClient: vaultClient,
|
|
}
|
|
|
|
err = km.DeleteKeyset(context.Background(), tt.keyID)
|
|
if err != tt.wantErr {
|
|
t.Errorf("DeletePrivateKeys() error = %v, want %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func setupMockVaultServer(t *testing.T, kvVersion, keyID string, success bool) *httptest.Server {
|
|
t.Helper()
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
expectedPathV1 := fmt.Sprintf("/v1/secret/keys/%s", keyID)
|
|
expectedPathV2 := fmt.Sprintf("/v1/secret/data/keys/%s", keyID)
|
|
|
|
if (kvVersion == "v2" && r.URL.Path != expectedPathV2) || (kvVersion != "v2" && r.URL.Path != expectedPathV1) {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
if !success {
|
|
http.Error(w, `{"errors":["key not found"]}`, http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
if kvVersion == "v2" {
|
|
resp := fmt.Sprintf(`{
|
|
"request_id": "req-1234",
|
|
"lease_id": "",
|
|
"renewable": false,
|
|
"lease_duration": 0,
|
|
"data": {
|
|
"data": {
|
|
"uniqueKeyID": "%s",
|
|
"signingPublicKey": "sign-pub",
|
|
"signingPrivateKey": "sign-priv",
|
|
"encrPublicKey": "encr-pub",
|
|
"encrPrivateKey": "encr-priv"
|
|
},
|
|
"metadata": {
|
|
"created_time": "2025-05-28T00:00:00Z",
|
|
"deletion_time": "",
|
|
"destroyed": false,
|
|
"version": 1
|
|
}
|
|
},
|
|
"warnings": null,
|
|
"auth": null
|
|
}`, keyID)
|
|
if _, err := w.Write([]byte(resp)); err != nil {
|
|
t.Fatalf("failed to write response: %v", err)
|
|
}
|
|
} else {
|
|
resp := fmt.Sprintf(`{
|
|
"request_id": "req-1234",
|
|
"lease_id": "",
|
|
"renewable": false,
|
|
"lease_duration": 0,
|
|
"data": {
|
|
"uniqueKeyID": "%s",
|
|
"signingPublicKey": "sign-pub",
|
|
"signingPrivateKey": "sign-priv",
|
|
"encrPublicKey": "encr-pub",
|
|
"encrPrivateKey": "encr-priv"
|
|
},
|
|
"warnings": null,
|
|
"auth": null
|
|
}`, keyID)
|
|
if _, err := w.Write([]byte(resp)); err != nil {
|
|
t.Fatalf("failed to write response: %v", err)
|
|
}
|
|
}
|
|
})
|
|
|
|
return httptest.NewServer(handler)
|
|
}
|
|
|
|
func TestKeysetSuccess(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
kvVersion string
|
|
keyID string
|
|
}{
|
|
{
|
|
name: "success with KV v2",
|
|
kvVersion: "v2",
|
|
keyID: "test-key-v2",
|
|
},
|
|
{
|
|
name: "success with KV v1",
|
|
kvVersion: "v1",
|
|
keyID: "test-key-v1",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ts := setupMockVaultServer(t, tt.kvVersion, tt.keyID, true)
|
|
defer ts.Close()
|
|
|
|
cfg := vault.DefaultConfig()
|
|
cfg.Address = ts.URL
|
|
|
|
client, err := vault.NewClient(cfg)
|
|
if err != nil {
|
|
t.Fatalf("failed to create Vault client: %v", err)
|
|
}
|
|
|
|
km := &KeyMgr{
|
|
VaultClient: client,
|
|
KvVersion: tt.kvVersion,
|
|
}
|
|
|
|
keys, err := km.Keyset(context.Background(), tt.keyID)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if keys == nil {
|
|
t.Fatalf("expected keys but got nil")
|
|
}
|
|
if keys.UniqueKeyID != tt.keyID {
|
|
t.Errorf("expected UniqueKeyID %q, got %q", tt.keyID, keys.UniqueKeyID)
|
|
}
|
|
if keys.SigningPrivate != "sign-priv" {
|
|
t.Errorf("expected SigningPrivate 'sign-priv', got %q", keys.SigningPrivate)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKeysetFailure(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
kvVersion string
|
|
keyID string
|
|
success bool
|
|
}{
|
|
{
|
|
name: "failure: vault returns 404 v2",
|
|
kvVersion: "v2",
|
|
keyID: "missing-key-v2",
|
|
success: false,
|
|
},
|
|
{
|
|
name: "failure: vault returns 404 v1",
|
|
kvVersion: "v1",
|
|
keyID: "missing-key-v1",
|
|
success: false,
|
|
},
|
|
{
|
|
name: "failure: empty keyID",
|
|
kvVersion: "v2",
|
|
keyID: "",
|
|
success: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var ts *httptest.Server
|
|
if tt.keyID != "" {
|
|
ts = setupMockVaultServer(t, tt.kvVersion, tt.keyID, tt.success)
|
|
defer ts.Close()
|
|
}
|
|
|
|
cfg := vault.DefaultConfig()
|
|
if ts != nil {
|
|
cfg.Address = ts.URL
|
|
} else {
|
|
// For empty keyID case or no mock server, use invalid URL to force error
|
|
cfg.Address = "http://invalid"
|
|
}
|
|
|
|
client, err := vault.NewClient(cfg)
|
|
if err != nil {
|
|
t.Fatalf("failed to create Vault client: %v", err)
|
|
}
|
|
|
|
km := &KeyMgr{
|
|
VaultClient: client,
|
|
KvVersion: tt.kvVersion,
|
|
}
|
|
|
|
keys, err := km.Keyset(context.Background(), tt.keyID)
|
|
if err == nil {
|
|
t.Fatalf("expected error but got nil")
|
|
}
|
|
if keys != nil {
|
|
t.Fatalf("expected nil keys but got %+v", keys)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidateParamsSuccess(t *testing.T) {
|
|
err := validateParams("someSubscriberID", "someUniqueKeyID")
|
|
if err != nil {
|
|
t.Fatalf("expected no error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestValidateParamsFailure(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
subscriberID string
|
|
uniqueKeyID string
|
|
wantErr error
|
|
}{
|
|
{
|
|
name: "empty subscriberID",
|
|
subscriberID: "",
|
|
uniqueKeyID: "validKeyID",
|
|
wantErr: ErrEmptySubscriberID,
|
|
},
|
|
{
|
|
name: "empty uniqueKeyID",
|
|
subscriberID: "validSubscriberID",
|
|
uniqueKeyID: "",
|
|
wantErr: ErrEmptyUniqueKeyID,
|
|
},
|
|
{
|
|
name: "both empty",
|
|
subscriberID: "",
|
|
uniqueKeyID: "",
|
|
wantErr: ErrEmptySubscriberID,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validateParams(tt.subscriberID, tt.uniqueKeyID)
|
|
if err == nil {
|
|
t.Fatalf("expected error %v but got nil", tt.wantErr)
|
|
}
|
|
if err != tt.wantErr {
|
|
t.Errorf("expected error %v, got %v", tt.wantErr, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLookupNPKeysSuccess(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cacheGetFunc func(ctx context.Context, key string) (string, error)
|
|
registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
|
|
expectedSigningPub string
|
|
expectedEncrPub string
|
|
}{
|
|
{
|
|
name: "Cache hit with valid keys",
|
|
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
|
|
return `{"SigningPublic":"mock-signing-public-key","EncrPublic":"mock-encryption-public-key"}`, nil
|
|
},
|
|
registryLookupFunc: nil,
|
|
expectedSigningPub: "mock-signing-public-key",
|
|
expectedEncrPub: "mock-encryption-public-key",
|
|
},
|
|
{
|
|
name: "Cache miss and registry success",
|
|
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
|
|
|
|
return "", nil
|
|
},
|
|
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
return []model.Subscription{
|
|
{
|
|
Subscriber: model.Subscriber{
|
|
SubscriberID: sub.SubscriberID,
|
|
},
|
|
KeyID: sub.KeyID,
|
|
SigningPublicKey: "mock-signing-public-key",
|
|
EncrPublicKey: "mock-encryption-public-key",
|
|
},
|
|
}, nil
|
|
},
|
|
expectedSigningPub: "mock-signing-public-key",
|
|
expectedEncrPub: "mock-encryption-public-key",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Set up the KeyMgr with mocks
|
|
km := &KeyMgr{
|
|
Cache: &mockCache{
|
|
GetFunc: tt.cacheGetFunc,
|
|
},
|
|
Registry: &mockRegistry{
|
|
LookupFunc: tt.registryLookupFunc,
|
|
},
|
|
}
|
|
|
|
// Call the method
|
|
signingPublic, encrPublic, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id")
|
|
|
|
// Validate no errors in success cases
|
|
if err != nil {
|
|
t.Fatalf("LookupNPKeys() unexpected error: %v", err)
|
|
}
|
|
|
|
// Validate returned public keys
|
|
if signingPublic != tt.expectedSigningPub {
|
|
t.Errorf("SigningPublic = %v, want %v", signingPublic, tt.expectedSigningPub)
|
|
}
|
|
if encrPublic != tt.expectedEncrPub {
|
|
t.Errorf("EncrPublic = %v, want %v", encrPublic, tt.expectedEncrPub)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLookupNPKeysFailure(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cacheGetFunc func(ctx context.Context, key string) (string, error)
|
|
registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
|
|
expectedError string
|
|
}{
|
|
{
|
|
name: "Cache miss and registry failure",
|
|
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
|
|
return "", nil
|
|
},
|
|
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
return nil, fmt.Errorf("registry down")
|
|
},
|
|
expectedError: "registry down",
|
|
},
|
|
{
|
|
name: "Cache miss and registry returns no subscriber",
|
|
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
|
|
return "", nil
|
|
},
|
|
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
return nil, nil
|
|
},
|
|
expectedError: "no subscriber found with given credentials",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Set up the KeyMgr with mocks
|
|
km := &KeyMgr{
|
|
Cache: &mockCache{
|
|
GetFunc: tt.cacheGetFunc,
|
|
},
|
|
Registry: &mockRegistry{
|
|
LookupFunc: tt.registryLookupFunc,
|
|
},
|
|
}
|
|
_, _, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id")
|
|
if err == nil {
|
|
t.Fatalf("expected an error but got none")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), tt.expectedError) {
|
|
t.Errorf("expected error to contain %v, got %v", tt.expectedError, err.Error())
|
|
}
|
|
})
|
|
}
|
|
}
|