Merge branch 'beckn-onix-v1.0-develop' into feature/redis-plugin

This commit is contained in:
Tanya Madaan
2025-06-16 12:38:45 +05:30
committed by GitHub
20 changed files with 2747 additions and 31 deletions

View File

@@ -8,13 +8,11 @@ import (
// KeyManager defines the interface for key management operations/methods.
type KeyManager interface {
GenerateKeyPairs() (*model.Keyset, error)
StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error
SigningPrivateKey(ctx context.Context, keyID string) (string, string, error)
EncrPrivateKey(ctx context.Context, keyID string) (string, string, error)
SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error)
EncrPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error)
DeletePrivateKeys(ctx context.Context, keyID string) error
GenerateKeyset() (*model.Keyset, error)
InsertKeyset(ctx context.Context, keyID string, keyset *model.Keyset) error
Keyset(ctx context.Context, keyID string) (*model.Keyset, error)
LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (signingPublicKey string, encrPublicKey string, err error)
DeleteKeyset(ctx context.Context, keyID string) error
}
// KeyManagerProvider initializes a new signer instance.

View File

@@ -8,6 +8,7 @@ type Publisher interface {
Publish(context.Context, string, []byte) error
}
// PublisherProvider is the interface for creating new Publisher instances.
type PublisherProvider interface {
// New initializes a new publisher instance with the given configuration.
New(ctx context.Context, config map[string]string) (Publisher, func() error, error)

View File

@@ -0,0 +1,34 @@
package main
import (
"context"
"github.com/beckn/beckn-onix/pkg/log"
"github.com/beckn/beckn-onix/pkg/plugin/definition"
"github.com/beckn/beckn-onix/pkg/plugin/implementation/keymanager"
)
// keyManagerProvider implements the plugin provider for the KeyManager plugin.
type keyManagerProvider struct{}
// newKeyManagerFunc is a function type that creates a new KeyManager instance.
var newKeyManagerFunc = keymanager.New
// New creates and initializes a new KeyManager instance using the provided cache, registry lookup, and configuration.
func (k *keyManagerProvider) New(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg map[string]string) (definition.KeyManager, func() error, error) {
config := &keymanager.Config{
VaultAddr: cfg["vaultAddr"],
KVVersion: cfg["kvVersion"],
}
log.Debugf(ctx, "Keymanager config mapped: %+v", cfg)
km, cleanup, err := newKeyManagerFunc(ctx, cache, registry, config)
if err != nil {
log.Error(ctx, err, "Failed to initialize KeyManager")
return nil, nil, err
}
log.Debugf(ctx, "KeyManager instance created successfully")
return km, cleanup, nil
}
// Provider is the exported instance of keyManagerProvider used for plugin registration.
var Provider = keyManagerProvider{}

View File

@@ -0,0 +1,127 @@
package main
import (
"context"
"fmt"
"testing"
"time"
"github.com/beckn/beckn-onix/pkg/model"
"github.com/beckn/beckn-onix/pkg/plugin/definition"
"github.com/beckn/beckn-onix/pkg/plugin/implementation/keymanager"
)
type mockRegistry struct {
LookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error)
}
func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) {
if m.LookupFunc != nil {
return m.LookupFunc(ctx, sub)
}
return []model.Subscription{
{
Subscriber: model.Subscriber{
SubscriberID: sub.SubscriberID,
URL: "https://mock.registry/subscriber",
Type: "BPP",
Domain: "retail",
},
KeyID: sub.KeyID,
SigningPublicKey: "mock-signing-public-key",
EncrPublicKey: "mock-encryption-public-key",
ValidFrom: time.Now().Add(-time.Hour),
ValidUntil: time.Now().Add(time.Hour),
Status: "SUBSCRIBED",
Created: time.Now().Add(-2 * time.Hour),
Updated: time.Now(),
Nonce: "mock-nonce",
},
}, nil
}
type mockCache struct{}
func (m *mockCache) Get(ctx context.Context, key string) (string, error) {
return "", nil
}
func (m *mockCache) Set(ctx context.Context, key string, value string, ttl time.Duration) error {
return nil
}
func (m *mockCache) Clear(ctx context.Context) error {
return nil
}
func (m *mockCache) Delete(ctx context.Context, key string) error {
return nil
}
func TestNewSuccess(t *testing.T) {
// Setup dummy implementations and variables
ctx := context.Background()
cache := &mockCache{}
registry := &mockRegistry{}
cfg := map[string]string{
"vaultAddr": "http://dummy-vault",
"kvVersion": "2",
}
cleanupCalled := false
fakeCleanup := func() error {
cleanupCalled = true
return nil
}
newKeyManagerFunc = func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (*keymanager.KeyMgr, func() error, error) {
// return a mock struct pointer of *keymanager.KeyMgr or a stub instance
return &keymanager.KeyMgr{}, fakeCleanup, nil
}
// Create provider and call New
provider := &keyManagerProvider{}
km, cleanup, err := provider.New(ctx, cache, registry, cfg)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if km == nil {
t.Fatal("Expected non-nil KeyManager instance")
}
if cleanup == nil {
t.Fatal("Expected non-nil cleanup function")
}
// Call cleanup and check if it behaves correctly
if err := cleanup(); err != nil {
t.Fatalf("Expected no error from cleanup, got %v", err)
}
if !cleanupCalled {
t.Error("Expected cleanup function to be called")
}
}
func TestNewFailure(t *testing.T) {
// Setup dummy variables
ctx := context.Background()
cache := &mockCache{}
registry := &mockRegistry{}
cfg := map[string]string{
"vaultAddr": "http://dummy-vault",
"kvVersion": "2",
}
newKeyManagerFunc = func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (*keymanager.KeyMgr, func() error, error) {
return nil, nil, fmt.Errorf("some error")
}
provider := &keyManagerProvider{}
km, cleanup, err := provider.New(ctx, cache, registry, cfg)
if err == nil {
t.Fatal("Expected error, got nil")
}
if km != nil {
t.Error("Expected nil KeyManager on error")
}
if cleanup != nil {
t.Error("Expected nil cleanup function on error")
}
}

View File

@@ -0,0 +1,328 @@
package keymanager
import (
"context"
"crypto/ecdh"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"github.com/beckn/beckn-onix/pkg/log"
"github.com/beckn/beckn-onix/pkg/model"
"github.com/beckn/beckn-onix/pkg/plugin/definition"
"github.com/google/uuid"
vault "github.com/hashicorp/vault/api"
)
// Config holds configuration parameters for connecting to Vault.
type Config struct {
VaultAddr string
KVVersion string
}
// KeyMgr provides methods for managing cryptographic keys using Vault.
type KeyMgr struct {
VaultClient *vault.Client
Registry definition.RegistryLookup
Cache definition.Cache
KvVersion string
SecretPath string
}
var (
// ErrEmptyKeyID indicates that the provided key ID is empty.
ErrEmptyKeyID = errors.New("invalid request: keyID cannot be empty")
// ErrNilKeySet indicates that the provided keyset is nil.
ErrNilKeySet = errors.New("keyset cannot be nil")
// ErrEmptySubscriberID indicates that the provided subscriber ID is empty.
ErrEmptySubscriberID = errors.New("invalid request: subscriberID cannot be empty")
// ErrEmptyUniqueKeyID indicates that the provided unique key ID is empty.
ErrEmptyUniqueKeyID = errors.New("invalid request: uniqueKeyID cannot be empty")
// ErrSubscriberNotFound indicates that no subscriber was found with the provided credentials.
ErrSubscriberNotFound = errors.New("no subscriber found with given credentials")
// ErrNilCache indicates that the cache implementation is nil.
ErrNilCache = errors.New("cache implementation cannot be nil")
// ErrNilRegistryLookup indicates that the registry lookup implementation is nil.
ErrNilRegistryLookup = errors.New("registry lookup implementation cannot be nil")
)
// ValidateCfg validates the Vault configuration and sets default KV version if missing.
func ValidateCfg(cfg *Config) error {
if cfg.VaultAddr == "" {
return errors.New("invalid config: VaultAddr cannot be empty")
}
kvVersion := strings.ToLower(cfg.KVVersion)
if kvVersion == "" {
kvVersion = "v1"
} else if kvVersion != "v1" && kvVersion != "v2" {
return fmt.Errorf("invalid KVVersion: must be 'v1' or 'v2'")
}
cfg.KVVersion = kvVersion
return nil
}
// getVaultClient is a function that creates a new Vault client.
// This is exported for testing purposes.
var getVaultClient = GetVaultClient
// New creates a new KeyMgr instance with the provided configuration, cache, and registry lookup.
func New(ctx context.Context, cache definition.Cache, registryLookup definition.RegistryLookup, cfg *Config) (*KeyMgr, func() error, error) {
log.Info(ctx, "Initializing KeyManager plugin")
// Validate configuration.
if err := ValidateCfg(cfg); err != nil {
log.Error(ctx, err, "Invalid configuration for KeyManager")
return nil, nil, err
}
// Check if cache implementation is provided.
if cache == nil {
log.Error(ctx, ErrNilCache, "Cache is nil in KeyManager initialization")
return nil, nil, ErrNilCache
}
// Check if registry lookup implementation is provided.
if registryLookup == nil {
log.Error(ctx, ErrNilRegistryLookup, "RegistryLookup is nil in KeyManager initialization")
return nil, nil, ErrNilRegistryLookup
}
// Initialize Vault client.
log.Debugf(ctx, "Creating Vault client with address: %s", cfg.VaultAddr)
vaultClient, err := getVaultClient(ctx, cfg.VaultAddr)
if err != nil {
log.Errorf(ctx, err, "Failed to create Vault client at address: %s", cfg.VaultAddr)
return nil, nil, fmt.Errorf("failed to create vault client: %w", err)
}
log.Info(ctx, "Successfully created Vault client")
// Create KeyManager instance.
km := &KeyMgr{
VaultClient: vaultClient,
Registry: registryLookup,
Cache: cache,
KvVersion: cfg.KVVersion,
}
// Cleanup function to release KeyManager resources.
cleanup := func() error {
log.Info(ctx, "Cleaning up KeyManager resources")
km.VaultClient = nil
km.Cache = nil
km.Registry = nil
return nil
}
log.Info(ctx, "KeyManager plugin initialized successfully")
return km, cleanup, nil
}
// NewVaultClient creates a new Vault client instance.
// This function is exported for testing purposes.
var NewVaultClient = vault.NewClient
// GetVaultClient creates and authenticates a Vault client using AppRole.
func GetVaultClient(ctx context.Context, vaultAddr string) (*vault.Client, error) {
roleID := os.Getenv("VAULT_ROLE_ID")
secretID := os.Getenv("VAULT_SECRET_ID")
if roleID == "" || secretID == "" {
log.Error(ctx, fmt.Errorf("missing credentials"), "VAULT_ROLE_ID or VAULT_SECRET_ID is not set")
return nil, fmt.Errorf("VAULT_ROLE_ID or VAULT_SECRET_ID is not set")
}
config := vault.DefaultConfig()
config.Address = vaultAddr
client, err := NewVaultClient(config)
if err != nil {
log.Error(ctx, err, "failed to create Vault client")
return nil, fmt.Errorf("failed to create Vault client: %w", err)
}
data := map[string]interface{}{
"role_id": roleID,
"secret_id": secretID,
}
log.Info(ctx, "Logging into Vault with AppRole")
resp, err := client.Logical().Write("auth/approle/login", data)
if err != nil {
log.Error(ctx, err, "failed to login with AppRole")
return nil, fmt.Errorf("failed to login with AppRole: %w", err)
}
if resp == nil || resp.Auth == nil {
log.Error(ctx, nil, "AppRole login failed: no auth info returned")
return nil, errors.New("AppRole login failed: no auth info returned")
}
log.Info(ctx, "Vault login successful")
client.SetToken(resp.Auth.ClientToken)
return client, nil
}
var (
ed25519KeyGenFunc = ed25519.GenerateKey
x25519KeyGenFunc = ecdh.X25519().GenerateKey
uuidGenFunc = uuid.NewRandom
)
// GenerateKeyset generates a new signing (Ed25519) and encryption (X25519) key pair.
func (km *KeyMgr) GenerateKeyset() (*model.Keyset, error) {
signingPublic, signingPrivate, err := ed25519KeyGenFunc(rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate signing key pair: %w", err)
}
encrPrivateKey, err := x25519KeyGenFunc(rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate encryption key pair: %w", err)
}
encrPublicKey := encrPrivateKey.PublicKey().Bytes()
uuid, err := uuidGenFunc()
if err != nil {
return nil, fmt.Errorf("failed to generate unique key id uuid: %w", err)
}
return &model.Keyset{
UniqueKeyID: uuid.String(),
SigningPrivate: encodeBase64(signingPrivate.Seed()),
SigningPublic: encodeBase64(signingPublic),
EncrPrivate: encodeBase64(encrPrivateKey.Bytes()),
EncrPublic: encodeBase64(encrPublicKey),
}, nil
}
// getSecretPath constructs the Vault secret path for storing keys based on the KV version.
func (km *KeyMgr) getSecretPath(keyID string) string {
if km.KvVersion == "v2" {
return fmt.Sprintf("secret/data/keys/%s", keyID)
}
return fmt.Sprintf("secret/keys/%s", keyID)
}
// InsertKeyset stores the given keyset in Vault under the specified key ID.
func (km *KeyMgr) InsertKeyset(ctx context.Context, keyID string, keys *model.Keyset) error {
if keyID == "" {
return ErrEmptyKeyID
}
if keys == nil {
return ErrNilKeySet
}
keyData := map[string]interface{}{
"uniqueKeyID": keys.UniqueKeyID,
"signingPublicKey": keys.SigningPublic,
"signingPrivateKey": keys.SigningPrivate,
"encrPublicKey": keys.EncrPublic,
"encrPrivateKey": keys.EncrPrivate,
}
path := km.getSecretPath(keyID)
var payload map[string]interface{}
if km.KvVersion == "v2" {
payload = map[string]interface{}{"data": keyData}
} else {
payload = keyData
}
_, err := km.VaultClient.Logical().Write(path, payload)
if err != nil {
return fmt.Errorf("failed to store secret in Vault at path %s: %w", path, err)
}
return nil
}
// DeleteKeyset deletes the private keys for the given key ID from Vault.
func (km *KeyMgr) DeleteKeyset(ctx context.Context, keyID string) error {
if keyID == "" {
return ErrEmptyKeyID
}
path := km.getSecretPath(keyID)
return km.VaultClient.KVv2(path).Delete(ctx, keyID)
}
// Keyset retrieves the keyset for the given key ID from Vault and public keys from the registry.
func (km *KeyMgr) Keyset(ctx context.Context, keyID string) (*model.Keyset, error) {
if keyID == "" {
return nil, ErrEmptyKeyID
}
path := km.getSecretPath(keyID)
secret, err := km.VaultClient.Logical().Read(path)
if err != nil || secret == nil {
return nil, fmt.Errorf("failed to read secret from Vault: %w", err)
}
var data map[string]interface{}
if km.KvVersion == "v2" {
dataRaw, ok := secret.Data["data"]
if !ok {
return nil, errors.New("missing 'data' in secret response")
}
data, ok = dataRaw.(map[string]interface{})
if !ok {
return nil, errors.New("invalid 'data' format in Vault response")
}
} else {
data = secret.Data
}
return &model.Keyset{
UniqueKeyID: data["uniqueKeyID"].(string),
SigningPublic: data["signingPublicKey"].(string),
SigningPrivate: data["signingPrivateKey"].(string),
EncrPublic: data["encrPublicKey"].(string),
EncrPrivate: data["encrPrivateKey"].(string),
}, nil
}
// LookupNPKeys retrieves the signing and encryption public keys for the given subscriber ID and unique key ID.
func (km *KeyMgr) LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error) {
cacheKey := fmt.Sprintf("%s_%s", subscriberID, uniqueKeyID)
cachedData, err := km.Cache.Get(ctx, cacheKey)
if err == nil {
var keys model.Keyset
if err := json.Unmarshal([]byte(cachedData), &keys); err == nil {
return keys.SigningPublic, keys.EncrPublic, nil
}
}
subscribers, err := km.Registry.Lookup(ctx, &model.Subscription{
Subscriber: model.Subscriber{
SubscriberID: subscriberID,
},
KeyID: uniqueKeyID,
})
if err != nil {
return "", "", fmt.Errorf("failed to lookup registry: %w", err)
}
if len(subscribers) == 0 {
return "", "", ErrSubscriberNotFound
}
return subscribers[0].SigningPublicKey, subscribers[0].EncrPublicKey, nil
}
// validateParams checks that subscriberID and uniqueKeyID are not empty.
func validateParams(subscriberID, uniqueKeyID string) error {
if subscriberID == "" {
return ErrEmptySubscriberID
}
if uniqueKeyID == "" {
return ErrEmptyUniqueKeyID
}
return nil
}
// encodeBase64 returns the base64-encoded string of the given data.
func encodeBase64(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,37 @@
package main
import (
"context"
"github.com/beckn/beckn-onix/pkg/log"
"github.com/beckn/beckn-onix/pkg/plugin/definition"
"github.com/beckn/beckn-onix/pkg/plugin/implementation/publisher"
)
// publisherProvider implements the PublisherProvider interface.
// It is responsible for creating a new Publisher instance.
type publisherProvider struct{}
// New creates a new Publisher instance based on the provided configuration.
func (p *publisherProvider) New(ctx context.Context, config map[string]string) (definition.Publisher, func() error, error) {
cfg := &publisher.Config{
Addr: config["addr"],
Exchange: config["exchange"],
RoutingKey: config["routing_key"],
Durable: config["durable"] == "true",
UseTLS: config["use_tls"] == "true",
}
log.Debugf(ctx, "Publisher config mapped: %+v", cfg)
pub, cleanup, err := publisher.New(cfg)
if err != nil {
log.Errorf(ctx, err, "Failed to create publisher instance")
return nil, nil, err
}
log.Infof(ctx, "Publisher instance created successfully")
return pub, cleanup, nil
}
// Provider is the instance of publisherProvider that implements the PublisherProvider interface.
var Provider = publisherProvider{}

View File

@@ -0,0 +1,106 @@
package main
import (
"context"
"errors"
"strings"
"testing"
"github.com/beckn/beckn-onix/pkg/plugin/implementation/publisher"
"github.com/rabbitmq/amqp091-go"
)
type mockChannel struct{}
func (m *mockChannel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error {
return nil
}
func (m *mockChannel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error {
return nil
}
func (m *mockChannel) Close() error {
return nil
}
func TestPublisherProvider_New_Success(t *testing.T) {
// Save original dialFunc and channelFunc
originalDialFunc := publisher.DialFunc
originalChannelFunc := publisher.ChannelFunc
defer func() {
publisher.DialFunc = originalDialFunc
publisher.ChannelFunc = originalChannelFunc
}()
// Override mocks
publisher.DialFunc = func(url string) (*amqp091.Connection, error) {
return nil, nil
}
publisher.ChannelFunc = func(conn *amqp091.Connection) (publisher.Channel, error) {
return &mockChannel{}, nil
}
t.Setenv("RABBITMQ_USERNAME", "guest")
t.Setenv("RABBITMQ_PASSWORD", "guest")
config := map[string]string{
"addr": "localhost",
"exchange": "test-exchange",
"routing_key": "test.key",
"durable": "true",
"use_tls": "false",
}
ctx := context.Background()
pub, cleanup, err := Provider.New(ctx, config)
if err != nil {
t.Fatalf("Provider.New returned error: %v", err)
}
if pub == nil {
t.Fatal("Expected non-nil publisher")
}
if cleanup == nil {
t.Fatal("Expected non-nil cleanup function")
}
if err := cleanup(); err != nil {
t.Errorf("Cleanup returned error: %v", err)
}
}
func TestPublisherProvider_New_Failure(t *testing.T) {
// Save and restore dialFunc
originalDialFunc := publisher.DialFunc
defer func() { publisher.DialFunc = originalDialFunc }()
// Simulate dial failure
publisher.DialFunc = func(url string) (*amqp091.Connection, error) {
return nil, errors.New("dial failed")
}
t.Setenv("RABBITMQ_USERNAME", "guest")
t.Setenv("RABBITMQ_PASSWORD", "guest")
config := map[string]string{
"addr": "localhost",
"exchange": "test-exchange",
"routing_key": "test.key",
"durable": "true",
}
ctx := context.Background()
pub, cleanup, err := Provider.New(ctx, config)
if err == nil {
t.Fatal("Expected error from Provider.New but got nil")
}
if !strings.Contains(err.Error(), "dial failed") {
t.Errorf("Expected 'dial failed' error, got: %v", err)
}
if pub != nil {
t.Errorf("Expected nil publisher, got: %v", pub)
}
if cleanup != nil {
t.Error("Expected nil cleanup, got non-nil")
}
}

View File

@@ -0,0 +1,196 @@
package publisher
import (
"context"
"errors"
"fmt"
"net/url"
"os"
"strings"
"github.com/beckn/beckn-onix/pkg/log"
"github.com/beckn/beckn-onix/pkg/model"
"github.com/rabbitmq/amqp091-go"
)
// Config holds the configuration required to establish a connection with RabbitMQ.
type Config struct {
Addr string
Exchange string
RoutingKey string
Durable bool
UseTLS bool
}
// Channel defines the interface for publishing messages to RabbitMQ.
type Channel interface {
PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error
ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error
Close() error
}
// Publisher manages the RabbitMQ connection and channel to publish messages.
type Publisher struct {
Conn *amqp091.Connection
Channel Channel
Config *Config
}
// Error variables representing different failure scenarios.
var (
ErrEmptyConfig = errors.New("empty config")
ErrAddrMissing = errors.New("missing required field 'Addr'")
ErrExchangeMissing = errors.New("missing required field 'Exchange'")
ErrCredentialMissing = errors.New("missing RabbitMQ credentials in environment")
ErrConnectionFailed = errors.New("failed to connect to RabbitMQ")
ErrChannelFailed = errors.New("failed to open channel")
ErrExchangeDeclare = errors.New("failed to declare exchange")
)
// Validate checks whether the provided Config is valid for connecting to RabbitMQ.
func Validate(cfg *Config) error {
if cfg == nil {
return model.NewBadReqErr(fmt.Errorf("config is nil"))
}
if strings.TrimSpace(cfg.Addr) == "" {
return model.NewBadReqErr(fmt.Errorf("missing config.Addr"))
}
if strings.TrimSpace(cfg.Exchange) == "" {
return model.NewBadReqErr(fmt.Errorf("missing config.Exchange"))
}
return nil
}
// GetConnURL constructs the RabbitMQ connection URL using the config and environment credentials.
func GetConnURL(cfg *Config) (string, error) {
user := os.Getenv("RABBITMQ_USERNAME")
pass := os.Getenv("RABBITMQ_PASSWORD")
if user == "" || pass == "" {
return "", model.NewBadReqErr(fmt.Errorf("missing RabbitMQ credentials in environment"))
}
parts := strings.SplitN(strings.TrimSpace(cfg.Addr), "/", 2)
hostPort := parts[0]
vhost := "/"
if len(parts) > 1 {
vhost = parts[1]
}
if !strings.Contains(hostPort, ":") {
if cfg.UseTLS {
hostPort += ":5671"
} else {
hostPort += ":5672"
}
}
encodedUser := url.QueryEscape(user)
encodedPass := url.QueryEscape(pass)
encodedVHost := url.QueryEscape(vhost)
protocol := "amqp"
if cfg.UseTLS {
protocol = "amqps"
}
connURL := fmt.Sprintf("%s://%s:%s@%s/%s", protocol, encodedUser, encodedPass, hostPort, encodedVHost)
log.Debugf(context.Background(), "Generated RabbitMQ connection details: protocol=%s, hostPort=%s, vhost=%s", protocol, hostPort, vhost)
return connURL, nil
}
// Publish sends a message to the configured RabbitMQ exchange with the specified routing key.
// If routingKey is empty, the default routing key from Config is used.
func (p *Publisher) Publish(ctx context.Context, routingKey string, msg []byte) error {
if routingKey == "" {
routingKey = p.Config.RoutingKey
}
log.Debugf(ctx, "Attempting to publish message. Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey)
err := p.Channel.PublishWithContext(
ctx,
p.Config.Exchange,
routingKey,
false,
false,
amqp091.Publishing{
ContentType: "application/json",
Body: msg,
},
)
if err != nil {
log.Errorf(ctx, err, "Publish failed for Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey)
return model.NewBadReqErr(fmt.Errorf("publish message failed: %w", err))
}
log.Infof(ctx, "Message published successfully to Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey)
return nil
}
// DialFunc is a function variable used to establish a connection to RabbitMQ.
var DialFunc = amqp091.Dial
// ChannelFunc is a function variable used to open a channel on the given RabbitMQ connection.
var ChannelFunc = func(conn *amqp091.Connection) (Channel, error) {
return conn.Channel()
}
// New initializes a new Publisher with the given config, opens a connection,
// channel, and declares the exchange. Returns the publisher and a cleanup function.
func New(cfg *Config) (*Publisher, func() error, error) {
// Step 1: Validate config
if err := Validate(cfg); err != nil {
return nil, nil, err
}
// Step 2: Build connection URL
connURL, err := GetConnURL(cfg)
if err != nil {
return nil, nil, fmt.Errorf("%w: %v", ErrConnectionFailed, err)
}
// Step 3: Dial connection
conn, err := DialFunc(connURL)
if err != nil {
return nil, nil, fmt.Errorf("%w: %v", ErrConnectionFailed, err)
}
// Step 4: Open channel
ch, err := ChannelFunc(conn)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("%w: %v", ErrChannelFailed, err)
}
// Step 5: Declare exchange
if err := ch.ExchangeDeclare(
cfg.Exchange,
"topic",
cfg.Durable,
false,
false,
false,
nil,
); err != nil {
ch.Close()
conn.Close()
return nil, nil, fmt.Errorf("%w: %v", ErrExchangeDeclare, err)
}
// Step 6: Construct publisher
pub := &Publisher{
Conn: conn,
Channel: ch,
Config: cfg,
}
cleanup := func() error {
if ch != nil {
_ = ch.Close()
}
if conn != nil {
return conn.Close()
}
return nil
}
return pub, cleanup, nil
}

View File

@@ -0,0 +1,362 @@
package publisher
import (
"context"
"fmt"
"strings"
"testing"
"github.com/rabbitmq/amqp091-go"
)
func TestGetConnURLSuccess(t *testing.T) {
tests := []struct {
name string
config *Config
}{
{
name: "Valid config with connection address",
config: &Config{
Addr: "localhost:5672",
UseTLS: false,
},
},
{
name: "Valid config with vhost",
config: &Config{
Addr: "localhost:5672/myvhost",
UseTLS: false,
},
},
{
name: "Addr with leading and trailing spaces",
config: &Config{
Addr: " localhost:5672/myvhost ",
UseTLS: false,
},
},
}
// Set valid credentials
t.Setenv("RABBITMQ_USERNAME", "guest")
t.Setenv("RABBITMQ_PASSWORD", "guest")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url, err := GetConnURL(tt.config)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if url == "" {
t.Error("expected non-empty URL, got empty string")
}
})
}
}
func TestGetConnURLFailure(t *testing.T) {
tests := []struct {
name string
username string
password string
config *Config
wantErr bool
}{
{
name: "Missing credentials",
username: "",
password: "",
config: &Config{Addr: "localhost:5672"},
wantErr: true,
},
{
name: "Missing config address",
username: "guest",
password: "guest",
config: &Config{}, // this won't error unless Validate() is called separately
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.username != "" {
t.Setenv("RABBITMQ_USERNAME", tt.username)
}
if tt.password != "" {
t.Setenv("RABBITMQ_PASSWORD", tt.password)
}
url, err := GetConnURL(tt.config)
if (err != nil) != tt.wantErr {
t.Errorf("unexpected error. gotErr = %v, wantErr = %v", err != nil, tt.wantErr)
}
if err == nil && url == "" {
t.Errorf("expected non-empty URL, got empty string")
}
})
}
}
func TestValidateSuccess(t *testing.T) {
tests := []struct {
name string
config *Config
}{
{
name: "Valid config with Addr and Exchange",
config: &Config{
Addr: "localhost:5672",
Exchange: "ex",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := Validate(tt.config); err != nil {
t.Errorf("expected no error, got: %v", err)
}
})
}
}
func TestValidateFailure(t *testing.T) {
tests := []struct {
name string
config *Config
expectedErrr string
}{
{
name: "Nil config",
config: nil,
expectedErrr: "config is nil",
},
{
name: "Missing Addr",
config: &Config{Exchange: "ex"},
expectedErrr: "missing config.Addr",
},
{
name: "Missing Exchange",
config: &Config{Addr: "localhost:5672"},
expectedErrr: "missing config.Exchange",
},
{
name: "Empty Addr and Exchange",
config: &Config{Addr: " ", Exchange: " "},
expectedErrr: "missing config.Addr",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := Validate(tt.config)
if err == nil {
t.Errorf("expected error for invalid config, got nil")
return
}
if !strings.Contains(err.Error(), tt.expectedErrr) {
t.Errorf("expected error to contain %q, got: %v", tt.expectedErrr, err)
}
})
}
}
type mockChannelForPublish struct {
published bool
exchange string
key string
body []byte
fail bool
}
func (m *mockChannelForPublish) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error {
if m.fail {
return fmt.Errorf("simulated publish failure")
}
m.published = true
m.exchange = exchange
m.key = key
m.body = msg.Body
return nil
}
func (m *mockChannelForPublish) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error {
return nil
}
func (m *mockChannelForPublish) Close() error {
return nil
}
func TestPublishSuccess(t *testing.T) {
mockCh := &mockChannelForPublish{}
p := &Publisher{
Channel: mockCh,
Config: &Config{
Exchange: "mock.exchange",
RoutingKey: "mock.key",
},
}
err := p.Publish(context.Background(), "", []byte(`{"test": true}`))
if err != nil {
t.Errorf("expected no error, got: %v", err)
}
if !mockCh.published {
t.Error("expected message to be published, but it wasn't")
}
if mockCh.exchange != "mock.exchange" || mockCh.key != "mock.key" {
t.Errorf("unexpected exchange or key. got (%s, %s)", mockCh.exchange, mockCh.key)
}
}
func TestPublishFailure(t *testing.T) {
mockCh := &mockChannelForPublish{fail: true}
p := &Publisher{
Channel: mockCh,
Config: &Config{
Exchange: "mock.exchange",
RoutingKey: "mock.key",
},
}
err := p.Publish(context.Background(), "", []byte(`{"test": true}`))
if err == nil {
t.Error("expected error from failed publish, got nil")
}
}
type mockChannel struct{}
func (m *mockChannel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error {
return nil
}
func (m *mockChannel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error {
return nil
}
func (m *mockChannel) Close() error {
return nil
}
func TestNewPublisherSucess(t *testing.T) {
originalDialFunc := DialFunc
originalChannelFunc := ChannelFunc
defer func() {
DialFunc = originalDialFunc
ChannelFunc = originalChannelFunc
}()
// mockedConn := &mockConnection{}
DialFunc = func(url string) (*amqp091.Connection, error) {
return nil, nil
}
ChannelFunc = func(conn *amqp091.Connection) (Channel, error) {
return &mockChannel{}, nil
}
cfg := &Config{
Addr: "localhost",
Exchange: "test-ex",
Durable: true,
RoutingKey: "test.key",
}
t.Setenv("RABBITMQ_USERNAME", "user")
t.Setenv("RABBITMQ_PASSWORD", "pass")
pub, cleanup, err := New(cfg)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
if pub == nil {
t.Fatal("Publisher should not be nil")
}
if cleanup == nil {
t.Fatal("Cleanup should not be nil")
}
if err := cleanup(); err != nil {
t.Errorf("Cleanup failed: %v", err)
}
}
func TestNewPublisherFailures(t *testing.T) {
tests := []struct {
name string
cfg *Config
dialFunc func(url string) (*amqp091.Connection, error) // Mocked dial function
envVars map[string]string
expectedError string
}{
{
name: "ValidateFailure",
cfg: &Config{}, // invalid config
expectedError: "missing config.Addr",
},
{
name: "GetConnURLFailure",
cfg: &Config{
Addr: "localhost",
Exchange: "test-ex",
Durable: true,
RoutingKey: "test.key",
},
envVars: map[string]string{
"RABBITMQ_USERNAME": "",
"RABBITMQ_PASSWORD": "",
},
expectedError: "missing RabbitMQ credentials in environment",
},
{
name: "ConnectionFailure",
cfg: &Config{
Addr: "localhost",
Exchange: "test-ex",
Durable: true,
RoutingKey: "test.key",
},
dialFunc: func(url string) (*amqp091.Connection, error) {
return nil, fmt.Errorf("simulated connection failure")
},
envVars: map[string]string{
"RABBITMQ_USERNAME": "user",
"RABBITMQ_PASSWORD": "pass",
},
expectedError: "failed to connect to RabbitMQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set environment variables
for key, value := range tt.envVars {
t.Setenv(key, value)
}
// Mock dialFunc if needed
originalDialFunc := DialFunc
if tt.dialFunc != nil {
DialFunc = tt.dialFunc
defer func() {
DialFunc = originalDialFunc
}()
}
_, _, err := New(tt.cfg)
if err == nil || (tt.expectedError != "" && !strings.Contains(err.Error(), tt.expectedError)) {
t.Errorf("Test %s failed: expected error containing %v, got: %v", tt.name, tt.expectedError, err)
}
})
}
}