change in the keymanager interface
This commit is contained in:
@@ -8,13 +8,11 @@ import (
|
|||||||
|
|
||||||
// KeyManager defines the interface for key management operations/methods.
|
// KeyManager defines the interface for key management operations/methods.
|
||||||
type KeyManager interface {
|
type KeyManager interface {
|
||||||
GenerateKeyPairs() (*model.Keyset, error)
|
GenerateKeyset() (*model.Keyset, error)
|
||||||
StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error
|
InsertKeyset(ctx context.Context, keyID string, keyset *model.Keyset) error
|
||||||
SigningPrivateKey(ctx context.Context, keyID string) (string, string, error)
|
Keyset(ctx context.Context, keyID string) (*model.Keyset, error)
|
||||||
EncrPrivateKey(ctx context.Context, keyID string) (string, string, error)
|
LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error)
|
||||||
SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error)
|
DeleteKeyset(ctx context.Context, keyID string) error
|
||||||
EncrPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error)
|
|
||||||
DeletePrivateKeys(ctx context.Context, keyID string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeyManagerProvider initializes a new signer instance.
|
// KeyManagerProvider initializes a new signer instance.
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/beckn/beckn-onix/pkg/log"
|
"github.com/beckn/beckn-onix/pkg/log"
|
||||||
"github.com/beckn/beckn-onix/pkg/model"
|
"github.com/beckn/beckn-onix/pkg/model"
|
||||||
@@ -175,8 +174,8 @@ var (
|
|||||||
uuidGenFunc = uuid.NewRandom
|
uuidGenFunc = uuid.NewRandom
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenerateKeyPairs generates a new signing (Ed25519) and encryption (X25519) key pair.
|
// GenerateKeyset generates a new signing (Ed25519) and encryption (X25519) key pair.
|
||||||
func (km *KeyMgr) GenerateKeyPairs() (*model.Keyset, error) {
|
func (km *KeyMgr) GenerateKeyset() (*model.Keyset, error) {
|
||||||
signingPublic, signingPrivate, err := ed25519KeyGenFunc(rand.Reader)
|
signingPublic, signingPrivate, err := ed25519KeyGenFunc(rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to generate signing key pair: %w", err)
|
return nil, fmt.Errorf("failed to generate signing key pair: %w", err)
|
||||||
@@ -200,8 +199,8 @@ func (km *KeyMgr) GenerateKeyPairs() (*model.Keyset, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// StorePrivateKeys stores the given keyset in Vault under the specified key ID.
|
// InsertKeyset stores the given keyset in Vault under the specified key ID.
|
||||||
func (km *KeyMgr) StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error {
|
func (km *KeyMgr) InsertKeyset(ctx context.Context, keyID string, keys *model.Keyset) error {
|
||||||
if keyID == "" {
|
if keyID == "" {
|
||||||
return ErrEmptyKeyID
|
return ErrEmptyKeyID
|
||||||
}
|
}
|
||||||
@@ -233,44 +232,8 @@ func (km *KeyMgr) StorePrivateKeys(ctx context.Context, keyID string, keys *mode
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SigningPrivateKey retrieves the unique key ID and signing private key for the given key ID.
|
// DeleteKeyset deletes the private keys for the given key ID from Vault.
|
||||||
func (km *KeyMgr) SigningPrivateKey(ctx context.Context, keyID string) (string, string, error) {
|
func (km *KeyMgr) DeleteKeyset(ctx context.Context, keyID string) error {
|
||||||
keys, err := km.getKeys(ctx, keyID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
return keys.UniqueKeyID, keys.SigningPrivate, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// EncrPrivateKey retrieves the unique key ID and encryption private key for the given key ID.
|
|
||||||
func (km *KeyMgr) EncrPrivateKey(ctx context.Context, keyID string) (string, string, error) {
|
|
||||||
keys, err := km.getKeys(ctx, keyID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
return keys.UniqueKeyID, keys.EncrPrivate, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SigningPublicKey returns the signing public key for the given subscriber ID and key ID.
|
|
||||||
func (km *KeyMgr) SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) {
|
|
||||||
keys, err := km.getPublicKeys(ctx, subscriberID, uniqueKeyID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return keys.SigningPublic, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// EncrPublicKey returns the encryption public key for the given subscriber ID and key ID.
|
|
||||||
func (km *KeyMgr) EncrPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) {
|
|
||||||
keys, err := km.getPublicKeys(ctx, subscriberID, uniqueKeyID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return keys.EncrPublic, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePrivateKeys deletes the private keys for the given key ID from Vault.
|
|
||||||
func (km *KeyMgr) DeletePrivateKeys(ctx context.Context, keyID string) error {
|
|
||||||
if keyID == "" {
|
if keyID == "" {
|
||||||
return ErrEmptyKeyID
|
return ErrEmptyKeyID
|
||||||
}
|
}
|
||||||
@@ -283,8 +246,8 @@ func (km *KeyMgr) DeletePrivateKeys(ctx context.Context, keyID string) error {
|
|||||||
return km.VaultClient.KVv2(path).Delete(ctx, keyID)
|
return km.VaultClient.KVv2(path).Delete(ctx, keyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getKeys retrieves the full keyset from Vault for the given key ID.
|
// Keyset retrieves the keyset for the given key ID from Vault and public keys from the registry.
|
||||||
func (km *KeyMgr) getKeys(ctx context.Context, keyID string) (*model.Keyset, error) {
|
func (km *KeyMgr) Keyset(ctx context.Context, keyID string) (*model.Keyset, error) {
|
||||||
if keyID == "" {
|
if keyID == "" {
|
||||||
return nil, ErrEmptyKeyID
|
return nil, ErrEmptyKeyID
|
||||||
}
|
}
|
||||||
@@ -324,32 +287,16 @@ func (km *KeyMgr) getKeys(ctx context.Context, keyID string) (*model.Keyset, err
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPublicKeys fetches the public keys from cache or registry for the given subscriber and key ID.
|
// LookupNPKeys retrieves the signing and encryption public keys for the given subscriber ID and unique key ID.
|
||||||
func (km *KeyMgr) getPublicKeys(ctx context.Context, subscriberID, uniqueKeyID string) (*model.Keyset, error) {
|
func (km *KeyMgr) LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error) {
|
||||||
if err := validateParams(subscriberID, uniqueKeyID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cacheKey := fmt.Sprintf("%s_%s", subscriberID, uniqueKeyID)
|
cacheKey := fmt.Sprintf("%s_%s", subscriberID, uniqueKeyID)
|
||||||
cachedData, err := km.Cache.Get(ctx, cacheKey)
|
cachedData, err := km.Cache.Get(ctx, cacheKey)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var keys model.Keyset
|
var keys model.Keyset
|
||||||
if err := json.Unmarshal([]byte(cachedData), &keys); err == nil {
|
if err := json.Unmarshal([]byte(cachedData), &keys); err == nil {
|
||||||
return &keys, nil
|
return keys.SigningPublic, keys.EncrPublic, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
publicKeys, err := km.lookupRegistry(ctx, subscriberID, uniqueKeyID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cacheValue, err := json.Marshal(publicKeys)
|
|
||||||
if err == nil {
|
|
||||||
_ = km.Cache.Set(ctx, cacheKey, string(cacheValue), time.Hour)
|
|
||||||
}
|
|
||||||
return publicKeys, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupRegistry queries the registry for public keys based on subscriber ID and key ID.
|
|
||||||
func (km *KeyMgr) lookupRegistry(ctx context.Context, subscriberID, uniqueKeyID string) (*model.Keyset, error) {
|
|
||||||
subscribers, err := km.Registry.Lookup(ctx, &model.Subscription{
|
subscribers, err := km.Registry.Lookup(ctx, &model.Subscription{
|
||||||
Subscriber: model.Subscriber{
|
Subscriber: model.Subscriber{
|
||||||
SubscriberID: subscriberID,
|
SubscriberID: subscriberID,
|
||||||
@@ -357,15 +304,12 @@ func (km *KeyMgr) lookupRegistry(ctx context.Context, subscriberID, uniqueKeyID
|
|||||||
KeyID: uniqueKeyID,
|
KeyID: uniqueKeyID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to lookup registry: %w", err)
|
return "", "", fmt.Errorf("failed to lookup registry: %w", err)
|
||||||
}
|
}
|
||||||
if len(subscribers) == 0 {
|
if len(subscribers) == 0 {
|
||||||
return nil, ErrSubscriberNotFound
|
return "", "", ErrSubscriberNotFound
|
||||||
}
|
}
|
||||||
return &model.Keyset{
|
return subscribers[0].SigningPublicKey, subscribers[0].EncrPublicKey, nil
|
||||||
SigningPublic: subscribers[0].SigningPublicKey,
|
|
||||||
EncrPublic: subscribers[0].EncrPublicKey,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateParams checks that subscriberID and uniqueKeyID are not empty.
|
// validateParams checks that subscriberID and uniqueKeyID are not empty.
|
||||||
|
|||||||
@@ -50,7 +50,9 @@ func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]m
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockCache struct{}
|
type mockCache struct {
|
||||||
|
GetFunc func(ctx context.Context, key string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockCache) Get(ctx context.Context, key string) (string, error) {
|
func (m *mockCache) Get(ctx context.Context, key string) (string, error) {
|
||||||
return "", nil
|
return "", nil
|
||||||
@@ -193,7 +195,7 @@ func TestGenerateKeyPairs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
km := &KeyMgr{}
|
km := &KeyMgr{}
|
||||||
keyset, err := km.GenerateKeyPairs()
|
keyset, err := km.GenerateKeyset()
|
||||||
|
|
||||||
if tt.expectErr {
|
if tt.expectErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -589,7 +591,7 @@ func TestStorePrivateKeysSuccess(t *testing.T) {
|
|||||||
KvVersion: tt.kvVersion,
|
KvVersion: tt.kvVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = km.StorePrivateKeys(ctx, tt.keyID, tt.keys)
|
err = km.InsertKeyset(ctx, tt.keyID, tt.keys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expected no error, got %v", err)
|
t.Errorf("expected no error, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -667,7 +669,7 @@ func TestStorePrivateKeysFailure(t *testing.T) {
|
|||||||
KvVersion: tt.kvVersion,
|
KvVersion: tt.kvVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = km.StorePrivateKeys(ctx, tt.keyID, tt.keys)
|
err = km.InsertKeyset(ctx, tt.keyID, tt.keys)
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error %q but got nil", tt.expectedErr)
|
t.Fatalf("expected error %q but got nil", tt.expectedErr)
|
||||||
@@ -717,7 +719,7 @@ func TestDeletePrivateKeys(t *testing.T) {
|
|||||||
KvVersion: tt.kvVersion,
|
KvVersion: tt.kvVersion,
|
||||||
VaultClient: nil,
|
VaultClient: nil,
|
||||||
}
|
}
|
||||||
err := km.DeletePrivateKeys(context.Background(), tt.keyID)
|
err := km.DeleteKeyset(context.Background(), tt.keyID)
|
||||||
if err != tt.wantErr {
|
if err != tt.wantErr {
|
||||||
t.Errorf("expected error %v, got %v", tt.wantErr, err)
|
t.Errorf("expected error %v, got %v", tt.wantErr, err)
|
||||||
}
|
}
|
||||||
@@ -745,7 +747,7 @@ func TestDeletePrivateKeys(t *testing.T) {
|
|||||||
VaultClient: vaultClient,
|
VaultClient: vaultClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = km.DeletePrivateKeys(context.Background(), tt.keyID)
|
err = km.DeleteKeyset(context.Background(), tt.keyID)
|
||||||
if err != tt.wantErr {
|
if err != tt.wantErr {
|
||||||
t.Errorf("DeletePrivateKeys() error = %v, want %v", err, tt.wantErr)
|
t.Errorf("DeletePrivateKeys() error = %v, want %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
@@ -809,7 +811,7 @@ func setupMockVaultServer(t *testing.T, kvVersion, keyID string, success bool) *
|
|||||||
return httptest.NewServer(handler)
|
return httptest.NewServer(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetKeysSuccess(t *testing.T) {
|
func TestKeysetSuccess(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
kvVersion string
|
kvVersion string
|
||||||
@@ -845,7 +847,7 @@ func TestGetKeysSuccess(t *testing.T) {
|
|||||||
KvVersion: tt.kvVersion,
|
KvVersion: tt.kvVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := km.getKeys(context.Background(), tt.keyID)
|
keys, err := km.Keyset(context.Background(), tt.keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -862,7 +864,7 @@ func TestGetKeysSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetKeysFailure(t *testing.T) {
|
func TestKeysetFailure(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
kvVersion string
|
kvVersion string
|
||||||
@@ -915,7 +917,7 @@ func TestGetKeysFailure(t *testing.T) {
|
|||||||
KvVersion: tt.kvVersion,
|
KvVersion: tt.kvVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := km.getKeys(context.Background(), tt.keyID)
|
keys, err := km.Keyset(context.Background(), tt.keyID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected error but got nil")
|
t.Fatalf("expected error but got nil")
|
||||||
}
|
}
|
||||||
@@ -926,195 +928,6 @@ func TestGetKeysFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetPublicKeysSuccess(t *testing.T) {
|
|
||||||
km := &KeyMgr{
|
|
||||||
Cache: &mockCache{},
|
|
||||||
Registry: &mockRegistry{
|
|
||||||
LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
||||||
return []model.Subscription{
|
|
||||||
{
|
|
||||||
Subscriber: model.Subscriber{
|
|
||||||
SubscriberID: sub.SubscriberID,
|
|
||||||
},
|
|
||||||
KeyID: sub.KeyID,
|
|
||||||
SigningPublicKey: "mock-signing-public-key",
|
|
||||||
EncrPublicKey: "mock-encryption-public-key",
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := km.getPublicKeys(context.Background(), "sub-id", "key-id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("getPublicKeys() unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
if got == nil {
|
|
||||||
t.Fatal("getPublicKeys() returned nil Keyset")
|
|
||||||
}
|
|
||||||
if got.SigningPublic != "mock-signing-public-key" {
|
|
||||||
t.Errorf("SigningPublic = %v, want %v", got.SigningPublic, "mock-signing-public-key")
|
|
||||||
}
|
|
||||||
if got.EncrPublic != "mock-encryption-public-key" {
|
|
||||||
t.Errorf("EncrPublic = %v, want %v", got.EncrPublic, "mock-encryption-public-key")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPublicKeysFailure(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
cache definition.Cache
|
|
||||||
registry definition.RegistryLookup
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
subscriberID string
|
|
||||||
uniqueKeyID string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
errMessage string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "failure - invalid parameters",
|
|
||||||
fields: fields{
|
|
||||||
cache: &mockCache{},
|
|
||||||
registry: &mockRegistry{},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
subscriberID: "",
|
|
||||||
uniqueKeyID: "",
|
|
||||||
},
|
|
||||||
errMessage: "invalid",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "failure - registry lookup fails",
|
|
||||||
fields: fields{
|
|
||||||
cache: &mockCache{},
|
|
||||||
registry: &mockRegistry{
|
|
||||||
LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
||||||
return nil, fmt.Errorf("registry down")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
subscriberID: "sub-id",
|
|
||||||
uniqueKeyID: "key-id",
|
|
||||||
},
|
|
||||||
errMessage: "registry down",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "failure - registry returns empty",
|
|
||||||
fields: fields{
|
|
||||||
cache: &mockCache{},
|
|
||||||
registry: &mockRegistry{
|
|
||||||
LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
||||||
return []model.Subscription{}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
args: args{
|
|
||||||
subscriberID: "sub-id",
|
|
||||||
uniqueKeyID: "key-id",
|
|
||||||
},
|
|
||||||
errMessage: "no subscriber found",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
km := &KeyMgr{
|
|
||||||
Cache: tt.fields.cache,
|
|
||||||
Registry: tt.fields.registry,
|
|
||||||
}
|
|
||||||
got, err := km.getPublicKeys(context.Background(), tt.args.subscriberID, tt.args.uniqueKeyID)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("getPublicKeys() expected error but got none, result: %v", got)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), tt.errMessage) {
|
|
||||||
t.Errorf("getPublicKeys() error = %v, want error message to contain %v", err.Error(), tt.errMessage)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLookupRegistrySuccess(t *testing.T) {
|
|
||||||
km := &KeyMgr{
|
|
||||||
Registry: &mockRegistry{
|
|
||||||
LookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
||||||
return []model.Subscription{
|
|
||||||
{
|
|
||||||
Subscriber: model.Subscriber{
|
|
||||||
SubscriberID: sub.SubscriberID,
|
|
||||||
},
|
|
||||||
KeyID: sub.KeyID,
|
|
||||||
SigningPublicKey: "signing-key",
|
|
||||||
EncrPublicKey: "encryption-key",
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
got, err := km.lookupRegistry(context.Background(), "test-subscriber", "key123")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("expected no error, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
want := &model.Keyset{
|
|
||||||
SigningPublic: "signing-key",
|
|
||||||
EncrPublic: "encryption-key",
|
|
||||||
}
|
|
||||||
|
|
||||||
if got.SigningPublic != want.SigningPublic {
|
|
||||||
t.Errorf("SigningPublic = %v, want %v", got.SigningPublic, want.SigningPublic)
|
|
||||||
}
|
|
||||||
if got.EncrPublic != want.EncrPublic {
|
|
||||||
t.Errorf("EncrPublic = %v, want %v", got.EncrPublic, want.EncrPublic)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLookupRegistryFailure(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
mockLookup func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "lookup error",
|
|
||||||
mockLookup: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
||||||
return nil, fmt.Errorf("registry failure")
|
|
||||||
},
|
|
||||||
wantErr: fmt.Errorf("registry failure"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no subscriber found",
|
|
||||||
mockLookup: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
|
||||||
return []model.Subscription{}, nil
|
|
||||||
},
|
|
||||||
wantErr: ErrSubscriberNotFound,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
km := &KeyMgr{
|
|
||||||
Registry: &mockRegistry{
|
|
||||||
LookupFunc: tt.mockLookup,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
got, err := km.lookupRegistry(context.Background(), "some-id", "key-id")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error, got none")
|
|
||||||
}
|
|
||||||
if got != nil {
|
|
||||||
t.Errorf("expected nil keyset, got %v", got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateParamsSuccess(t *testing.T) {
|
func TestValidateParamsSuccess(t *testing.T) {
|
||||||
err := validateParams("someSubscriberID", "someUniqueKeyID")
|
err := validateParams("someSubscriberID", "someUniqueKeyID")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1161,3 +974,126 @@ func TestValidateParamsFailure(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLookupNPKeysSuccess(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cacheGetFunc func(ctx context.Context, key string) (string, error)
|
||||||
|
registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
|
||||||
|
expectedSigningPub string
|
||||||
|
expectedEncrPub string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Cache hit with valid keys",
|
||||||
|
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
|
||||||
|
return `{"SigningPublic":"mock-signing-public-key","EncrPublic":"mock-encryption-public-key"}`, nil
|
||||||
|
},
|
||||||
|
registryLookupFunc: nil,
|
||||||
|
expectedSigningPub: "mock-signing-public-key",
|
||||||
|
expectedEncrPub: "mock-encryption-public-key",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Cache miss and registry success",
|
||||||
|
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
},
|
||||||
|
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
||||||
|
return []model.Subscription{
|
||||||
|
{
|
||||||
|
Subscriber: model.Subscriber{
|
||||||
|
SubscriberID: sub.SubscriberID,
|
||||||
|
},
|
||||||
|
KeyID: sub.KeyID,
|
||||||
|
SigningPublicKey: "mock-signing-public-key",
|
||||||
|
EncrPublicKey: "mock-encryption-public-key",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
expectedSigningPub: "mock-signing-public-key",
|
||||||
|
expectedEncrPub: "mock-encryption-public-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Set up the KeyMgr with mocks
|
||||||
|
km := &KeyMgr{
|
||||||
|
Cache: &mockCache{
|
||||||
|
GetFunc: tt.cacheGetFunc,
|
||||||
|
},
|
||||||
|
Registry: &mockRegistry{
|
||||||
|
LookupFunc: tt.registryLookupFunc,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the method
|
||||||
|
signingPublic, encrPublic, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id")
|
||||||
|
|
||||||
|
// Validate no errors in success cases
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LookupNPKeys() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate returned public keys
|
||||||
|
if signingPublic != tt.expectedSigningPub {
|
||||||
|
t.Errorf("SigningPublic = %v, want %v", signingPublic, tt.expectedSigningPub)
|
||||||
|
}
|
||||||
|
if encrPublic != tt.expectedEncrPub {
|
||||||
|
t.Errorf("EncrPublic = %v, want %v", encrPublic, tt.expectedEncrPub)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupNPKeysFailure(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cacheGetFunc func(ctx context.Context, key string) (string, error)
|
||||||
|
registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
|
||||||
|
expectedError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Cache miss and registry failure",
|
||||||
|
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
},
|
||||||
|
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
||||||
|
return nil, fmt.Errorf("registry down")
|
||||||
|
},
|
||||||
|
expectedError: "registry down",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Cache miss and registry returns no subscriber",
|
||||||
|
cacheGetFunc: func(ctx context.Context, key string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
},
|
||||||
|
registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
expectedError: "no subscriber found with given credentials",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Set up the KeyMgr with mocks
|
||||||
|
km := &KeyMgr{
|
||||||
|
Cache: &mockCache{
|
||||||
|
GetFunc: tt.cacheGetFunc,
|
||||||
|
},
|
||||||
|
Registry: &mockRegistry{
|
||||||
|
LookupFunc: tt.registryLookupFunc,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, _, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected an error but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), tt.expectedError) {
|
||||||
|
t.Errorf("expected error to contain %v, got %v", tt.expectedError, err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user