diff --git a/go.mod b/go.mod index 12fad60..29412d4 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( require ( github.com/hashicorp/go-retryablehttp v0.7.7 + github.com/rabbitmq/amqp091-go v1.10.0 github.com/rs/zerolog v1.34.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index 821e117..d6ddf7e 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsK github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw= +github.com/rabbitmq/amqp091-go v1.10.0/go.mod h1:Hy4jKW5kQART1u+JkDTF9YYOQUHXqMuhrgxOEeS7G4o= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= diff --git a/pkg/plugin/definition/publisher.go b/pkg/plugin/definition/publisher.go index 1e744da..55ed217 100644 --- a/pkg/plugin/definition/publisher.go +++ b/pkg/plugin/definition/publisher.go @@ -8,6 +8,7 @@ type Publisher interface { Publish(context.Context, string, []byte) error } +// PublisherProvider is the interface for creating new Publisher instances. type PublisherProvider interface { // New initializes a new publisher instance with the given configuration. New(ctx context.Context, config map[string]string) (Publisher, func() error, error) diff --git a/pkg/plugin/implementation/publisher/cmd/plugin.go b/pkg/plugin/implementation/publisher/cmd/plugin.go new file mode 100644 index 0000000..ccf87fa --- /dev/null +++ b/pkg/plugin/implementation/publisher/cmd/plugin.go @@ -0,0 +1,37 @@ +package main + +import ( + "context" + + "github.com/beckn/beckn-onix/pkg/log" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/beckn/beckn-onix/pkg/plugin/implementation/publisher" +) + +// publisherProvider implements the PublisherProvider interface. +// It is responsible for creating a new Publisher instance. +type publisherProvider struct{} + +// New creates a new Publisher instance based on the provided configuration. +func (p *publisherProvider) New(ctx context.Context, config map[string]string) (definition.Publisher, func() error, error) { + cfg := &publisher.Config{ + Addr: config["addr"], + Exchange: config["exchange"], + RoutingKey: config["routing_key"], + Durable: config["durable"] == "true", + UseTLS: config["use_tls"] == "true", + } + log.Debugf(ctx, "Publisher config mapped: %+v", cfg) + + pub, cleanup, err := publisher.New(cfg) + if err != nil { + log.Errorf(ctx, err, "Failed to create publisher instance") + return nil, nil, err + } + + log.Infof(ctx, "Publisher instance created successfully") + return pub, cleanup, nil +} + +// Provider is the instance of publisherProvider that implements the PublisherProvider interface. +var Provider = publisherProvider{} diff --git a/pkg/plugin/implementation/publisher/cmd/plugin_test.go b/pkg/plugin/implementation/publisher/cmd/plugin_test.go new file mode 100644 index 0000000..e3f9837 --- /dev/null +++ b/pkg/plugin/implementation/publisher/cmd/plugin_test.go @@ -0,0 +1,106 @@ +package main + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/beckn/beckn-onix/pkg/plugin/implementation/publisher" + "github.com/rabbitmq/amqp091-go" +) + +type mockChannel struct{} + +func (m *mockChannel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error { + return nil +} +func (m *mockChannel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error { + return nil +} +func (m *mockChannel) Close() error { + return nil +} + +func TestPublisherProvider_New_Success(t *testing.T) { + // Save original dialFunc and channelFunc + originalDialFunc := publisher.DialFunc + originalChannelFunc := publisher.ChannelFunc + defer func() { + publisher.DialFunc = originalDialFunc + publisher.ChannelFunc = originalChannelFunc + }() + + // Override mocks + publisher.DialFunc = func(url string) (*amqp091.Connection, error) { + return nil, nil + } + publisher.ChannelFunc = func(conn *amqp091.Connection) (publisher.Channel, error) { + return &mockChannel{}, nil + } + + t.Setenv("RABBITMQ_USERNAME", "guest") + t.Setenv("RABBITMQ_PASSWORD", "guest") + + config := map[string]string{ + "addr": "localhost", + "exchange": "test-exchange", + "routing_key": "test.key", + "durable": "true", + "use_tls": "false", + } + + ctx := context.Background() + pub, cleanup, err := Provider.New(ctx, config) + + if err != nil { + t.Fatalf("Provider.New returned error: %v", err) + } + if pub == nil { + t.Fatal("Expected non-nil publisher") + } + if cleanup == nil { + t.Fatal("Expected non-nil cleanup function") + } + + if err := cleanup(); err != nil { + t.Errorf("Cleanup returned error: %v", err) + } +} + +func TestPublisherProvider_New_Failure(t *testing.T) { + // Save and restore dialFunc + originalDialFunc := publisher.DialFunc + defer func() { publisher.DialFunc = originalDialFunc }() + + // Simulate dial failure + publisher.DialFunc = func(url string) (*amqp091.Connection, error) { + return nil, errors.New("dial failed") + } + + t.Setenv("RABBITMQ_USERNAME", "guest") + t.Setenv("RABBITMQ_PASSWORD", "guest") + + config := map[string]string{ + "addr": "localhost", + "exchange": "test-exchange", + "routing_key": "test.key", + "durable": "true", + } + + ctx := context.Background() + pub, cleanup, err := Provider.New(ctx, config) + + if err == nil { + t.Fatal("Expected error from Provider.New but got nil") + } + if !strings.Contains(err.Error(), "dial failed") { + t.Errorf("Expected 'dial failed' error, got: %v", err) + } + if pub != nil { + t.Errorf("Expected nil publisher, got: %v", pub) + } + if cleanup != nil { + t.Error("Expected nil cleanup, got non-nil") + } +} diff --git a/pkg/plugin/implementation/publisher/publisher.go b/pkg/plugin/implementation/publisher/publisher.go new file mode 100644 index 0000000..db3e577 --- /dev/null +++ b/pkg/plugin/implementation/publisher/publisher.go @@ -0,0 +1,196 @@ +package publisher + +import ( + "context" + "errors" + "fmt" + "net/url" + "os" + "strings" + + "github.com/beckn/beckn-onix/pkg/log" + "github.com/beckn/beckn-onix/pkg/model" + "github.com/rabbitmq/amqp091-go" +) + +// Config holds the configuration required to establish a connection with RabbitMQ. +type Config struct { + Addr string + Exchange string + RoutingKey string + Durable bool + UseTLS bool +} + +// Channel defines the interface for publishing messages to RabbitMQ. +type Channel interface { + PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error + ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error + Close() error +} + +// Publisher manages the RabbitMQ connection and channel to publish messages. +type Publisher struct { + Conn *amqp091.Connection + Channel Channel + Config *Config +} + +// Error variables representing different failure scenarios. +var ( + ErrEmptyConfig = errors.New("empty config") + ErrAddrMissing = errors.New("missing required field 'Addr'") + ErrExchangeMissing = errors.New("missing required field 'Exchange'") + ErrCredentialMissing = errors.New("missing RabbitMQ credentials in environment") + ErrConnectionFailed = errors.New("failed to connect to RabbitMQ") + ErrChannelFailed = errors.New("failed to open channel") + ErrExchangeDeclare = errors.New("failed to declare exchange") +) + +// Validate checks whether the provided Config is valid for connecting to RabbitMQ. +func Validate(cfg *Config) error { + if cfg == nil { + return model.NewBadReqErr(fmt.Errorf("config is nil")) + } + if strings.TrimSpace(cfg.Addr) == "" { + return model.NewBadReqErr(fmt.Errorf("missing config.Addr")) + } + if strings.TrimSpace(cfg.Exchange) == "" { + return model.NewBadReqErr(fmt.Errorf("missing config.Exchange")) + } + return nil +} + +// GetConnURL constructs the RabbitMQ connection URL using the config and environment credentials. +func GetConnURL(cfg *Config) (string, error) { + user := os.Getenv("RABBITMQ_USERNAME") + pass := os.Getenv("RABBITMQ_PASSWORD") + if user == "" || pass == "" { + return "", model.NewBadReqErr(fmt.Errorf("missing RabbitMQ credentials in environment")) + } + parts := strings.SplitN(strings.TrimSpace(cfg.Addr), "/", 2) + hostPort := parts[0] + vhost := "/" + if len(parts) > 1 { + vhost = parts[1] + } + + if !strings.Contains(hostPort, ":") { + if cfg.UseTLS { + hostPort += ":5671" + } else { + hostPort += ":5672" + } + } + + encodedUser := url.QueryEscape(user) + encodedPass := url.QueryEscape(pass) + encodedVHost := url.QueryEscape(vhost) + protocol := "amqp" + if cfg.UseTLS { + protocol = "amqps" + } + + connURL := fmt.Sprintf("%s://%s:%s@%s/%s", protocol, encodedUser, encodedPass, hostPort, encodedVHost) + log.Debugf(context.Background(), "Generated RabbitMQ connection details: protocol=%s, hostPort=%s, vhost=%s", protocol, hostPort, vhost) + + return connURL, nil +} + +// Publish sends a message to the configured RabbitMQ exchange with the specified routing key. +// If routingKey is empty, the default routing key from Config is used. +func (p *Publisher) Publish(ctx context.Context, routingKey string, msg []byte) error { + if routingKey == "" { + routingKey = p.Config.RoutingKey + } + log.Debugf(ctx, "Attempting to publish message. Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey) + err := p.Channel.PublishWithContext( + ctx, + p.Config.Exchange, + routingKey, + false, + false, + amqp091.Publishing{ + ContentType: "application/json", + Body: msg, + }, + ) + + if err != nil { + log.Errorf(ctx, err, "Publish failed for Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey) + return model.NewBadReqErr(fmt.Errorf("publish message failed: %w", err)) + } + + log.Infof(ctx, "Message published successfully to Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey) + return nil +} + +// DialFunc is a function variable used to establish a connection to RabbitMQ. +var DialFunc = amqp091.Dial + +// ChannelFunc is a function variable used to open a channel on the given RabbitMQ connection. +var ChannelFunc = func(conn *amqp091.Connection) (Channel, error) { + return conn.Channel() +} + +// New initializes a new Publisher with the given config, opens a connection, +// channel, and declares the exchange. Returns the publisher and a cleanup function. +func New(cfg *Config) (*Publisher, func() error, error) { + // Step 1: Validate config + if err := Validate(cfg); err != nil { + return nil, nil, err + } + + // Step 2: Build connection URL + connURL, err := GetConnURL(cfg) + if err != nil { + return nil, nil, fmt.Errorf("%w: %v", ErrConnectionFailed, err) + } + + // Step 3: Dial connection + conn, err := DialFunc(connURL) + if err != nil { + return nil, nil, fmt.Errorf("%w: %v", ErrConnectionFailed, err) + } + + // Step 4: Open channel + ch, err := ChannelFunc(conn) + if err != nil { + conn.Close() + return nil, nil, fmt.Errorf("%w: %v", ErrChannelFailed, err) + } + + // Step 5: Declare exchange + if err := ch.ExchangeDeclare( + cfg.Exchange, + "topic", + cfg.Durable, + false, + false, + false, + nil, + ); err != nil { + ch.Close() + conn.Close() + return nil, nil, fmt.Errorf("%w: %v", ErrExchangeDeclare, err) + } + + // Step 6: Construct publisher + pub := &Publisher{ + Conn: conn, + Channel: ch, + Config: cfg, + } + + cleanup := func() error { + if ch != nil { + _ = ch.Close() + } + if conn != nil { + return conn.Close() + } + return nil + } + + return pub, cleanup, nil +} diff --git a/pkg/plugin/implementation/publisher/publisher_test.go b/pkg/plugin/implementation/publisher/publisher_test.go new file mode 100644 index 0000000..82b8404 --- /dev/null +++ b/pkg/plugin/implementation/publisher/publisher_test.go @@ -0,0 +1,362 @@ +package publisher + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/rabbitmq/amqp091-go" +) + +func TestGetConnURLSuccess(t *testing.T) { + tests := []struct { + name string + config *Config + }{ + { + name: "Valid config with connection address", + config: &Config{ + Addr: "localhost:5672", + UseTLS: false, + }, + }, + + { + name: "Valid config with vhost", + config: &Config{ + Addr: "localhost:5672/myvhost", + UseTLS: false, + }, + }, + { + name: "Addr with leading and trailing spaces", + config: &Config{ + Addr: " localhost:5672/myvhost ", + UseTLS: false, + }, + }, + } + + // Set valid credentials + t.Setenv("RABBITMQ_USERNAME", "guest") + t.Setenv("RABBITMQ_PASSWORD", "guest") + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url, err := GetConnURL(tt.config) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if url == "" { + t.Error("expected non-empty URL, got empty string") + } + }) + } +} + +func TestGetConnURLFailure(t *testing.T) { + tests := []struct { + name string + username string + password string + config *Config + wantErr bool + }{ + { + name: "Missing credentials", + username: "", + password: "", + config: &Config{Addr: "localhost:5672"}, + wantErr: true, + }, + { + name: "Missing config address", + username: "guest", + password: "guest", + config: &Config{}, // this won't error unless Validate() is called separately + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.username != "" { + t.Setenv("RABBITMQ_USERNAME", tt.username) + } + + if tt.password != "" { + t.Setenv("RABBITMQ_PASSWORD", tt.password) + } + + url, err := GetConnURL(tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("unexpected error. gotErr = %v, wantErr = %v", err != nil, tt.wantErr) + } + + if err == nil && url == "" { + t.Errorf("expected non-empty URL, got empty string") + } + }) + } +} + +func TestValidateSuccess(t *testing.T) { + tests := []struct { + name string + config *Config + }{ + { + name: "Valid config with Addr and Exchange", + config: &Config{ + Addr: "localhost:5672", + Exchange: "ex", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Validate(tt.config); err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + } +} + +func TestValidateFailure(t *testing.T) { + tests := []struct { + name string + config *Config + expectedErrr string + }{ + { + name: "Nil config", + config: nil, + expectedErrr: "config is nil", + }, + { + name: "Missing Addr", + config: &Config{Exchange: "ex"}, + expectedErrr: "missing config.Addr", + }, + { + name: "Missing Exchange", + config: &Config{Addr: "localhost:5672"}, + expectedErrr: "missing config.Exchange", + }, + { + name: "Empty Addr and Exchange", + config: &Config{Addr: " ", Exchange: " "}, + expectedErrr: "missing config.Addr", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Validate(tt.config) + if err == nil { + t.Errorf("expected error for invalid config, got nil") + return + } + if !strings.Contains(err.Error(), tt.expectedErrr) { + t.Errorf("expected error to contain %q, got: %v", tt.expectedErrr, err) + } + }) + } +} + +type mockChannelForPublish struct { + published bool + exchange string + key string + body []byte + fail bool +} + +func (m *mockChannelForPublish) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error { + if m.fail { + return fmt.Errorf("simulated publish failure") + } + m.published = true + m.exchange = exchange + m.key = key + m.body = msg.Body + return nil +} + +func (m *mockChannelForPublish) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error { + return nil +} + +func (m *mockChannelForPublish) Close() error { + return nil +} + +func TestPublishSuccess(t *testing.T) { + mockCh := &mockChannelForPublish{} + + p := &Publisher{ + Channel: mockCh, + Config: &Config{ + Exchange: "mock.exchange", + RoutingKey: "mock.key", + }, + } + + err := p.Publish(context.Background(), "", []byte(`{"test": true}`)) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + + if !mockCh.published { + t.Error("expected message to be published, but it wasn't") + } + + if mockCh.exchange != "mock.exchange" || mockCh.key != "mock.key" { + t.Errorf("unexpected exchange or key. got (%s, %s)", mockCh.exchange, mockCh.key) + } +} + +func TestPublishFailure(t *testing.T) { + mockCh := &mockChannelForPublish{fail: true} + + p := &Publisher{ + Channel: mockCh, + Config: &Config{ + Exchange: "mock.exchange", + RoutingKey: "mock.key", + }, + } + + err := p.Publish(context.Background(), "", []byte(`{"test": true}`)) + if err == nil { + t.Error("expected error from failed publish, got nil") + } +} + +type mockChannel struct{} + +func (m *mockChannel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error { + return nil +} +func (m *mockChannel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error { + return nil +} +func (m *mockChannel) Close() error { + return nil +} + +func TestNewPublisherSucess(t *testing.T) { + originalDialFunc := DialFunc + originalChannelFunc := ChannelFunc + defer func() { + DialFunc = originalDialFunc + ChannelFunc = originalChannelFunc + }() + + // mockedConn := &mockConnection{} + + DialFunc = func(url string) (*amqp091.Connection, error) { + return nil, nil + } + + ChannelFunc = func(conn *amqp091.Connection) (Channel, error) { + return &mockChannel{}, nil + } + + cfg := &Config{ + Addr: "localhost", + Exchange: "test-ex", + Durable: true, + RoutingKey: "test.key", + } + + t.Setenv("RABBITMQ_USERNAME", "user") + t.Setenv("RABBITMQ_PASSWORD", "pass") + + pub, cleanup, err := New(cfg) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + if pub == nil { + t.Fatal("Publisher should not be nil") + } + if cleanup == nil { + t.Fatal("Cleanup should not be nil") + } + if err := cleanup(); err != nil { + t.Errorf("Cleanup failed: %v", err) + } +} + +func TestNewPublisherFailures(t *testing.T) { + tests := []struct { + name string + cfg *Config + dialFunc func(url string) (*amqp091.Connection, error) // Mocked dial function + envVars map[string]string + expectedError string + }{ + { + name: "ValidateFailure", + cfg: &Config{}, // invalid config + expectedError: "missing config.Addr", + }, + { + name: "GetConnURLFailure", + cfg: &Config{ + Addr: "localhost", + Exchange: "test-ex", + Durable: true, + RoutingKey: "test.key", + }, + envVars: map[string]string{ + "RABBITMQ_USERNAME": "", + "RABBITMQ_PASSWORD": "", + }, + expectedError: "missing RabbitMQ credentials in environment", + }, + { + name: "ConnectionFailure", + cfg: &Config{ + Addr: "localhost", + Exchange: "test-ex", + Durable: true, + RoutingKey: "test.key", + }, + dialFunc: func(url string) (*amqp091.Connection, error) { + return nil, fmt.Errorf("simulated connection failure") + }, + envVars: map[string]string{ + "RABBITMQ_USERNAME": "user", + "RABBITMQ_PASSWORD": "pass", + }, + expectedError: "failed to connect to RabbitMQ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set environment variables + for key, value := range tt.envVars { + t.Setenv(key, value) + } + + // Mock dialFunc if needed + originalDialFunc := DialFunc + if tt.dialFunc != nil { + DialFunc = tt.dialFunc + defer func() { + DialFunc = originalDialFunc + }() + } + + _, _, err := New(tt.cfg) + + if err == nil || (tt.expectedError != "" && !strings.Contains(err.Error(), tt.expectedError)) { + t.Errorf("Test %s failed: expected error containing %v, got: %v", tt.name, tt.expectedError, err) + } + }) + } +}