Merge branch 'beckn-onix-v1.0-develop' into feature/redis-plugin
This commit is contained in:
34
pkg/plugin/implementation/keymanager/cmd/plugin.go
Normal file
34
pkg/plugin/implementation/keymanager/cmd/plugin.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/log"
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/definition"
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/implementation/keymanager"
|
||||
)
|
||||
|
||||
// keyManagerProvider implements the plugin provider for the KeyManager plugin.
|
||||
type keyManagerProvider struct{}
|
||||
|
||||
// newKeyManagerFunc is a function type that creates a new KeyManager instance.
|
||||
var newKeyManagerFunc = keymanager.New
|
||||
|
||||
// New creates and initializes a new KeyManager instance using the provided cache, registry lookup, and configuration.
|
||||
func (k *keyManagerProvider) New(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg map[string]string) (definition.KeyManager, func() error, error) {
|
||||
config := &keymanager.Config{
|
||||
VaultAddr: cfg["vaultAddr"],
|
||||
KVVersion: cfg["kvVersion"],
|
||||
}
|
||||
log.Debugf(ctx, "Keymanager config mapped: %+v", cfg)
|
||||
km, cleanup, err := newKeyManagerFunc(ctx, cache, registry, config)
|
||||
if err != nil {
|
||||
log.Error(ctx, err, "Failed to initialize KeyManager")
|
||||
return nil, nil, err
|
||||
}
|
||||
log.Debugf(ctx, "KeyManager instance created successfully")
|
||||
return km, cleanup, nil
|
||||
}
|
||||
|
||||
// Provider is the exported instance of keyManagerProvider used for plugin registration.
|
||||
var Provider = keyManagerProvider{}
|
||||
127
pkg/plugin/implementation/keymanager/cmd/plugin_test.go
Normal file
127
pkg/plugin/implementation/keymanager/cmd/plugin_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/model"
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/definition"
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/implementation/keymanager"
|
||||
)
|
||||
|
||||
type mockRegistry struct {
|
||||
LookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
|
||||
if m.LookupFunc != nil {
|
||||
return m.LookupFunc(ctx, sub)
|
||||
}
|
||||
return []model.Subscription{
|
||||
{
|
||||
Subscriber: model.Subscriber{
|
||||
SubscriberID: sub.SubscriberID,
|
||||
URL: "https://mock.registry/subscriber",
|
||||
Type: "BPP",
|
||||
Domain: "retail",
|
||||
},
|
||||
KeyID: sub.KeyID,
|
||||
SigningPublicKey: "mock-signing-public-key",
|
||||
EncrPublicKey: "mock-encryption-public-key",
|
||||
ValidFrom: time.Now().Add(-time.Hour),
|
||||
ValidUntil: time.Now().Add(time.Hour),
|
||||
Status: "SUBSCRIBED",
|
||||
Created: time.Now().Add(-2 * time.Hour),
|
||||
Updated: time.Now(),
|
||||
Nonce: "mock-nonce",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockCache struct{}
|
||||
|
||||
func (m *mockCache) Get(ctx context.Context, key string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (m *mockCache) Set(ctx context.Context, key string, value string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockCache) Clear(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCache) Delete(ctx context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewSuccess(t *testing.T) {
|
||||
// Setup dummy implementations and variables
|
||||
ctx := context.Background()
|
||||
cache := &mockCache{}
|
||||
registry := &mockRegistry{}
|
||||
cfg := map[string]string{
|
||||
"vaultAddr": "http://dummy-vault",
|
||||
"kvVersion": "2",
|
||||
}
|
||||
|
||||
cleanupCalled := false
|
||||
fakeCleanup := func() error {
|
||||
cleanupCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
newKeyManagerFunc = func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (*keymanager.KeyMgr, func() error, error) {
|
||||
// return a mock struct pointer of *keymanager.KeyMgr or a stub instance
|
||||
return &keymanager.KeyMgr{}, fakeCleanup, nil
|
||||
}
|
||||
|
||||
// Create provider and call New
|
||||
provider := &keyManagerProvider{}
|
||||
km, cleanup, err := provider.New(ctx, cache, registry, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if km == nil {
|
||||
t.Fatal("Expected non-nil KeyManager instance")
|
||||
}
|
||||
if cleanup == nil {
|
||||
t.Fatal("Expected non-nil cleanup function")
|
||||
}
|
||||
|
||||
// Call cleanup and check if it behaves correctly
|
||||
if err := cleanup(); err != nil {
|
||||
t.Fatalf("Expected no error from cleanup, got %v", err)
|
||||
}
|
||||
if !cleanupCalled {
|
||||
t.Error("Expected cleanup function to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFailure(t *testing.T) {
|
||||
// Setup dummy variables
|
||||
ctx := context.Background()
|
||||
cache := &mockCache{}
|
||||
registry := &mockRegistry{}
|
||||
cfg := map[string]string{
|
||||
"vaultAddr": "http://dummy-vault",
|
||||
"kvVersion": "2",
|
||||
}
|
||||
|
||||
newKeyManagerFunc = func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (*keymanager.KeyMgr, func() error, error) {
|
||||
return nil, nil, fmt.Errorf("some error")
|
||||
}
|
||||
|
||||
provider := &keyManagerProvider{}
|
||||
km, cleanup, err := provider.New(ctx, cache, registry, cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
if km != nil {
|
||||
t.Error("Expected nil KeyManager on error")
|
||||
}
|
||||
if cleanup != nil {
|
||||
t.Error("Expected nil cleanup function on error")
|
||||
}
|
||||
}
|
||||
328
pkg/plugin/implementation/keymanager/keymanager.go
Normal file
328
pkg/plugin/implementation/keymanager/keymanager.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package keymanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdh"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/log"
|
||||
"github.com/beckn/beckn-onix/pkg/model"
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/definition"
|
||||
"github.com/google/uuid"
|
||||
vault "github.com/hashicorp/vault/api"
|
||||
)
|
||||
|
||||
// Config holds configuration parameters for connecting to Vault.
|
||||
type Config struct {
|
||||
VaultAddr string
|
||||
KVVersion string
|
||||
}
|
||||
|
||||
// KeyMgr provides methods for managing cryptographic keys using Vault.
|
||||
type KeyMgr struct {
|
||||
VaultClient *vault.Client
|
||||
Registry definition.RegistryLookup
|
||||
Cache definition.Cache
|
||||
KvVersion string
|
||||
SecretPath string
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrEmptyKeyID indicates that the provided key ID is empty.
|
||||
ErrEmptyKeyID = errors.New("invalid request: keyID cannot be empty")
|
||||
|
||||
// ErrNilKeySet indicates that the provided keyset is nil.
|
||||
ErrNilKeySet = errors.New("keyset cannot be nil")
|
||||
|
||||
// ErrEmptySubscriberID indicates that the provided subscriber ID is empty.
|
||||
ErrEmptySubscriberID = errors.New("invalid request: subscriberID cannot be empty")
|
||||
|
||||
// ErrEmptyUniqueKeyID indicates that the provided unique key ID is empty.
|
||||
ErrEmptyUniqueKeyID = errors.New("invalid request: uniqueKeyID cannot be empty")
|
||||
|
||||
// ErrSubscriberNotFound indicates that no subscriber was found with the provided credentials.
|
||||
ErrSubscriberNotFound = errors.New("no subscriber found with given credentials")
|
||||
|
||||
// ErrNilCache indicates that the cache implementation is nil.
|
||||
ErrNilCache = errors.New("cache implementation cannot be nil")
|
||||
|
||||
// ErrNilRegistryLookup indicates that the registry lookup implementation is nil.
|
||||
ErrNilRegistryLookup = errors.New("registry lookup implementation cannot be nil")
|
||||
)
|
||||
|
||||
// ValidateCfg validates the Vault configuration and sets default KV version if missing.
|
||||
func ValidateCfg(cfg *Config) error {
|
||||
if cfg.VaultAddr == "" {
|
||||
return errors.New("invalid config: VaultAddr cannot be empty")
|
||||
}
|
||||
kvVersion := strings.ToLower(cfg.KVVersion)
|
||||
if kvVersion == "" {
|
||||
kvVersion = "v1"
|
||||
} else if kvVersion != "v1" && kvVersion != "v2" {
|
||||
return fmt.Errorf("invalid KVVersion: must be 'v1' or 'v2'")
|
||||
}
|
||||
cfg.KVVersion = kvVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
// getVaultClient is a function that creates a new Vault client.
|
||||
// This is exported for testing purposes.
|
||||
var getVaultClient = GetVaultClient
|
||||
|
||||
// New creates a new KeyMgr instance with the provided configuration, cache, and registry lookup.
|
||||
func New(ctx context.Context, cache definition.Cache, registryLookup definition.RegistryLookup, cfg *Config) (*KeyMgr, func() error, error) {
|
||||
log.Info(ctx, "Initializing KeyManager plugin")
|
||||
// Validate configuration.
|
||||
if err := ValidateCfg(cfg); err != nil {
|
||||
log.Error(ctx, err, "Invalid configuration for KeyManager")
|
||||
return nil, nil, err
|
||||
}
|
||||
// Check if cache implementation is provided.
|
||||
if cache == nil {
|
||||
log.Error(ctx, ErrNilCache, "Cache is nil in KeyManager initialization")
|
||||
return nil, nil, ErrNilCache
|
||||
}
|
||||
|
||||
// Check if registry lookup implementation is provided.
|
||||
if registryLookup == nil {
|
||||
log.Error(ctx, ErrNilRegistryLookup, "RegistryLookup is nil in KeyManager initialization")
|
||||
return nil, nil, ErrNilRegistryLookup
|
||||
}
|
||||
|
||||
// Initialize Vault client.
|
||||
log.Debugf(ctx, "Creating Vault client with address: %s", cfg.VaultAddr)
|
||||
vaultClient, err := getVaultClient(ctx, cfg.VaultAddr)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, err, "Failed to create Vault client at address: %s", cfg.VaultAddr)
|
||||
return nil, nil, fmt.Errorf("failed to create vault client: %w", err)
|
||||
}
|
||||
|
||||
log.Info(ctx, "Successfully created Vault client")
|
||||
|
||||
// Create KeyManager instance.
|
||||
km := &KeyMgr{
|
||||
VaultClient: vaultClient,
|
||||
Registry: registryLookup,
|
||||
Cache: cache,
|
||||
KvVersion: cfg.KVVersion,
|
||||
}
|
||||
|
||||
// Cleanup function to release KeyManager resources.
|
||||
cleanup := func() error {
|
||||
log.Info(ctx, "Cleaning up KeyManager resources")
|
||||
km.VaultClient = nil
|
||||
km.Cache = nil
|
||||
km.Registry = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info(ctx, "KeyManager plugin initialized successfully")
|
||||
return km, cleanup, nil
|
||||
}
|
||||
|
||||
// NewVaultClient creates a new Vault client instance.
|
||||
// This function is exported for testing purposes.
|
||||
var NewVaultClient = vault.NewClient
|
||||
|
||||
// GetVaultClient creates and authenticates a Vault client using AppRole.
|
||||
func GetVaultClient(ctx context.Context, vaultAddr string) (*vault.Client, error) {
|
||||
roleID := os.Getenv("VAULT_ROLE_ID")
|
||||
secretID := os.Getenv("VAULT_SECRET_ID")
|
||||
|
||||
if roleID == "" || secretID == "" {
|
||||
log.Error(ctx, fmt.Errorf("missing credentials"), "VAULT_ROLE_ID or VAULT_SECRET_ID is not set")
|
||||
return nil, fmt.Errorf("VAULT_ROLE_ID or VAULT_SECRET_ID is not set")
|
||||
}
|
||||
|
||||
config := vault.DefaultConfig()
|
||||
config.Address = vaultAddr
|
||||
|
||||
client, err := NewVaultClient(config)
|
||||
if err != nil {
|
||||
log.Error(ctx, err, "failed to create Vault client")
|
||||
return nil, fmt.Errorf("failed to create Vault client: %w", err)
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"role_id": roleID,
|
||||
"secret_id": secretID,
|
||||
}
|
||||
|
||||
log.Info(ctx, "Logging into Vault with AppRole")
|
||||
resp, err := client.Logical().Write("auth/approle/login", data)
|
||||
if err != nil {
|
||||
log.Error(ctx, err, "failed to login with AppRole")
|
||||
return nil, fmt.Errorf("failed to login with AppRole: %w", err)
|
||||
}
|
||||
if resp == nil || resp.Auth == nil {
|
||||
log.Error(ctx, nil, "AppRole login failed: no auth info returned")
|
||||
return nil, errors.New("AppRole login failed: no auth info returned")
|
||||
}
|
||||
|
||||
log.Info(ctx, "Vault login successful")
|
||||
client.SetToken(resp.Auth.ClientToken)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ed25519KeyGenFunc = ed25519.GenerateKey
|
||||
x25519KeyGenFunc = ecdh.X25519().GenerateKey
|
||||
uuidGenFunc = uuid.NewRandom
|
||||
)
|
||||
|
||||
// GenerateKeyset generates a new signing (Ed25519) and encryption (X25519) key pair.
|
||||
func (km *KeyMgr) GenerateKeyset() (*model.Keyset, error) {
|
||||
signingPublic, signingPrivate, err := ed25519KeyGenFunc(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate signing key pair: %w", err)
|
||||
}
|
||||
|
||||
encrPrivateKey, err := x25519KeyGenFunc(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate encryption key pair: %w", err)
|
||||
}
|
||||
encrPublicKey := encrPrivateKey.PublicKey().Bytes()
|
||||
uuid, err := uuidGenFunc()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate unique key id uuid: %w", err)
|
||||
}
|
||||
return &model.Keyset{
|
||||
UniqueKeyID: uuid.String(),
|
||||
SigningPrivate: encodeBase64(signingPrivate.Seed()),
|
||||
SigningPublic: encodeBase64(signingPublic),
|
||||
EncrPrivate: encodeBase64(encrPrivateKey.Bytes()),
|
||||
EncrPublic: encodeBase64(encrPublicKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getSecretPath constructs the Vault secret path for storing keys based on the KV version.
|
||||
func (km *KeyMgr) getSecretPath(keyID string) string {
|
||||
if km.KvVersion == "v2" {
|
||||
return fmt.Sprintf("secret/data/keys/%s", keyID)
|
||||
}
|
||||
return fmt.Sprintf("secret/keys/%s", keyID)
|
||||
}
|
||||
|
||||
// InsertKeyset stores the given keyset in Vault under the specified key ID.
|
||||
func (km *KeyMgr) InsertKeyset(ctx context.Context, keyID string, keys *model.Keyset) error {
|
||||
if keyID == "" {
|
||||
return ErrEmptyKeyID
|
||||
}
|
||||
if keys == nil {
|
||||
return ErrNilKeySet
|
||||
}
|
||||
|
||||
keyData := map[string]interface{}{
|
||||
"uniqueKeyID": keys.UniqueKeyID,
|
||||
"signingPublicKey": keys.SigningPublic,
|
||||
"signingPrivateKey": keys.SigningPrivate,
|
||||
"encrPublicKey": keys.EncrPublic,
|
||||
"encrPrivateKey": keys.EncrPrivate,
|
||||
}
|
||||
path := km.getSecretPath(keyID)
|
||||
var payload map[string]interface{}
|
||||
if km.KvVersion == "v2" {
|
||||
payload = map[string]interface{}{"data": keyData}
|
||||
} else {
|
||||
payload = keyData
|
||||
}
|
||||
|
||||
_, err := km.VaultClient.Logical().Write(path, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store secret in Vault at path %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteKeyset deletes the private keys for the given key ID from Vault.
|
||||
func (km *KeyMgr) DeleteKeyset(ctx context.Context, keyID string) error {
|
||||
if keyID == "" {
|
||||
return ErrEmptyKeyID
|
||||
}
|
||||
path := km.getSecretPath(keyID)
|
||||
return km.VaultClient.KVv2(path).Delete(ctx, keyID)
|
||||
}
|
||||
|
||||
// Keyset retrieves the keyset for the given key ID from Vault and public keys from the registry.
|
||||
func (km *KeyMgr) Keyset(ctx context.Context, keyID string) (*model.Keyset, error) {
|
||||
if keyID == "" {
|
||||
return nil, ErrEmptyKeyID
|
||||
}
|
||||
|
||||
path := km.getSecretPath(keyID)
|
||||
|
||||
secret, err := km.VaultClient.Logical().Read(path)
|
||||
if err != nil || secret == nil {
|
||||
return nil, fmt.Errorf("failed to read secret from Vault: %w", err)
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if km.KvVersion == "v2" {
|
||||
dataRaw, ok := secret.Data["data"]
|
||||
if !ok {
|
||||
return nil, errors.New("missing 'data' in secret response")
|
||||
}
|
||||
data, ok = dataRaw.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, errors.New("invalid 'data' format in Vault response")
|
||||
}
|
||||
} else {
|
||||
data = secret.Data
|
||||
}
|
||||
|
||||
return &model.Keyset{
|
||||
UniqueKeyID: data["uniqueKeyID"].(string),
|
||||
SigningPublic: data["signingPublicKey"].(string),
|
||||
SigningPrivate: data["signingPrivateKey"].(string),
|
||||
EncrPublic: data["encrPublicKey"].(string),
|
||||
EncrPrivate: data["encrPrivateKey"].(string),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LookupNPKeys retrieves the signing and encryption public keys for the given subscriber ID and unique key ID.
|
||||
func (km *KeyMgr) LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error) {
|
||||
cacheKey := fmt.Sprintf("%s_%s", subscriberID, uniqueKeyID)
|
||||
cachedData, err := km.Cache.Get(ctx, cacheKey)
|
||||
if err == nil {
|
||||
var keys model.Keyset
|
||||
if err := json.Unmarshal([]byte(cachedData), &keys); err == nil {
|
||||
return keys.SigningPublic, keys.EncrPublic, nil
|
||||
}
|
||||
}
|
||||
subscribers, err := km.Registry.Lookup(ctx, &model.Subscription{
|
||||
Subscriber: model.Subscriber{
|
||||
SubscriberID: subscriberID,
|
||||
},
|
||||
KeyID: uniqueKeyID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to lookup registry: %w", err)
|
||||
}
|
||||
if len(subscribers) == 0 {
|
||||
return "", "", ErrSubscriberNotFound
|
||||
}
|
||||
return subscribers[0].SigningPublicKey, subscribers[0].EncrPublicKey, nil
|
||||
}
|
||||
|
||||
// validateParams checks that subscriberID and uniqueKeyID are not empty.
|
||||
func validateParams(subscriberID, uniqueKeyID string) error {
|
||||
if subscriberID == "" {
|
||||
return ErrEmptySubscriberID
|
||||
}
|
||||
if uniqueKeyID == "" {
|
||||
return ErrEmptyUniqueKeyID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeBase64 returns the base64-encoded string of the given data.
|
||||
func encodeBase64(data []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
1114
pkg/plugin/implementation/keymanager/keymanager_test.go
Normal file
1114
pkg/plugin/implementation/keymanager/keymanager_test.go
Normal file
File diff suppressed because it is too large
Load Diff
37
pkg/plugin/implementation/publisher/cmd/plugin.go
Normal file
37
pkg/plugin/implementation/publisher/cmd/plugin.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/log"
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/definition"
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/implementation/publisher"
|
||||
)
|
||||
|
||||
// publisherProvider implements the PublisherProvider interface.
|
||||
// It is responsible for creating a new Publisher instance.
|
||||
type publisherProvider struct{}
|
||||
|
||||
// New creates a new Publisher instance based on the provided configuration.
|
||||
func (p *publisherProvider) New(ctx context.Context, config map[string]string) (definition.Publisher, func() error, error) {
|
||||
cfg := &publisher.Config{
|
||||
Addr: config["addr"],
|
||||
Exchange: config["exchange"],
|
||||
RoutingKey: config["routing_key"],
|
||||
Durable: config["durable"] == "true",
|
||||
UseTLS: config["use_tls"] == "true",
|
||||
}
|
||||
log.Debugf(ctx, "Publisher config mapped: %+v", cfg)
|
||||
|
||||
pub, cleanup, err := publisher.New(cfg)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, err, "Failed to create publisher instance")
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
log.Infof(ctx, "Publisher instance created successfully")
|
||||
return pub, cleanup, nil
|
||||
}
|
||||
|
||||
// Provider is the instance of publisherProvider that implements the PublisherProvider interface.
|
||||
var Provider = publisherProvider{}
|
||||
106
pkg/plugin/implementation/publisher/cmd/plugin_test.go
Normal file
106
pkg/plugin/implementation/publisher/cmd/plugin_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/implementation/publisher"
|
||||
"github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
type mockChannel struct{}
|
||||
|
||||
func (m *mockChannel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockChannel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockChannel) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPublisherProvider_New_Success(t *testing.T) {
|
||||
// Save original dialFunc and channelFunc
|
||||
originalDialFunc := publisher.DialFunc
|
||||
originalChannelFunc := publisher.ChannelFunc
|
||||
defer func() {
|
||||
publisher.DialFunc = originalDialFunc
|
||||
publisher.ChannelFunc = originalChannelFunc
|
||||
}()
|
||||
|
||||
// Override mocks
|
||||
publisher.DialFunc = func(url string) (*amqp091.Connection, error) {
|
||||
return nil, nil
|
||||
}
|
||||
publisher.ChannelFunc = func(conn *amqp091.Connection) (publisher.Channel, error) {
|
||||
return &mockChannel{}, nil
|
||||
}
|
||||
|
||||
t.Setenv("RABBITMQ_USERNAME", "guest")
|
||||
t.Setenv("RABBITMQ_PASSWORD", "guest")
|
||||
|
||||
config := map[string]string{
|
||||
"addr": "localhost",
|
||||
"exchange": "test-exchange",
|
||||
"routing_key": "test.key",
|
||||
"durable": "true",
|
||||
"use_tls": "false",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
pub, cleanup, err := Provider.New(ctx, config)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Provider.New returned error: %v", err)
|
||||
}
|
||||
if pub == nil {
|
||||
t.Fatal("Expected non-nil publisher")
|
||||
}
|
||||
if cleanup == nil {
|
||||
t.Fatal("Expected non-nil cleanup function")
|
||||
}
|
||||
|
||||
if err := cleanup(); err != nil {
|
||||
t.Errorf("Cleanup returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublisherProvider_New_Failure(t *testing.T) {
|
||||
// Save and restore dialFunc
|
||||
originalDialFunc := publisher.DialFunc
|
||||
defer func() { publisher.DialFunc = originalDialFunc }()
|
||||
|
||||
// Simulate dial failure
|
||||
publisher.DialFunc = func(url string) (*amqp091.Connection, error) {
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
|
||||
t.Setenv("RABBITMQ_USERNAME", "guest")
|
||||
t.Setenv("RABBITMQ_PASSWORD", "guest")
|
||||
|
||||
config := map[string]string{
|
||||
"addr": "localhost",
|
||||
"exchange": "test-exchange",
|
||||
"routing_key": "test.key",
|
||||
"durable": "true",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
pub, cleanup, err := Provider.New(ctx, config)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected error from Provider.New but got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "dial failed") {
|
||||
t.Errorf("Expected 'dial failed' error, got: %v", err)
|
||||
}
|
||||
if pub != nil {
|
||||
t.Errorf("Expected nil publisher, got: %v", pub)
|
||||
}
|
||||
if cleanup != nil {
|
||||
t.Error("Expected nil cleanup, got non-nil")
|
||||
}
|
||||
}
|
||||
196
pkg/plugin/implementation/publisher/publisher.go
Normal file
196
pkg/plugin/implementation/publisher/publisher.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package publisher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/log"
|
||||
"github.com/beckn/beckn-onix/pkg/model"
|
||||
"github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
// Config holds the configuration required to establish a connection with RabbitMQ.
|
||||
type Config struct {
|
||||
Addr string
|
||||
Exchange string
|
||||
RoutingKey string
|
||||
Durable bool
|
||||
UseTLS bool
|
||||
}
|
||||
|
||||
// Channel defines the interface for publishing messages to RabbitMQ.
|
||||
type Channel interface {
|
||||
PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error
|
||||
ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Publisher manages the RabbitMQ connection and channel to publish messages.
|
||||
type Publisher struct {
|
||||
Conn *amqp091.Connection
|
||||
Channel Channel
|
||||
Config *Config
|
||||
}
|
||||
|
||||
// Error variables representing different failure scenarios.
|
||||
var (
|
||||
ErrEmptyConfig = errors.New("empty config")
|
||||
ErrAddrMissing = errors.New("missing required field 'Addr'")
|
||||
ErrExchangeMissing = errors.New("missing required field 'Exchange'")
|
||||
ErrCredentialMissing = errors.New("missing RabbitMQ credentials in environment")
|
||||
ErrConnectionFailed = errors.New("failed to connect to RabbitMQ")
|
||||
ErrChannelFailed = errors.New("failed to open channel")
|
||||
ErrExchangeDeclare = errors.New("failed to declare exchange")
|
||||
)
|
||||
|
||||
// Validate checks whether the provided Config is valid for connecting to RabbitMQ.
|
||||
func Validate(cfg *Config) error {
|
||||
if cfg == nil {
|
||||
return model.NewBadReqErr(fmt.Errorf("config is nil"))
|
||||
}
|
||||
if strings.TrimSpace(cfg.Addr) == "" {
|
||||
return model.NewBadReqErr(fmt.Errorf("missing config.Addr"))
|
||||
}
|
||||
if strings.TrimSpace(cfg.Exchange) == "" {
|
||||
return model.NewBadReqErr(fmt.Errorf("missing config.Exchange"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConnURL constructs the RabbitMQ connection URL using the config and environment credentials.
|
||||
func GetConnURL(cfg *Config) (string, error) {
|
||||
user := os.Getenv("RABBITMQ_USERNAME")
|
||||
pass := os.Getenv("RABBITMQ_PASSWORD")
|
||||
if user == "" || pass == "" {
|
||||
return "", model.NewBadReqErr(fmt.Errorf("missing RabbitMQ credentials in environment"))
|
||||
}
|
||||
parts := strings.SplitN(strings.TrimSpace(cfg.Addr), "/", 2)
|
||||
hostPort := parts[0]
|
||||
vhost := "/"
|
||||
if len(parts) > 1 {
|
||||
vhost = parts[1]
|
||||
}
|
||||
|
||||
if !strings.Contains(hostPort, ":") {
|
||||
if cfg.UseTLS {
|
||||
hostPort += ":5671"
|
||||
} else {
|
||||
hostPort += ":5672"
|
||||
}
|
||||
}
|
||||
|
||||
encodedUser := url.QueryEscape(user)
|
||||
encodedPass := url.QueryEscape(pass)
|
||||
encodedVHost := url.QueryEscape(vhost)
|
||||
protocol := "amqp"
|
||||
if cfg.UseTLS {
|
||||
protocol = "amqps"
|
||||
}
|
||||
|
||||
connURL := fmt.Sprintf("%s://%s:%s@%s/%s", protocol, encodedUser, encodedPass, hostPort, encodedVHost)
|
||||
log.Debugf(context.Background(), "Generated RabbitMQ connection details: protocol=%s, hostPort=%s, vhost=%s", protocol, hostPort, vhost)
|
||||
|
||||
return connURL, nil
|
||||
}
|
||||
|
||||
// Publish sends a message to the configured RabbitMQ exchange with the specified routing key.
|
||||
// If routingKey is empty, the default routing key from Config is used.
|
||||
func (p *Publisher) Publish(ctx context.Context, routingKey string, msg []byte) error {
|
||||
if routingKey == "" {
|
||||
routingKey = p.Config.RoutingKey
|
||||
}
|
||||
log.Debugf(ctx, "Attempting to publish message. Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey)
|
||||
err := p.Channel.PublishWithContext(
|
||||
ctx,
|
||||
p.Config.Exchange,
|
||||
routingKey,
|
||||
false,
|
||||
false,
|
||||
amqp091.Publishing{
|
||||
ContentType: "application/json",
|
||||
Body: msg,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(ctx, err, "Publish failed for Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey)
|
||||
return model.NewBadReqErr(fmt.Errorf("publish message failed: %w", err))
|
||||
}
|
||||
|
||||
log.Infof(ctx, "Message published successfully to Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DialFunc is a function variable used to establish a connection to RabbitMQ.
|
||||
var DialFunc = amqp091.Dial
|
||||
|
||||
// ChannelFunc is a function variable used to open a channel on the given RabbitMQ connection.
|
||||
var ChannelFunc = func(conn *amqp091.Connection) (Channel, error) {
|
||||
return conn.Channel()
|
||||
}
|
||||
|
||||
// New initializes a new Publisher with the given config, opens a connection,
|
||||
// channel, and declares the exchange. Returns the publisher and a cleanup function.
|
||||
func New(cfg *Config) (*Publisher, func() error, error) {
|
||||
// Step 1: Validate config
|
||||
if err := Validate(cfg); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Step 2: Build connection URL
|
||||
connURL, err := GetConnURL(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("%w: %v", ErrConnectionFailed, err)
|
||||
}
|
||||
|
||||
// Step 3: Dial connection
|
||||
conn, err := DialFunc(connURL)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("%w: %v", ErrConnectionFailed, err)
|
||||
}
|
||||
|
||||
// Step 4: Open channel
|
||||
ch, err := ChannelFunc(conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, fmt.Errorf("%w: %v", ErrChannelFailed, err)
|
||||
}
|
||||
|
||||
// Step 5: Declare exchange
|
||||
if err := ch.ExchangeDeclare(
|
||||
cfg.Exchange,
|
||||
"topic",
|
||||
cfg.Durable,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
); err != nil {
|
||||
ch.Close()
|
||||
conn.Close()
|
||||
return nil, nil, fmt.Errorf("%w: %v", ErrExchangeDeclare, err)
|
||||
}
|
||||
|
||||
// Step 6: Construct publisher
|
||||
pub := &Publisher{
|
||||
Conn: conn,
|
||||
Channel: ch,
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
cleanup := func() error {
|
||||
if ch != nil {
|
||||
_ = ch.Close()
|
||||
}
|
||||
if conn != nil {
|
||||
return conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return pub, cleanup, nil
|
||||
}
|
||||
362
pkg/plugin/implementation/publisher/publisher_test.go
Normal file
362
pkg/plugin/implementation/publisher/publisher_test.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package publisher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
func TestGetConnURLSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
}{
|
||||
{
|
||||
name: "Valid config with connection address",
|
||||
config: &Config{
|
||||
Addr: "localhost:5672",
|
||||
UseTLS: false,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "Valid config with vhost",
|
||||
config: &Config{
|
||||
Addr: "localhost:5672/myvhost",
|
||||
UseTLS: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Addr with leading and trailing spaces",
|
||||
config: &Config{
|
||||
Addr: " localhost:5672/myvhost ",
|
||||
UseTLS: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Set valid credentials
|
||||
t.Setenv("RABBITMQ_USERNAME", "guest")
|
||||
t.Setenv("RABBITMQ_PASSWORD", "guest")
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url, err := GetConnURL(tt.config)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if url == "" {
|
||||
t.Error("expected non-empty URL, got empty string")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConnURLFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
password string
|
||||
config *Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Missing credentials",
|
||||
username: "",
|
||||
password: "",
|
||||
config: &Config{Addr: "localhost:5672"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Missing config address",
|
||||
username: "guest",
|
||||
password: "guest",
|
||||
config: &Config{}, // this won't error unless Validate() is called separately
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.username != "" {
|
||||
t.Setenv("RABBITMQ_USERNAME", tt.username)
|
||||
}
|
||||
|
||||
if tt.password != "" {
|
||||
t.Setenv("RABBITMQ_PASSWORD", tt.password)
|
||||
}
|
||||
|
||||
url, err := GetConnURL(tt.config)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("unexpected error. gotErr = %v, wantErr = %v", err != nil, tt.wantErr)
|
||||
}
|
||||
|
||||
if err == nil && url == "" {
|
||||
t.Errorf("expected non-empty URL, got empty string")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
}{
|
||||
{
|
||||
name: "Valid config with Addr and Exchange",
|
||||
config: &Config{
|
||||
Addr: "localhost:5672",
|
||||
Exchange: "ex",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := Validate(tt.config); err != nil {
|
||||
t.Errorf("expected no error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *Config
|
||||
expectedErrr string
|
||||
}{
|
||||
{
|
||||
name: "Nil config",
|
||||
config: nil,
|
||||
expectedErrr: "config is nil",
|
||||
},
|
||||
{
|
||||
name: "Missing Addr",
|
||||
config: &Config{Exchange: "ex"},
|
||||
expectedErrr: "missing config.Addr",
|
||||
},
|
||||
{
|
||||
name: "Missing Exchange",
|
||||
config: &Config{Addr: "localhost:5672"},
|
||||
expectedErrr: "missing config.Exchange",
|
||||
},
|
||||
{
|
||||
name: "Empty Addr and Exchange",
|
||||
config: &Config{Addr: " ", Exchange: " "},
|
||||
expectedErrr: "missing config.Addr",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := Validate(tt.config)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid config, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.expectedErrr) {
|
||||
t.Errorf("expected error to contain %q, got: %v", tt.expectedErrr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockChannelForPublish struct {
|
||||
published bool
|
||||
exchange string
|
||||
key string
|
||||
body []byte
|
||||
fail bool
|
||||
}
|
||||
|
||||
func (m *mockChannelForPublish) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error {
|
||||
if m.fail {
|
||||
return fmt.Errorf("simulated publish failure")
|
||||
}
|
||||
m.published = true
|
||||
m.exchange = exchange
|
||||
m.key = key
|
||||
m.body = msg.Body
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockChannelForPublish) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockChannelForPublish) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestPublishSuccess(t *testing.T) {
|
||||
mockCh := &mockChannelForPublish{}
|
||||
|
||||
p := &Publisher{
|
||||
Channel: mockCh,
|
||||
Config: &Config{
|
||||
Exchange: "mock.exchange",
|
||||
RoutingKey: "mock.key",
|
||||
},
|
||||
}
|
||||
|
||||
err := p.Publish(context.Background(), "", []byte(`{"test": true}`))
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if !mockCh.published {
|
||||
t.Error("expected message to be published, but it wasn't")
|
||||
}
|
||||
|
||||
if mockCh.exchange != "mock.exchange" || mockCh.key != "mock.key" {
|
||||
t.Errorf("unexpected exchange or key. got (%s, %s)", mockCh.exchange, mockCh.key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishFailure(t *testing.T) {
|
||||
mockCh := &mockChannelForPublish{fail: true}
|
||||
|
||||
p := &Publisher{
|
||||
Channel: mockCh,
|
||||
Config: &Config{
|
||||
Exchange: "mock.exchange",
|
||||
RoutingKey: "mock.key",
|
||||
},
|
||||
}
|
||||
|
||||
err := p.Publish(context.Background(), "", []byte(`{"test": true}`))
|
||||
if err == nil {
|
||||
t.Error("expected error from failed publish, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
type mockChannel struct{}
|
||||
|
||||
func (m *mockChannel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockChannel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockChannel) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewPublisherSucess(t *testing.T) {
|
||||
originalDialFunc := DialFunc
|
||||
originalChannelFunc := ChannelFunc
|
||||
defer func() {
|
||||
DialFunc = originalDialFunc
|
||||
ChannelFunc = originalChannelFunc
|
||||
}()
|
||||
|
||||
// mockedConn := &mockConnection{}
|
||||
|
||||
DialFunc = func(url string) (*amqp091.Connection, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ChannelFunc = func(conn *amqp091.Connection) (Channel, error) {
|
||||
return &mockChannel{}, nil
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
Addr: "localhost",
|
||||
Exchange: "test-ex",
|
||||
Durable: true,
|
||||
RoutingKey: "test.key",
|
||||
}
|
||||
|
||||
t.Setenv("RABBITMQ_USERNAME", "user")
|
||||
t.Setenv("RABBITMQ_PASSWORD", "pass")
|
||||
|
||||
pub, cleanup, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
if pub == nil {
|
||||
t.Fatal("Publisher should not be nil")
|
||||
}
|
||||
if cleanup == nil {
|
||||
t.Fatal("Cleanup should not be nil")
|
||||
}
|
||||
if err := cleanup(); err != nil {
|
||||
t.Errorf("Cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPublisherFailures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *Config
|
||||
dialFunc func(url string) (*amqp091.Connection, error) // Mocked dial function
|
||||
envVars map[string]string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ValidateFailure",
|
||||
cfg: &Config{}, // invalid config
|
||||
expectedError: "missing config.Addr",
|
||||
},
|
||||
{
|
||||
name: "GetConnURLFailure",
|
||||
cfg: &Config{
|
||||
Addr: "localhost",
|
||||
Exchange: "test-ex",
|
||||
Durable: true,
|
||||
RoutingKey: "test.key",
|
||||
},
|
||||
envVars: map[string]string{
|
||||
"RABBITMQ_USERNAME": "",
|
||||
"RABBITMQ_PASSWORD": "",
|
||||
},
|
||||
expectedError: "missing RabbitMQ credentials in environment",
|
||||
},
|
||||
{
|
||||
name: "ConnectionFailure",
|
||||
cfg: &Config{
|
||||
Addr: "localhost",
|
||||
Exchange: "test-ex",
|
||||
Durable: true,
|
||||
RoutingKey: "test.key",
|
||||
},
|
||||
dialFunc: func(url string) (*amqp091.Connection, error) {
|
||||
return nil, fmt.Errorf("simulated connection failure")
|
||||
},
|
||||
envVars: map[string]string{
|
||||
"RABBITMQ_USERNAME": "user",
|
||||
"RABBITMQ_PASSWORD": "pass",
|
||||
},
|
||||
expectedError: "failed to connect to RabbitMQ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set environment variables
|
||||
for key, value := range tt.envVars {
|
||||
t.Setenv(key, value)
|
||||
}
|
||||
|
||||
// Mock dialFunc if needed
|
||||
originalDialFunc := DialFunc
|
||||
if tt.dialFunc != nil {
|
||||
DialFunc = tt.dialFunc
|
||||
defer func() {
|
||||
DialFunc = originalDialFunc
|
||||
}()
|
||||
}
|
||||
|
||||
_, _, err := New(tt.cfg)
|
||||
|
||||
if err == nil || (tt.expectedError != "" && !strings.Contains(err.Error(), tt.expectedError)) {
|
||||
t.Errorf("Test %s failed: expected error containing %v, got: %v", tt.name, tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user