diff --git a/pkg/plugin/definition/keymanager.go b/pkg/plugin/definition/keymanager.go index f2c0e2f..749ad0a 100644 --- a/pkg/plugin/definition/keymanager.go +++ b/pkg/plugin/definition/keymanager.go @@ -8,13 +8,11 @@ import ( // KeyManager defines the interface for key management operations/methods. type KeyManager interface { - GenerateKeyPairs() (*model.Keyset, error) - StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error - SigningPrivateKey(ctx context.Context, keyID string) (string, string, error) - EncrPrivateKey(ctx context.Context, keyID string) (string, string, error) - SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) - EncrPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) - DeletePrivateKeys(ctx context.Context, keyID string) error + GenerateKeyset() (*model.Keyset, error) + InsertKeyset(ctx context.Context, keyID string, keyset *model.Keyset) error + Keyset(ctx context.Context, keyID string) (*model.Keyset, error) + LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error) + DeleteKeyset(ctx context.Context, keyID string) error } // KeyManagerProvider initializes a new signer instance. diff --git a/pkg/plugin/implementation/keymanager/keymanager.go b/pkg/plugin/implementation/keymanager/keymanager.go index c50079d..c79b8c1 100644 --- a/pkg/plugin/implementation/keymanager/keymanager.go +++ b/pkg/plugin/implementation/keymanager/keymanager.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "os" - "time" "github.com/beckn/beckn-onix/pkg/log" "github.com/beckn/beckn-onix/pkg/model" @@ -175,8 +174,8 @@ var ( uuidGenFunc = uuid.NewRandom ) -// GenerateKeyPairs generates a new signing (Ed25519) and encryption (X25519) key pair. -func (km *KeyMgr) GenerateKeyPairs() (*model.Keyset, error) { +// GenerateKeyset generates a new signing (Ed25519) and encryption (X25519) key pair. +func (km *KeyMgr) GenerateKeyset() (*model.Keyset, error) { signingPublic, signingPrivate, err := ed25519KeyGenFunc(rand.Reader) if err != nil { return nil, fmt.Errorf("failed to generate signing key pair: %w", err) @@ -200,8 +199,8 @@ func (km *KeyMgr) GenerateKeyPairs() (*model.Keyset, error) { }, nil } -// StorePrivateKeys stores the given keyset in Vault under the specified key ID. -func (km *KeyMgr) StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error { +// InsertKeyset stores the given keyset in Vault under the specified key ID. +func (km *KeyMgr) InsertKeyset(ctx context.Context, keyID string, keys *model.Keyset) error { if keyID == "" { return ErrEmptyKeyID } @@ -233,44 +232,8 @@ func (km *KeyMgr) StorePrivateKeys(ctx context.Context, keyID string, keys *mode return nil } -// SigningPrivateKey retrieves the unique key ID and signing private key for the given key ID. -func (km *KeyMgr) SigningPrivateKey(ctx context.Context, keyID string) (string, string, error) { - keys, err := km.getKeys(ctx, keyID) - if err != nil { - return "", "", err - } - return keys.UniqueKeyID, keys.SigningPrivate, nil -} - -// EncrPrivateKey retrieves the unique key ID and encryption private key for the given key ID. -func (km *KeyMgr) EncrPrivateKey(ctx context.Context, keyID string) (string, string, error) { - keys, err := km.getKeys(ctx, keyID) - if err != nil { - return "", "", err - } - return keys.UniqueKeyID, keys.EncrPrivate, nil -} - -// SigningPublicKey returns the signing public key for the given subscriber ID and key ID. -func (km *KeyMgr) SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) { - keys, err := km.getPublicKeys(ctx, subscriberID, uniqueKeyID) - if err != nil { - return "", err - } - return keys.SigningPublic, nil -} - -// EncrPublicKey returns the encryption public key for the given subscriber ID and key ID. -func (km *KeyMgr) EncrPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) { - keys, err := km.getPublicKeys(ctx, subscriberID, uniqueKeyID) - if err != nil { - return "", err - } - return keys.EncrPublic, nil -} - -// DeletePrivateKeys deletes the private keys for the given key ID from Vault. -func (km *KeyMgr) DeletePrivateKeys(ctx context.Context, keyID string) error { +// DeleteKeyset deletes the private keys for the given key ID from Vault. +func (km *KeyMgr) DeleteKeyset(ctx context.Context, keyID string) error { if keyID == "" { return ErrEmptyKeyID } @@ -283,8 +246,8 @@ func (km *KeyMgr) DeletePrivateKeys(ctx context.Context, keyID string) error { return km.VaultClient.KVv2(path).Delete(ctx, keyID) } -// getKeys retrieves the full keyset from Vault for the given key ID. -func (km *KeyMgr) getKeys(ctx context.Context, keyID string) (*model.Keyset, error) { +// Keyset retrieves the keyset for the given key ID from Vault and public keys from the registry. +func (km *KeyMgr) Keyset(ctx context.Context, keyID string) (*model.Keyset, error) { if keyID == "" { return nil, ErrEmptyKeyID } @@ -324,32 +287,16 @@ func (km *KeyMgr) getKeys(ctx context.Context, keyID string) (*model.Keyset, err }, nil } -// getPublicKeys fetches the public keys from cache or registry for the given subscriber and key ID. -func (km *KeyMgr) getPublicKeys(ctx context.Context, subscriberID, uniqueKeyID string) (*model.Keyset, error) { - if err := validateParams(subscriberID, uniqueKeyID); err != nil { - return nil, err - } +// LookupNPKeys retrieves the signing and encryption public keys for the given subscriber ID and unique key ID. +func (km *KeyMgr) LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error) { cacheKey := fmt.Sprintf("%s_%s", subscriberID, uniqueKeyID) cachedData, err := km.Cache.Get(ctx, cacheKey) if err == nil { var keys model.Keyset if err := json.Unmarshal([]byte(cachedData), &keys); err == nil { - return &keys, nil + return keys.SigningPublic, keys.EncrPublic, nil } } - publicKeys, err := km.lookupRegistry(ctx, subscriberID, uniqueKeyID) - if err != nil { - return nil, err - } - cacheValue, err := json.Marshal(publicKeys) - if err == nil { - _ = km.Cache.Set(ctx, cacheKey, string(cacheValue), time.Hour) - } - return publicKeys, nil -} - -// lookupRegistry queries the registry for public keys based on subscriber ID and key ID. -func (km *KeyMgr) lookupRegistry(ctx context.Context, subscriberID, uniqueKeyID string) (*model.Keyset, error) { subscribers, err := km.Registry.Lookup(ctx, &model.Subscription{ Subscriber: model.Subscriber{ SubscriberID: subscriberID, @@ -357,15 +304,12 @@ func (km *KeyMgr) lookupRegistry(ctx context.Context, subscriberID, uniqueKeyID KeyID: uniqueKeyID, }) if err != nil { - return nil, fmt.Errorf("failed to lookup registry: %w", err) + return "", "", fmt.Errorf("failed to lookup registry: %w", err) } if len(subscribers) == 0 { - return nil, ErrSubscriberNotFound + return "", "", ErrSubscriberNotFound } - return &model.Keyset{ - SigningPublic: subscribers[0].SigningPublicKey, - EncrPublic: subscribers[0].EncrPublicKey, - }, nil + return subscribers[0].SigningPublicKey, subscribers[0].EncrPublicKey, nil } // validateParams checks that subscriberID and uniqueKeyID are not empty. diff --git a/pkg/plugin/implementation/keymanager/keymanager_test.go b/pkg/plugin/implementation/keymanager/keymanager_test.go index c66b2fe..47d0fa2 100644 --- a/pkg/plugin/implementation/keymanager/keymanager_test.go +++ b/pkg/plugin/implementation/keymanager/keymanager_test.go @@ -50,7 +50,9 @@ func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]m }, nil } -type mockCache struct{} +type mockCache struct { + GetFunc func(ctx context.Context, key string) (string, error) +} func (m *mockCache) Get(ctx context.Context, key string) (string, error) { return "", nil @@ -193,7 +195,7 @@ func TestGenerateKeyPairs(t *testing.T) { } km := &KeyMgr{} - keyset, err := km.GenerateKeyPairs() + keyset, err := km.GenerateKeyset() if tt.expectErr { if err == nil { @@ -589,7 +591,7 @@ func TestStorePrivateKeysSuccess(t *testing.T) { KvVersion: tt.kvVersion, } - err = km.StorePrivateKeys(ctx, tt.keyID, tt.keys) + err = km.InsertKeyset(ctx, tt.keyID, tt.keys) if err != nil { t.Errorf("expected no error, got %v", err) } @@ -667,7 +669,7 @@ func TestStorePrivateKeysFailure(t *testing.T) { KvVersion: tt.kvVersion, } - err = km.StorePrivateKeys(ctx, tt.keyID, tt.keys) + err = km.InsertKeyset(ctx, tt.keyID, tt.keys) if err == nil { t.Fatalf("expected error %q but got nil", tt.expectedErr) @@ -717,7 +719,7 @@ func TestDeletePrivateKeys(t *testing.T) { KvVersion: tt.kvVersion, VaultClient: nil, } - err := km.DeletePrivateKeys(context.Background(), tt.keyID) + err := km.DeleteKeyset(context.Background(), tt.keyID) if err != tt.wantErr { t.Errorf("expected error %v, got %v", tt.wantErr, err) } @@ -745,7 +747,7 @@ func TestDeletePrivateKeys(t *testing.T) { VaultClient: vaultClient, } - err = km.DeletePrivateKeys(context.Background(), tt.keyID) + err = km.DeleteKeyset(context.Background(), tt.keyID) if err != tt.wantErr { t.Errorf("DeletePrivateKeys() error = %v, want %v", err, tt.wantErr) } @@ -809,7 +811,7 @@ func setupMockVaultServer(t *testing.T, kvVersion, keyID string, success bool) * return httptest.NewServer(handler) } -func TestGetKeysSuccess(t *testing.T) { +func TestKeysetSuccess(t *testing.T) { tests := []struct { name string kvVersion string @@ -845,7 +847,7 @@ func TestGetKeysSuccess(t *testing.T) { KvVersion: tt.kvVersion, } - keys, err := km.getKeys(context.Background(), tt.keyID) + keys, err := km.Keyset(context.Background(), tt.keyID) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -862,7 +864,7 @@ func TestGetKeysSuccess(t *testing.T) { } } -func TestGetKeysFailure(t *testing.T) { +func TestKeysetFailure(t *testing.T) { tests := []struct { name string kvVersion string @@ -915,7 +917,7 @@ func TestGetKeysFailure(t *testing.T) { KvVersion: tt.kvVersion, } - keys, err := km.getKeys(context.Background(), tt.keyID) + keys, err := km.Keyset(context.Background(), tt.keyID) if err == nil { t.Fatalf("expected error but got nil") } @@ -926,195 +928,6 @@ func TestGetKeysFailure(t *testing.T) { } } -func TestGetPublicKeysSuccess(t *testing.T) { - km := &KeyMgr{ - Cache: &mockCache{}, - Registry: &mockRegistry{ - LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { - return []model.Subscription{ - { - Subscriber: model.Subscriber{ - SubscriberID: sub.SubscriberID, - }, - KeyID: sub.KeyID, - SigningPublicKey: "mock-signing-public-key", - EncrPublicKey: "mock-encryption-public-key", - }, - }, nil - }, - }, - } - - got, err := km.getPublicKeys(context.Background(), "sub-id", "key-id") - if err != nil { - t.Fatalf("getPublicKeys() unexpected error: %v", err) - } - if got == nil { - t.Fatal("getPublicKeys() returned nil Keyset") - } - if got.SigningPublic != "mock-signing-public-key" { - t.Errorf("SigningPublic = %v, want %v", got.SigningPublic, "mock-signing-public-key") - } - if got.EncrPublic != "mock-encryption-public-key" { - t.Errorf("EncrPublic = %v, want %v", got.EncrPublic, "mock-encryption-public-key") - } -} - -func TestGetPublicKeysFailure(t *testing.T) { - type fields struct { - cache definition.Cache - registry definition.RegistryLookup - } - type args struct { - subscriberID string - uniqueKeyID string - } - tests := []struct { - name string - fields fields - args args - errMessage string - }{ - { - name: "failure - invalid parameters", - fields: fields{ - cache: &mockCache{}, - registry: &mockRegistry{}, - }, - args: args{ - subscriberID: "", - uniqueKeyID: "", - }, - errMessage: "invalid", - }, - { - name: "failure - registry lookup fails", - fields: fields{ - cache: &mockCache{}, - registry: &mockRegistry{ - LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { - return nil, fmt.Errorf("registry down") - }, - }, - }, - args: args{ - subscriberID: "sub-id", - uniqueKeyID: "key-id", - }, - errMessage: "registry down", - }, - { - name: "failure - registry returns empty", - fields: fields{ - cache: &mockCache{}, - registry: &mockRegistry{ - LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { - return []model.Subscription{}, nil - }, - }, - }, - args: args{ - subscriberID: "sub-id", - uniqueKeyID: "key-id", - }, - errMessage: "no subscriber found", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - km := &KeyMgr{ - Cache: tt.fields.cache, - Registry: tt.fields.registry, - } - got, err := km.getPublicKeys(context.Background(), tt.args.subscriberID, tt.args.uniqueKeyID) - if err == nil { - t.Errorf("getPublicKeys() expected error but got none, result: %v", got) - return - } - if !strings.Contains(err.Error(), tt.errMessage) { - t.Errorf("getPublicKeys() error = %v, want error message to contain %v", err.Error(), tt.errMessage) - } - }) - } -} - -func TestLookupRegistrySuccess(t *testing.T) { - km := &KeyMgr{ - Registry: &mockRegistry{ - LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { - return []model.Subscription{ - { - Subscriber: model.Subscriber{ - SubscriberID: sub.SubscriberID, - }, - KeyID: sub.KeyID, - SigningPublicKey: "signing-key", - EncrPublicKey: "encryption-key", - }, - }, nil - }, - }, - } - - got, err := km.lookupRegistry(context.Background(), "test-subscriber", "key123") - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - want := &model.Keyset{ - SigningPublic: "signing-key", - EncrPublic: "encryption-key", - } - - if got.SigningPublic != want.SigningPublic { - t.Errorf("SigningPublic = %v, want %v", got.SigningPublic, want.SigningPublic) - } - if got.EncrPublic != want.EncrPublic { - t.Errorf("EncrPublic = %v, want %v", got.EncrPublic, want.EncrPublic) - } -} - -func TestLookupRegistryFailure(t *testing.T) { - tests := []struct { - name string - mockLookup func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) - wantErr error - }{ - { - name: "lookup error", - mockLookup: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { - return nil, fmt.Errorf("registry failure") - }, - wantErr: fmt.Errorf("registry failure"), - }, - { - name: "no subscriber found", - mockLookup: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { - return []model.Subscription{}, nil - }, - wantErr: ErrSubscriberNotFound, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - km := &KeyMgr{ - Registry: &mockRegistry{ - LookupFunc: tt.mockLookup, - }, - } - got, err := km.lookupRegistry(context.Background(), "some-id", "key-id") - if err == nil { - t.Fatalf("expected error, got none") - } - if got != nil { - t.Errorf("expected nil keyset, got %v", got) - } - }) - } -} - func TestValidateParamsSuccess(t *testing.T) { err := validateParams("someSubscriberID", "someUniqueKeyID") if err != nil { @@ -1161,3 +974,126 @@ func TestValidateParamsFailure(t *testing.T) { }) } } + +func TestLookupNPKeysSuccess(t *testing.T) { + tests := []struct { + name string + cacheGetFunc func(ctx context.Context, key string) (string, error) + registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) + expectedSigningPub string + expectedEncrPub string + }{ + { + name: "Cache hit with valid keys", + cacheGetFunc: func(ctx context.Context, key string) (string, error) { + return `{"SigningPublic":"mock-signing-public-key","EncrPublic":"mock-encryption-public-key"}`, nil + }, + registryLookupFunc: nil, + expectedSigningPub: "mock-signing-public-key", + expectedEncrPub: "mock-encryption-public-key", + }, + { + name: "Cache miss and registry success", + cacheGetFunc: func(ctx context.Context, key string) (string, error) { + + return "", nil + }, + registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + }, + KeyID: sub.KeyID, + SigningPublicKey: "mock-signing-public-key", + EncrPublicKey: "mock-encryption-public-key", + }, + }, nil + }, + expectedSigningPub: "mock-signing-public-key", + expectedEncrPub: "mock-encryption-public-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the KeyMgr with mocks + km := &KeyMgr{ + Cache: &mockCache{ + GetFunc: tt.cacheGetFunc, + }, + Registry: &mockRegistry{ + LookupFunc: tt.registryLookupFunc, + }, + } + + // Call the method + signingPublic, encrPublic, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id") + + // Validate no errors in success cases + if err != nil { + t.Fatalf("LookupNPKeys() unexpected error: %v", err) + } + + // Validate returned public keys + if signingPublic != tt.expectedSigningPub { + t.Errorf("SigningPublic = %v, want %v", signingPublic, tt.expectedSigningPub) + } + if encrPublic != tt.expectedEncrPub { + t.Errorf("EncrPublic = %v, want %v", encrPublic, tt.expectedEncrPub) + } + }) + } +} + +func TestLookupNPKeysFailure(t *testing.T) { + tests := []struct { + name string + cacheGetFunc func(ctx context.Context, key string) (string, error) + registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) + expectedError string + }{ + { + name: "Cache miss and registry failure", + cacheGetFunc: func(ctx context.Context, key string) (string, error) { + return "", nil + }, + registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return nil, fmt.Errorf("registry down") + }, + expectedError: "registry down", + }, + { + name: "Cache miss and registry returns no subscriber", + cacheGetFunc: func(ctx context.Context, key string) (string, error) { + return "", nil + }, + registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return nil, nil + }, + expectedError: "no subscriber found with given credentials", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the KeyMgr with mocks + km := &KeyMgr{ + Cache: &mockCache{ + GetFunc: tt.cacheGetFunc, + }, + Registry: &mockRegistry{ + LookupFunc: tt.registryLookupFunc, + }, + } + _, _, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id") + if err == nil { + t.Fatalf("expected an error but got none") + } + + if !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error to contain %v, got %v", tt.expectedError, err.Error()) + } + }) + } +}