From 889619fba5c532d4daba851fc079db15e2b0b7c4 Mon Sep 17 00:00:00 2001 From: "mayur.popli" Date: Fri, 28 Mar 2025 17:17:55 +0530 Subject: [PATCH] fix: resolved comments --- go.mod | 3 + pkg/model/error.go | 28 ++- pkg/model/error_test.go | 18 +- .../requestPreProcessor/cmd/plugin.go | 21 +++ .../requestPreProcessor/cmd/plugin_test.go | 85 +++++++++ .../requestPreProcessor/reqpreprocessor.go | 105 +++++++++++ .../reqpreprocessor_test.go | 178 ++++++++++++++++++ 7 files changed, 420 insertions(+), 18 deletions(-) create mode 100644 pkg/plugin/implementation/requestPreProcessor/cmd/plugin.go create mode 100644 pkg/plugin/implementation/requestPreProcessor/cmd/plugin_test.go create mode 100644 pkg/plugin/implementation/requestPreProcessor/reqpreprocessor.go create mode 100644 pkg/plugin/implementation/requestPreProcessor/reqpreprocessor_test.go diff --git a/go.mod b/go.mod index d7ab3fd..8ae8ef8 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,9 @@ require ( github.com/zenazn/pkcs7pad v0.0.0-20170308005700-253a5b1f0e03 github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.10.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/pkg/model/error.go b/pkg/model/error.go index 5c5f450..0c49765 100644 --- a/pkg/model/error.go +++ b/pkg/model/error.go @@ -13,7 +13,7 @@ type Error struct { Message string `json:"message"` } -// Error implements the error interface for the Error struct. +// This implements the error interface for the Error struct. func (e *Error) Error() string { return fmt.Sprintf("Error: Code=%s, Path=%s, Message=%s", e.Code, e.Paths, e.Message) } @@ -23,7 +23,7 @@ type SchemaValidationErr struct { Errors []Error } -// Error implements the error interface for SchemaValidationErr. +// This implements the error interface for SchemaValidationErr. func (e *SchemaValidationErr) Error() string { var errorMessages []string for _, err := range e.Errors { @@ -57,19 +57,17 @@ func (e *SchemaValidationErr) BecknError() *Error { } } -// SignalidationErr represents a collection of schema validation failures. +// SignValidationErr represents a collection of schema validation failures. type SignValidationErr struct { error } -func NewSignValidationErrf(format string, a ...any) *SignValidationErr { - return &SignValidationErr{fmt.Errorf(format, a...)} -} - +// NewSignValidationErr creates a new instance of SignValidationErr from an error. func NewSignValidationErr(e error) *SignValidationErr { return &SignValidationErr{e} } +// BecknError converts the SignValidationErr to an instance of Error. func (e *SignValidationErr) BecknError() *Error { return &Error{ Code: http.StatusText(http.StatusUnauthorized), @@ -77,19 +75,17 @@ func (e *SignValidationErr) BecknError() *Error { } } -// SignalidationErr represents a collection of schema validation failures. +// SignValidationErr represents a collection of schema validation failures. type BadReqErr struct { error } +// NewBadReqErr creates a new instance of BadReqErr from an error. func NewBadReqErr(err error) *BadReqErr { return &BadReqErr{err} } -func NewBadReqErrf(format string, a ...any) *BadReqErr { - return &BadReqErr{fmt.Errorf(format, a...)} -} - +// BecknError converts the BadReqErr to an instance of Error. func (e *BadReqErr) BecknError() *Error { return &Error{ Code: http.StatusText(http.StatusBadRequest), @@ -97,19 +93,17 @@ func (e *BadReqErr) BecknError() *Error { } } -// SignalidationErr represents a collection of schema validation failures. +// SignValidationErr represents a collection of schema validation failures. type NotFoundErr struct { error } +// NewNotFoundErr creates a new instance of NotFoundErr from an error. func NewNotFoundErr(err error) *NotFoundErr { return &NotFoundErr{err} } -func NewNotFoundErrf(format string, a ...any) *NotFoundErr { - return &NotFoundErr{fmt.Errorf(format, a...)} -} - +// BecknError converts the NotFoundErr to an instance of Error. func (e *NotFoundErr) BecknError() *Error { return &Error{ Code: http.StatusText(http.StatusNotFound), diff --git a/pkg/model/error_test.go b/pkg/model/error_test.go index aa6ffb8..3c0727a 100644 --- a/pkg/model/error_test.go +++ b/pkg/model/error_test.go @@ -2,12 +2,28 @@ package model import ( "errors" + "fmt" "testing" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) +// NewSignValidationErrf formats an error message according to a format specifier and arguments,and returns a new instance of SignValidationErr. +func NewSignValidationErrf(format string, a ...any) *SignValidationErr { + return &SignValidationErr{fmt.Errorf(format, a...)} +} + +// NewNotFoundErrf formats an error message according to a format specifier and arguments, and returns a new instance of NotFoundErr. +func NewNotFoundErrf(format string, a ...any) *NotFoundErr { + return &NotFoundErr{fmt.Errorf(format, a...)} +} + +// NewBadReqErrf formats an error message according to a format specifier and arguments, and returns a new instance of BadReqErr. +func NewBadReqErrf(format string, a ...any) *BadReqErr { + return &BadReqErr{fmt.Errorf(format, a...)} +} + func TestError_Error(t *testing.T) { err := &Error{ Code: "404", @@ -154,7 +170,7 @@ func TestRole_UnmarshalYAML_ValidRole(t *testing.T) { yamlData := []byte("bap") err := yaml.Unmarshal(yamlData, &role) - assert.NoError(t, err) + assert.NoError(t, err) //TODO: should replace assert here assert.Equal(t, RoleBAP, role) } diff --git a/pkg/plugin/implementation/requestPreProcessor/cmd/plugin.go b/pkg/plugin/implementation/requestPreProcessor/cmd/plugin.go new file mode 100644 index 0000000..4a05ecc --- /dev/null +++ b/pkg/plugin/implementation/requestPreProcessor/cmd/plugin.go @@ -0,0 +1,21 @@ +package main + +import ( + "context" + "net/http" + "strings" + + requestpreprocessor "github.com/beckn/beckn-onix/pkg/plugin/implementation/requestPreProcessor" +) + +type provider struct{} + +func (p provider) New(ctx context.Context, c map[string]string) (func(http.Handler) http.Handler, error) { + config := &requestpreprocessor.Config{} + if contextKeysStr, ok := c["ContextKeys"]; ok { + config.ContextKeys = strings.Split(contextKeysStr, ",") + } + return requestpreprocessor.NewUUIDSetter(config) +} + +var Provider = provider{} diff --git a/pkg/plugin/implementation/requestPreProcessor/cmd/plugin_test.go b/pkg/plugin/implementation/requestPreProcessor/cmd/plugin_test.go new file mode 100644 index 0000000..0890dbc --- /dev/null +++ b/pkg/plugin/implementation/requestPreProcessor/cmd/plugin_test.go @@ -0,0 +1,85 @@ +package main + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO: Will Split this into success and fail (two test cases) +func TestProviderNew(t *testing.T) { + testCases := []struct { + name string + config map[string]string + expectedError bool + expectedStatus int + prepareRequest func(req *http.Request) + }{ + { + name: "No Config", + config: map[string]string{}, + expectedError: true, + expectedStatus: http.StatusOK, + prepareRequest: func(req *http.Request) { + // Add minimal required headers. + req.Header.Set("context", "test-context") + req.Header.Set("transaction_id", "test-transaction") + }, + }, + { + name: "With Check Keys", + config: map[string]string{ + "ContextKeys": "message_id,transaction_id", + }, + expectedError: false, + expectedStatus: http.StatusOK, + prepareRequest: func(req *http.Request) { + // Add headers matching the check keys. + req.Header.Set("context", "test-context") + req.Header.Set("transaction_id", "test-transaction") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + requestBody := `{ + "context": { + "transaction_id": "abc" + } + }` + + p := provider{} + middleware, err := p.New(context.Background(), tc.config) + if tc.expectedError { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.NotNil(t, middleware) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("POST", "/", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + if tc.prepareRequest != nil { + tc.prepareRequest(req) + } + + w := httptest.NewRecorder() + middlewaredHandler := middleware(testHandler) + middlewaredHandler.ServeHTTP(w, req) + assert.Equal(t, tc.expectedStatus, w.Code, "Unexpected response status") + responseBody := w.Body.String() + t.Logf("Response Body: %s", responseBody) + + }) + } +} diff --git a/pkg/plugin/implementation/requestPreProcessor/reqpreprocessor.go b/pkg/plugin/implementation/requestPreProcessor/reqpreprocessor.go new file mode 100644 index 0000000..13d4da0 --- /dev/null +++ b/pkg/plugin/implementation/requestPreProcessor/reqpreprocessor.go @@ -0,0 +1,105 @@ +package requestpreprocessor + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + "github.com/google/uuid" +) + +type Config struct { + ContextKeys []string + Role string +} + +type becknRequest struct { + Context map[string]any `json:"context"` +} + +type contextKeyType string + +const contextKey = "context" +const subscriberIDKey contextKeyType = "subscriber_id" + +func NewUUIDSetter(cfg *Config) (func(http.Handler) http.Handler, error) { + if err := validateConfig(cfg); err != nil { + return nil, err + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req becknRequest + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, "Failed to decode request body", http.StatusBadRequest) + return + } + if req.Context == nil { + http.Error(w, fmt.Sprintf("%s field not found.", contextKey), http.StatusBadRequest) + return + } + var subID any + switch cfg.Role { + case "bap": + subID = req.Context["bap_id"] + case "bpp": + subID = req.Context["bpp_id"] + } + ctx := context.WithValue(r.Context(), subscriberIDKey, subID) + for _, key := range cfg.ContextKeys { + value := uuid.NewString() + updatedValue := update(req.Context, key, value) + ctx = context.WithValue(ctx, contextKeyType(key), updatedValue) + } + reqData := map[string]any{"context": req.Context} + updatedBody, _ := json.Marshal(reqData) + r.Body = io.NopCloser(bytes.NewBuffer(updatedBody)) + r.ContentLength = int64(len(updatedBody)) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + }, nil +} + +func update(wrapper map[string]any, key string, value any) any { + field, exists := wrapper[key] + if !exists || isEmpty(field) { + wrapper[key] = value + return value + } + + return field +} +func isEmpty(v any) bool { + switch v := v.(type) { + case string: + return v == "" + case nil: + return true + default: + return false + } +} + +func validateConfig(cfg *Config) error { + if cfg == nil { + return errors.New("config cannot be nil") + } + + // Check if ContextKeys is empty. + if len(cfg.ContextKeys) == 0 { + return errors.New("ContextKeys cannot be empty") + } + + // Validate that ContextKeys does not contain empty strings. + for _, key := range cfg.ContextKeys { + if key == "" { + return errors.New("ContextKeys cannot contain empty strings") + } + } + return nil +} diff --git a/pkg/plugin/implementation/requestPreProcessor/reqpreprocessor_test.go b/pkg/plugin/implementation/requestPreProcessor/reqpreprocessor_test.go new file mode 100644 index 0000000..307a7e7 --- /dev/null +++ b/pkg/plugin/implementation/requestPreProcessor/reqpreprocessor_test.go @@ -0,0 +1,178 @@ +package requestpreprocessor + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewUUIDSetterSuccessCases(t *testing.T) { + tests := []struct { + name string + config *Config + requestBody map[string]any + expectedKeys []string + role string + }{ + { + name: "Valid keys, update missing keys with bap role", + config: &Config{ + ContextKeys: []string{"transaction_id", "message_id"}, + Role: "bap", + }, + requestBody: map[string]any{ + "context": map[string]any{ + "transaction_id": "", + "message_id": nil, + "bap_id": "bap-123", + }, + }, + expectedKeys: []string{"transaction_id", "message_id", "bap_id"}, + role: "bap", + }, + { + name: "Valid keys, do not update existing keys with bpp role", + config: &Config{ + ContextKeys: []string{"transaction_id", "message_id"}, + Role: "bpp", + }, + requestBody: map[string]any{ + "context": map[string]any{ + "transaction_id": "existing-transaction", + "message_id": "existing-message", + "bpp_id": "bpp-456", + }, + }, + expectedKeys: []string{"transaction_id", "message_id", "bpp_id"}, + role: "bpp", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := NewUUIDSetter(tt.config) + if err != nil { + t.Fatalf("Unexpected error while creating middleware: %v", err) + } + + bodyBytes, _ := json.Marshal(tt.requestBody) + req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + w.WriteHeader(http.StatusOK) + + subID, ok := ctx.Value(subscriberIDKey).(string) + if !ok { + http.Error(w, "Subscriber ID not found", http.StatusInternalServerError) + return + } + + response := map[string]any{"subscriber_id": subID} + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + }) + + middleware(dummyHandler).ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Expected status code 200, but got %d", rec.Code) + return + } + + var responseBody map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &responseBody); err != nil { + t.Fatal("Failed to unmarshal response body:", err) + } + + expectedSubIDKey := "bap_id" + if tt.role == "bpp" { + expectedSubIDKey = "bpp_id" + } + + subID, ok := responseBody["subscriber_id"].(string) + if !ok { + t.Error("subscriber_id not found in response") + return + } + + expectedSubID := tt.requestBody["context"].(map[string]any)[expectedSubIDKey] + if subID != expectedSubID { + t.Errorf("Expected subscriber_id %v, but got %v", expectedSubID, subID) + } + }) + } +} + +func TestNewUUIDSetterErrorCases(t *testing.T) { + tests := []struct { + name string + config *Config + requestBody map[string]any + expectedCode int + }{ + { + name: "Missing context key", + config: &Config{ + ContextKeys: []string{"transaction_id"}, + }, + requestBody: map[string]any{ + "otherKey": "value", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "Invalid context type", + config: &Config{ + ContextKeys: []string{"transaction_id"}, + }, + requestBody: map[string]any{ + "context": "not-a-map", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "Nil config", + config: nil, + requestBody: map[string]any{}, + expectedCode: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := NewUUIDSetter(tt.config) + if tt.config == nil { + if err == nil { + t.Error("Expected an error for nil config, but got none") + } + return + } + if err != nil { + t.Fatalf("Unexpected error while creating middleware: %v", err) + } + + bodyBytes, _ := json.Marshal(tt.requestBody) + req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware(dummyHandler).ServeHTTP(rec, req) + + if rec.Code != tt.expectedCode { + t.Errorf("Expected status code %d, but got %d", tt.expectedCode, rec.Code) + } + }) + } +}