change in the keymanager interface

This commit is contained in:
MohitKatare-protean
2025-05-22 12:53:54 +05:30
parent 0eb0cc572f
commit ea872338f6
3 changed files with 154 additions and 276 deletions

View File

@@ -8,13 +8,11 @@ import (
// KeyManager defines the interface for key management operations/methods.
type KeyManager interface {
GenerateKeyPairs() (*model.Keyset, error)
StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error
SigningPrivateKey(ctx context.Context, keyID string) (string, string, error)
EncrPrivateKey(ctx context.Context, keyID string) (string, string, error)
SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error)
EncrPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error)
DeletePrivateKeys(ctx context.Context, keyID string) error
GenerateKeyset() (*model.Keyset, error)
InsertKeyset(ctx context.Context, keyID string, keyset *model.Keyset) error
Keyset(ctx context.Context, keyID string) (*model.Keyset, error)
LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error)
DeleteKeyset(ctx context.Context, keyID string) error
}
// KeyManagerProvider initializes a new signer instance.

View File

@@ -10,7 +10,6 @@ import (
"errors"
"fmt"
"os"
"time"
"github.com/beckn/beckn-onix/pkg/log"
"github.com/beckn/beckn-onix/pkg/model"
@@ -175,8 +174,8 @@ var (
uuidGenFunc = uuid.NewRandom
)
// GenerateKeyPairs generates a new signing (Ed25519) and encryption (X25519) key pair.
func (km *KeyMgr) GenerateKeyPairs() (*model.Keyset, error) {
// GenerateKeyset generates a new signing (Ed25519) and encryption (X25519) key pair.
func (km *KeyMgr) GenerateKeyset() (*model.Keyset, error) {
signingPublic, signingPrivate, err := ed25519KeyGenFunc(rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate signing key pair: %w", err)
@@ -200,8 +199,8 @@ func (km *KeyMgr) GenerateKeyPairs() (*model.Keyset, error) {
}, nil
}
// StorePrivateKeys stores the given keyset in Vault under the specified key ID.
func (km *KeyMgr) StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error {
// InsertKeyset stores the given keyset in Vault under the specified key ID.
func (km *KeyMgr) InsertKeyset(ctx context.Context, keyID string, keys *model.Keyset) error {
if keyID == "" {
return ErrEmptyKeyID
}
@@ -233,44 +232,8 @@ func (km *KeyMgr) StorePrivateKeys(ctx context.Context, keyID string, keys *mode
return nil
}
// SigningPrivateKey retrieves the unique key ID and signing private key for the given key ID.
func (km *KeyMgr) SigningPrivateKey(ctx context.Context, keyID string) (string, string, error) {
keys, err := km.getKeys(ctx, keyID)
if err != nil {
return "", "", err
}
return keys.UniqueKeyID, keys.SigningPrivate, nil
}
// EncrPrivateKey retrieves the unique key ID and encryption private key for the given key ID.
func (km *KeyMgr) EncrPrivateKey(ctx context.Context, keyID string) (string, string, error) {
keys, err := km.getKeys(ctx, keyID)
if err != nil {
return "", "", err
}
return keys.UniqueKeyID, keys.EncrPrivate, nil
}
// SigningPublicKey returns the signing public key for the given subscriber ID and key ID.
func (km *KeyMgr) SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) {
keys, err := km.getPublicKeys(ctx, subscriberID, uniqueKeyID)
if err != nil {
return "", err
}
return keys.SigningPublic, nil
}
// EncrPublicKey returns the encryption public key for the given subscriber ID and key ID.
func (km *KeyMgr) EncrPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) {
keys, err := km.getPublicKeys(ctx, subscriberID, uniqueKeyID)
if err != nil {
return "", err
}
return keys.EncrPublic, nil
}
// DeletePrivateKeys deletes the private keys for the given key ID from Vault.
func (km *KeyMgr) DeletePrivateKeys(ctx context.Context, keyID string) error {
// DeleteKeyset deletes the private keys for the given key ID from Vault.
func (km *KeyMgr) DeleteKeyset(ctx context.Context, keyID string) error {
if keyID == "" {
return ErrEmptyKeyID
}
@@ -283,8 +246,8 @@ func (km *KeyMgr) DeletePrivateKeys(ctx context.Context, keyID string) error {
return km.VaultClient.KVv2(path).Delete(ctx, keyID)
}
// getKeys retrieves the full keyset from Vault for the given key ID.
func (km *KeyMgr) getKeys(ctx context.Context, keyID string) (*model.Keyset, error) {
// Keyset retrieves the keyset for the given key ID from Vault and public keys from the registry.
func (km *KeyMgr) Keyset(ctx context.Context, keyID string) (*model.Keyset, error) {
if keyID == "" {
return nil, ErrEmptyKeyID
}
@@ -324,32 +287,16 @@ func (km *KeyMgr) getKeys(ctx context.Context, keyID string) (*model.Keyset, err
}, nil
}
// getPublicKeys fetches the public keys from cache or registry for the given subscriber and key ID.
func (km *KeyMgr) getPublicKeys(ctx context.Context, subscriberID, uniqueKeyID string) (*model.Keyset, error) {
if err := validateParams(subscriberID, uniqueKeyID); err != nil {
return nil, err
}
// LookupNPKeys retrieves the signing and encryption public keys for the given subscriber ID and unique key ID.
func (km *KeyMgr) LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error) {
cacheKey := fmt.Sprintf("%s_%s", subscriberID, uniqueKeyID)
cachedData, err := km.Cache.Get(ctx, cacheKey)
if err == nil {
var keys model.Keyset
if err := json.Unmarshal([]byte(cachedData), &keys); err == nil {
return &keys, nil
return keys.SigningPublic, keys.EncrPublic, nil
}
}
publicKeys, err := km.lookupRegistry(ctx, subscriberID, uniqueKeyID)
if err != nil {
return nil, err
}
cacheValue, err := json.Marshal(publicKeys)
if err == nil {
_ = km.Cache.Set(ctx, cacheKey, string(cacheValue), time.Hour)
}
return publicKeys, nil
}
// lookupRegistry queries the registry for public keys based on subscriber ID and key ID.
func (km *KeyMgr) lookupRegistry(ctx context.Context, subscriberID, uniqueKeyID string) (*model.Keyset, error) {
subscribers, err := km.Registry.Lookup(ctx, &model.Subscription{
Subscriber: model.Subscriber{
SubscriberID: subscriberID,
@@ -357,15 +304,12 @@ func (km *KeyMgr) lookupRegistry(ctx context.Context, subscriberID, uniqueKeyID
KeyID: uniqueKeyID,
})
if err != nil {
return nil, fmt.Errorf("failed to lookup registry: %w", err)
return "", "", fmt.Errorf("failed to lookup registry: %w", err)
}
if len(subscribers) == 0 {
return nil, ErrSubscriberNotFound
return "", "", ErrSubscriberNotFound
}
return &model.Keyset{
SigningPublic: subscribers[0].SigningPublicKey,
EncrPublic: subscribers[0].EncrPublicKey,
}, nil
return subscribers[0].SigningPublicKey, subscribers[0].EncrPublicKey, nil
}
// validateParams checks that subscriberID and uniqueKeyID are not empty.

View File

@@ -50,7 +50,9 @@ func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]m
}, nil
}
type mockCache struct{}
type mockCache struct {
GetFunc func(ctx context.Context, key string) (string, error)
}
func (m *mockCache) Get(ctx context.Context, key string) (string, error) {
return "", nil
@@ -193,7 +195,7 @@ func TestGenerateKeyPairs(t *testing.T) {
}
km := &KeyMgr{}
keyset, err := km.GenerateKeyPairs()
keyset, err := km.GenerateKeyset()
if tt.expectErr {
if err == nil {
@@ -589,7 +591,7 @@ func TestStorePrivateKeysSuccess(t *testing.T) {
KvVersion: tt.kvVersion,
}
err = km.StorePrivateKeys(ctx, tt.keyID, tt.keys)
err = km.InsertKeyset(ctx, tt.keyID, tt.keys)
if err != nil {
t.Errorf("expected no error, got %v", err)
}
@@ -667,7 +669,7 @@ func TestStorePrivateKeysFailure(t *testing.T) {
KvVersion: tt.kvVersion,
}
err = km.StorePrivateKeys(ctx, tt.keyID, tt.keys)
err = km.InsertKeyset(ctx, tt.keyID, tt.keys)
if err == nil {
t.Fatalf("expected error %q but got nil", tt.expectedErr)
@@ -717,7 +719,7 @@ func TestDeletePrivateKeys(t *testing.T) {
KvVersion: tt.kvVersion,
VaultClient: nil,
}
err := km.DeletePrivateKeys(context.Background(), tt.keyID)
err := km.DeleteKeyset(context.Background(), tt.keyID)
if err != tt.wantErr {
t.Errorf("expected error %v, got %v", tt.wantErr, err)
}
@@ -745,7 +747,7 @@ func TestDeletePrivateKeys(t *testing.T) {
VaultClient: vaultClient,
}
err = km.DeletePrivateKeys(context.Background(), tt.keyID)
err = km.DeleteKeyset(context.Background(), tt.keyID)
if err != tt.wantErr {
t.Errorf("DeletePrivateKeys() error = %v, want %v", err, tt.wantErr)
}
@@ -809,7 +811,7 @@ func setupMockVaultServer(t *testing.T, kvVersion, keyID string, success bool) *
return httptest.NewServer(handler)
}
func TestGetKeysSuccess(t *testing.T) {
func TestKeysetSuccess(t *testing.T) {
tests := []struct {
name string
kvVersion string
@@ -845,7 +847,7 @@ func TestGetKeysSuccess(t *testing.T) {
KvVersion: tt.kvVersion,
}
keys, err := km.getKeys(context.Background(), tt.keyID)
keys, err := km.Keyset(context.Background(), tt.keyID)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -862,7 +864,7 @@ func TestGetKeysSuccess(t *testing.T) {
}
}
func TestGetKeysFailure(t *testing.T) {
func TestKeysetFailure(t *testing.T) {
tests := []struct {
name string
kvVersion string
@@ -915,7 +917,7 @@ func TestGetKeysFailure(t *testing.T) {
KvVersion: tt.kvVersion,
}
keys, err := km.getKeys(context.Background(), tt.keyID)
keys, err := km.Keyset(context.Background(), tt.keyID)
if err == nil {
t.Fatalf("expected error but got nil")
}
@@ -926,195 +928,6 @@ func TestGetKeysFailure(t *testing.T) {
}
}
func TestGetPublicKeysSuccess(t *testing.T) {
km := &KeyMgr{
Cache: &mockCache{},
Registry: &mockRegistry{
LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return []model.Subscription{
{
Subscriber: model.Subscriber{
SubscriberID: sub.SubscriberID,
},
KeyID: sub.KeyID,
SigningPublicKey: "mock-signing-public-key",
EncrPublicKey: "mock-encryption-public-key",
},
}, nil
},
},
}
got, err := km.getPublicKeys(context.Background(), "sub-id", "key-id")
if err != nil {
t.Fatalf("getPublicKeys() unexpected error: %v", err)
}
if got == nil {
t.Fatal("getPublicKeys() returned nil Keyset")
}
if got.SigningPublic != "mock-signing-public-key" {
t.Errorf("SigningPublic = %v, want %v", got.SigningPublic, "mock-signing-public-key")
}
if got.EncrPublic != "mock-encryption-public-key" {
t.Errorf("EncrPublic = %v, want %v", got.EncrPublic, "mock-encryption-public-key")
}
}
func TestGetPublicKeysFailure(t *testing.T) {
type fields struct {
cache definition.Cache
registry definition.RegistryLookup
}
type args struct {
subscriberID string
uniqueKeyID string
}
tests := []struct {
name string
fields fields
args args
errMessage string
}{
{
name: "failure - invalid parameters",
fields: fields{
cache: &mockCache{},
registry: &mockRegistry{},
},
args: args{
subscriberID: "",
uniqueKeyID: "",
},
errMessage: "invalid",
},
{
name: "failure - registry lookup fails",
fields: fields{
cache: &mockCache{},
registry: &mockRegistry{
LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return nil, fmt.Errorf("registry down")
},
},
},
args: args{
subscriberID: "sub-id",
uniqueKeyID: "key-id",
},
errMessage: "registry down",
},
{
name: "failure - registry returns empty",
fields: fields{
cache: &mockCache{},
registry: &mockRegistry{
LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return []model.Subscription{}, nil
},
},
},
args: args{
subscriberID: "sub-id",
uniqueKeyID: "key-id",
},
errMessage: "no subscriber found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
km := &KeyMgr{
Cache: tt.fields.cache,
Registry: tt.fields.registry,
}
got, err := km.getPublicKeys(context.Background(), tt.args.subscriberID, tt.args.uniqueKeyID)
if err == nil {
t.Errorf("getPublicKeys() expected error but got none, result: %v", got)
return
}
if !strings.Contains(err.Error(), tt.errMessage) {
t.Errorf("getPublicKeys() error = %v, want error message to contain %v", err.Error(), tt.errMessage)
}
})
}
}
func TestLookupRegistrySuccess(t *testing.T) {
km := &KeyMgr{
Registry: &mockRegistry{
LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return []model.Subscription{
{
Subscriber: model.Subscriber{
SubscriberID: sub.SubscriberID,
},
KeyID: sub.KeyID,
SigningPublicKey: "signing-key",
EncrPublicKey: "encryption-key",
},
}, nil
},
},
}
got, err := km.lookupRegistry(context.Background(), "test-subscriber", "key123")
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
want := &model.Keyset{
SigningPublic: "signing-key",
EncrPublic: "encryption-key",
}
if got.SigningPublic != want.SigningPublic {
t.Errorf("SigningPublic = %v, want %v", got.SigningPublic, want.SigningPublic)
}
if got.EncrPublic != want.EncrPublic {
t.Errorf("EncrPublic = %v, want %v", got.EncrPublic, want.EncrPublic)
}
}
func TestLookupRegistryFailure(t *testing.T) {
tests := []struct {
name string
mockLookup func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
wantErr error
}{
{
name: "lookup error",
mockLookup: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return nil, fmt.Errorf("registry failure")
},
wantErr: fmt.Errorf("registry failure"),
},
{
name: "no subscriber found",
mockLookup: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return []model.Subscription{}, nil
},
wantErr: ErrSubscriberNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
km := &KeyMgr{
Registry: &mockRegistry{
LookupFunc: tt.mockLookup,
},
}
got, err := km.lookupRegistry(context.Background(), "some-id", "key-id")
if err == nil {
t.Fatalf("expected error, got none")
}
if got != nil {
t.Errorf("expected nil keyset, got %v", got)
}
})
}
}
func TestValidateParamsSuccess(t *testing.T) {
err := validateParams("someSubscriberID", "someUniqueKeyID")
if err != nil {
@@ -1161,3 +974,126 @@ func TestValidateParamsFailure(t *testing.T) {
})
}
}
func TestLookupNPKeysSuccess(t *testing.T) {
tests := []struct {
name string
cacheGetFunc func(ctx context.Context, key string) (string, error)
registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
expectedSigningPub string
expectedEncrPub string
}{
{
name: "Cache hit with valid keys",
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
return `{"SigningPublic":"mock-signing-public-key","EncrPublic":"mock-encryption-public-key"}`, nil
},
registryLookupFunc: nil,
expectedSigningPub: "mock-signing-public-key",
expectedEncrPub: "mock-encryption-public-key",
},
{
name: "Cache miss and registry success",
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
return "", nil
},
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return []model.Subscription{
{
Subscriber: model.Subscriber{
SubscriberID: sub.SubscriberID,
},
KeyID: sub.KeyID,
SigningPublicKey: "mock-signing-public-key",
EncrPublicKey: "mock-encryption-public-key",
},
}, nil
},
expectedSigningPub: "mock-signing-public-key",
expectedEncrPub: "mock-encryption-public-key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up the KeyMgr with mocks
km := &KeyMgr{
Cache: &mockCache{
GetFunc: tt.cacheGetFunc,
},
Registry: &mockRegistry{
LookupFunc: tt.registryLookupFunc,
},
}
// Call the method
signingPublic, encrPublic, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id")
// Validate no errors in success cases
if err != nil {
t.Fatalf("LookupNPKeys() unexpected error: %v", err)
}
// Validate returned public keys
if signingPublic != tt.expectedSigningPub {
t.Errorf("SigningPublic = %v, want %v", signingPublic, tt.expectedSigningPub)
}
if encrPublic != tt.expectedEncrPub {
t.Errorf("EncrPublic = %v, want %v", encrPublic, tt.expectedEncrPub)
}
})
}
}
func TestLookupNPKeysFailure(t *testing.T) {
tests := []struct {
name string
cacheGetFunc func(ctx context.Context, key string) (string, error)
registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
expectedError string
}{
{
name: "Cache miss and registry failure",
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
return "", nil
},
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return nil, fmt.Errorf("registry down")
},
expectedError: "registry down",
},
{
name: "Cache miss and registry returns no subscriber",
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
return "", nil
},
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
return nil, nil
},
expectedError: "no subscriber found with given credentials",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up the KeyMgr with mocks
km := &KeyMgr{
Cache: &mockCache{
GetFunc: tt.cacheGetFunc,
},
Registry: &mockRegistry{
LookupFunc: tt.registryLookupFunc,
},
}
_, _, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id")
if err == nil {
t.Fatalf("expected an error but got none")
}
if !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error to contain %v, got %v", tt.expectedError, err.Error())
}
})
}
}