From f0e39e34e75842c5628a15704803b85cc0964b53 Mon Sep 17 00:00:00 2001 From: MohitKatare-protean Date: Sun, 30 Mar 2025 19:13:02 +0530 Subject: [PATCH] updated as per the review comments --- Dockerfile.adapter | 25 ++ cmd/adapter/main_test.go | 15 +- core/module/handler/config.go | 2 +- core/module/handler/stdHandler.go | 8 +- core/module/handler/step.go | 6 +- core/module/module_test.go | 2 +- pkg/model/model.go | 8 + pkg/plugin/definition/keymanager.go | 16 +- pkg/plugin/definition/publisher.go | 6 +- pkg/plugin/definition/registry.go | 11 + pkg/plugin/definition/router.go | 11 +- pkg/plugin/definition/signer.go | 7 - pkg/plugin/definition/signvalidator.go | 15 + .../implementation/decrypter/cmd/plugin.go | 8 +- .../decrypter/cmd/plugin_test.go | 2 +- .../implementation/encrypter/cmd/plugin.go | 8 +- .../encrypter/cmd/plugin_test.go | 2 +- .../reqpreprocessor/cmd/plugin.go | 24 ++ .../reqpreprocessor/cmd/plugin_test.go | 85 +++++ .../reqpreprocessor/reqpreprocessor.go | 108 ++++++ .../reqpreprocessor/reqpreprocessor_test.go | 178 +++++++++ .../implementation/router/cmd/plugin.go | 4 +- .../schemavalidator/cmd/plugin.go | 33 ++ .../schemavalidator/cmd/plugin_test.go | 150 ++++++++ .../schemavalidator/schemavalidator.go | 197 ++++++++++ .../schemavalidator/schemavalidator_test.go | 353 ++++++++++++++++++ .../implementation/signer/cmd/plugin.go | 2 +- pkg/plugin/implementation/signer/signer.go | 7 +- .../signvalidator/cmd/plugin.go | 24 ++ .../signvalidator/cmd/plugin_test.go | 89 +++++ .../signvalidator/signvalidator.go | 115 ++++++ .../signvalidator/signvalidator_test.go | 147 ++++++++ pkg/plugin/manager.go | 7 +- pkg/response/response.go | 39 -- pkg/response/response_test.go | 15 - 35 files changed, 1605 insertions(+), 124 deletions(-) create mode 100644 Dockerfile.adapter create mode 100644 pkg/plugin/definition/registry.go create mode 100644 pkg/plugin/definition/signvalidator.go create mode 100644 pkg/plugin/implementation/reqpreprocessor/cmd/plugin.go create mode 100644 pkg/plugin/implementation/reqpreprocessor/cmd/plugin_test.go create mode 100644 pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go create mode 100644 pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go create mode 100644 pkg/plugin/implementation/schemavalidator/cmd/plugin.go create mode 100644 pkg/plugin/implementation/schemavalidator/cmd/plugin_test.go create mode 100644 pkg/plugin/implementation/schemavalidator/schemavalidator.go create mode 100644 pkg/plugin/implementation/schemavalidator/schemavalidator_test.go create mode 100644 pkg/plugin/implementation/signvalidator/cmd/plugin.go create mode 100644 pkg/plugin/implementation/signvalidator/cmd/plugin_test.go create mode 100644 pkg/plugin/implementation/signvalidator/signvalidator.go create mode 100644 pkg/plugin/implementation/signvalidator/signvalidator_test.go diff --git a/Dockerfile.adapter b/Dockerfile.adapter new file mode 100644 index 0000000..a8eb006 --- /dev/null +++ b/Dockerfile.adapter @@ -0,0 +1,25 @@ +FROM golang:1.24-bullseye AS builder + +WORKDIR /app +COPY cmd/adapter ./cmd/adapter +COPY core/ ./core +COPY pkg/ ./pkg +COPY go.mod . +COPY go.sum . +RUN go mod download + +RUN go build -o server cmd/adapter/main.go + +# Create a minimal runtime image +FROM cgr.dev/chainguard/wolfi-base +# ✅ Alpine is removed; using minimal Debian +WORKDIR /app + +# Copy only the built binary and plugin +COPY --from=builder /app/server . + +# Expose port 8080 +EXPOSE 8080 + +# Run the Go server with the config flag from environment variable. +CMD ["sh", "-c", "./server --config=${CONFIG_FILE}"] \ No newline at end of file diff --git a/cmd/adapter/main_test.go b/cmd/adapter/main_test.go index 1a3182d..b1b15d1 100644 --- a/cmd/adapter/main_test.go +++ b/cmd/adapter/main_test.go @@ -28,7 +28,7 @@ func (m *MockPluginManager) Middleware(ctx context.Context, cfg *plugin.Config) } // SignValidator returns a mock implementation of the Verifier interface. -func (m *MockPluginManager) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.Verifier, error) { +func (m *MockPluginManager) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.SignValidator, error) { return nil, nil } @@ -200,7 +200,7 @@ func TestRunFailure(t *testing.T) { tests := []struct { name string configData string - mockMgr func() (*plugin.Manager, func(), error) + mockMgr func() (*MockPluginManager, func(), error) mockLogger func(cfg *Config) error mockServer func(ctx context.Context, mgr handler.PluginManager, cfg *Config) (http.Handler, error) expectedErr string @@ -208,8 +208,8 @@ func TestRunFailure(t *testing.T) { { name: "Invalid Config File", configData: "invalid_config.yaml", - mockMgr: func() (*plugin.Manager, func(), error) { - return &plugin.Manager{}, func() {}, nil + mockMgr: func() (*MockPluginManager, func(), error) { + return &MockPluginManager{}, func() {}, nil }, mockLogger: func(cfg *Config) error { return nil @@ -236,9 +236,10 @@ func TestRunFailure(t *testing.T) { // Mock dependencies originalNewManager := newManagerFunc - newManagerFunc = func(ctx context.Context, cfg *plugin.ManagerConfig) (*plugin.Manager, func(), error) { - return tt.mockMgr() - } + // newManagerFunc = func(ctx context.Context, cfg *plugin.ManagerConfig) (*plugin.Manager, func(), error) { + // return tt.mockMgr() + // } + newManagerFunc = nil defer func() { newManagerFunc = originalNewManager }() originalNewServer := newServerFunc diff --git a/core/module/handler/config.go b/core/module/handler/config.go index fa2a966..16a2c0c 100644 --- a/core/module/handler/config.go +++ b/core/module/handler/config.go @@ -12,7 +12,7 @@ import ( // PluginManager defines an interface for managing plugins dynamically. type PluginManager interface { Middleware(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) - SignValidator(ctx context.Context, cfg *plugin.Config) (definition.Verifier, error) + SignValidator(ctx context.Context, cfg *plugin.Config) (definition.SignValidator, error) Validator(ctx context.Context, cfg *plugin.Config) (definition.SchemaValidator, error) Router(ctx context.Context, cfg *plugin.Config) (definition.Router, error) Publisher(ctx context.Context, cfg *plugin.Config) (definition.Publisher, error) diff --git a/core/module/handler/stdHandler.go b/core/module/handler/stdHandler.go index 32222cd..1a39d51 100644 --- a/core/module/handler/stdHandler.go +++ b/core/module/handler/stdHandler.go @@ -21,7 +21,7 @@ import ( type stdHandler struct { signer definition.Signer steps []definition.Step - signValidator definition.Verifier + signValidator definition.SignValidator cache definition.Cache km definition.KeyManager schemaValidator definition.SchemaValidator @@ -108,13 +108,15 @@ func (h *stdHandler) subID(ctx context.Context) string { return h.SubscriberID } +var proxyFunc = proxy + // route handles request forwarding or message publishing based on the routing type. func route(ctx *model.StepContext, r *http.Request, w http.ResponseWriter, pb definition.Publisher) { log.Debugf(ctx, "Routing to ctx.Route to %#v", ctx.Route) switch ctx.Route.TargetType { case "url": log.Infof(ctx.Context, "Forwarding request to URL: %s", ctx.Route.URL) - proxy(r, w, ctx.Route.URL) + proxyFunc(r, w, ctx.Route.URL) return case "publisher": if pb == nil { @@ -124,7 +126,7 @@ func route(ctx *model.StepContext, r *http.Request, w http.ResponseWriter, pb de return } log.Infof(ctx.Context, "Publishing message to: %s", ctx.Route.PublisherID) - if err := pb.Publish(ctx, ctx.Body); err != nil { + if err := pb.Publish(ctx, ctx.Route.PublisherID, ctx.Body); err != nil { log.Errorf(ctx.Context, err, "Failed to publish message") http.Error(w, "Error publishing message", http.StatusInternalServerError) response.SendNack(ctx, w, err) diff --git a/core/module/handler/step.go b/core/module/handler/step.go index f7f8abc..9074843 100644 --- a/core/module/handler/step.go +++ b/core/module/handler/step.go @@ -52,12 +52,12 @@ func (s *signStep) Run(ctx *model.StepContext) error { // validateSignStep represents the signature validation step. type validateSignStep struct { - validator definition.Verifier + validator definition.SignValidator km definition.KeyManager } // newValidateSignStep initializes and returns a new validate sign step. -func newValidateSignStep(signValidator definition.Verifier, km definition.KeyManager) (definition.Step, error) { +func newValidateSignStep(signValidator definition.SignValidator, km definition.KeyManager) (definition.Step, error) { if signValidator == nil { return nil, fmt.Errorf("invalid config: SignValidator plugin not configured") } @@ -102,7 +102,7 @@ func (s *validateSignStep) validate(ctx *model.StepContext, value string) error if err != nil { return fmt.Errorf("failed to get validation key: %w", err) } - if _, err := s.validator.Verify(ctx, ctx.Body, []byte(value), key); err != nil { + if err := s.validator.Validate(ctx, ctx.Body, value, key); err != nil { return fmt.Errorf("sign validation failed: %w", err) } return nil diff --git a/core/module/module_test.go b/core/module/module_test.go index 56901f8..a4e1106 100644 --- a/core/module/module_test.go +++ b/core/module/module_test.go @@ -23,7 +23,7 @@ func (m *mockPluginManager) Middleware(ctx context.Context, cfg *plugin.Config) } // SignValidator returns a mock verifier implementation. -func (m *mockPluginManager) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.Verifier, error) { +func (m *mockPluginManager) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.SignValidator, error) { return nil, nil } diff --git a/pkg/model/model.go b/pkg/model/model.go index 2038d94..ec5e29c 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -86,6 +86,14 @@ type Route struct { URL *url.URL // For API calls } +type Keyset struct { + UniqueKeyID string + SigningPrivate string + SigningPublic string + EncrPrivate string + EncrPublic string +} + // StepContext holds context information for a request processing step. type StepContext struct { context.Context diff --git a/pkg/plugin/definition/keymanager.go b/pkg/plugin/definition/keymanager.go index 8d037a4..f2c0e2f 100644 --- a/pkg/plugin/definition/keymanager.go +++ b/pkg/plugin/definition/keymanager.go @@ -6,18 +6,10 @@ import ( "github.com/beckn/beckn-onix/pkg/model" ) -type Keyset struct { - UniqueKeyID string - SigningPrivate string - SigningPublic string - EncrPrivate string - EncrPublic string -} - // KeyManager defines the interface for key management operations/methods. type KeyManager interface { - GenerateKeyPairs() (*Keyset, error) - StorePrivateKeys(ctx context.Context, keyID string, keys *Keyset) error + GenerateKeyPairs() (*model.Keyset, error) + StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error SigningPrivateKey(ctx context.Context, keyID string) (string, string, error) EncrPrivateKey(ctx context.Context, keyID string) (string, string, error) SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) @@ -29,7 +21,3 @@ type KeyManager interface { type KeyManagerProvider interface { New(context.Context, Cache, RegistryLookup, map[string]string) (KeyManager, func() error, error) } - -type RegistryLookup interface { - Lookup(ctx context.Context, req *model.Subscription) ([]model.Subscription, error) -} diff --git a/pkg/plugin/definition/publisher.go b/pkg/plugin/definition/publisher.go index 93f9e21..4eba687 100644 --- a/pkg/plugin/definition/publisher.go +++ b/pkg/plugin/definition/publisher.go @@ -5,12 +5,10 @@ import "context" // Publisher defines the general publisher interface for messaging plugins. type Publisher interface { // Publish sends a message (as a byte slice) using the underlying messaging system. - Publish(ctx context.Context, msg []byte) error - - Close() error // Important for releasing resources. + Publish(context.Context, string, []byte) error } type PublisherProvider interface { // New initializes a new publisher instance with the given configuration. - New(ctx context.Context, config map[string]string) (Publisher, error) + New(ctx context.Context, config map[string]string) (Publisher, func(), error) } diff --git a/pkg/plugin/definition/registry.go b/pkg/plugin/definition/registry.go new file mode 100644 index 0000000..22881f3 --- /dev/null +++ b/pkg/plugin/definition/registry.go @@ -0,0 +1,11 @@ +package definition + +import ( + "context" + + "github.com/beckn/beckn-onix/pkg/model" +) + +type RegistryLookup interface { + Lookup(ctx context.Context, req *model.Subscription) ([]model.Subscription, error) +} diff --git a/pkg/plugin/definition/router.go b/pkg/plugin/definition/router.go index 05e2e30..f30a1ca 100644 --- a/pkg/plugin/definition/router.go +++ b/pkg/plugin/definition/router.go @@ -3,14 +3,9 @@ package definition import ( "context" "net/url" -) -// Route defines the structure for the Route returned. -type Route struct { - TargetType string // "url" or "msgq" or "bap" or "bpp" - PublisherID string // For message queues - URL *url.URL // For API calls -} + "github.com/beckn/beckn-onix/pkg/model" +) // RouterProvider initializes the a new Router instance with the given config. type RouterProvider interface { @@ -20,5 +15,5 @@ type RouterProvider interface { // Router defines the interface for routing requests. type Router interface { // Route determines the routing destination based on the request context. - Route(ctx context.Context, url *url.URL, body []byte) (*Route, error) + Route(ctx context.Context, url *url.URL, body []byte) (*model.Route, error) } diff --git a/pkg/plugin/definition/signer.go b/pkg/plugin/definition/signer.go index 84db5f5..eff7bae 100644 --- a/pkg/plugin/definition/signer.go +++ b/pkg/plugin/definition/signer.go @@ -8,7 +8,6 @@ type Signer interface { // The signature is created with the given timestamps: createdAt (signature creation time) // and expiresAt (signature expiration time). Sign(ctx context.Context, body []byte, privateKeyBase64 string, createdAt, expiresAt int64) (string, error) - Close() error // Close for releasing resources } // SignerProvider initializes a new signer instance with the given config. @@ -16,9 +15,3 @@ type SignerProvider interface { // New creates a new signer instance based on the provided config. New(ctx context.Context, config map[string]string) (Signer, func() error, error) } - -// PrivateKeyManager is the interface for key management plugin. -type PrivateKeyManager interface { - // PrivateKey retrieves the private key for the given subscriberID and keyID. - PrivateKey(ctx context.Context, subscriberID string, keyID string) (string, error) -} diff --git a/pkg/plugin/definition/signvalidator.go b/pkg/plugin/definition/signvalidator.go new file mode 100644 index 0000000..e900a37 --- /dev/null +++ b/pkg/plugin/definition/signvalidator.go @@ -0,0 +1,15 @@ +package definition + +import "context" + +// SignValidator defines the method for verifying signatures. +type SignValidator interface { + // Validate checks the validity of the signature for the given body. + Validate(ctx context.Context, body []byte, header string, publicKeyBase64 string) error +} + +// SignValidatorProvider initializes a new Verifier instance with the given config. +type SignValidatorProvider interface { + // New creates a new Verifier instance based on the provided config. + New(ctx context.Context, config map[string]string) (SignValidator, func() error, error) +} diff --git a/pkg/plugin/implementation/decrypter/cmd/plugin.go b/pkg/plugin/implementation/decrypter/cmd/plugin.go index cb988a9..628e2cb 100644 --- a/pkg/plugin/implementation/decrypter/cmd/plugin.go +++ b/pkg/plugin/implementation/decrypter/cmd/plugin.go @@ -7,13 +7,13 @@ import ( decrypter "github.com/beckn/beckn-onix/pkg/plugin/implementation/decrypter" ) -// DecrypterProvider implements the definition.DecrypterProvider interface. -type DecrypterProvider struct{} +// decrypterProvider implements the definition.decrypterProvider interface. +type decrypterProvider struct{} // New creates a new Decrypter instance using the provided configuration. -func (dp DecrypterProvider) New(ctx context.Context, config map[string]string) (definition.Decrypter, func() error, error) { +func (dp decrypterProvider) New(ctx context.Context, config map[string]string) (definition.Decrypter, func() error, error) { return decrypter.New(ctx) } // Provider is the exported symbol that the plugin manager will look for. -var Provider definition.DecrypterProvider = DecrypterProvider{} +var Provider = decrypterProvider{} diff --git a/pkg/plugin/implementation/decrypter/cmd/plugin_test.go b/pkg/plugin/implementation/decrypter/cmd/plugin_test.go index 6a4f168..0e8a079 100644 --- a/pkg/plugin/implementation/decrypter/cmd/plugin_test.go +++ b/pkg/plugin/implementation/decrypter/cmd/plugin_test.go @@ -25,7 +25,7 @@ func TestDecrypterProviderSuccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - provider := DecrypterProvider{} + provider := decrypterProvider{} decrypter, cleanup, err := provider.New(tt.ctx, tt.config) // Check error. diff --git a/pkg/plugin/implementation/encrypter/cmd/plugin.go b/pkg/plugin/implementation/encrypter/cmd/plugin.go index aad52ef..31e0044 100644 --- a/pkg/plugin/implementation/encrypter/cmd/plugin.go +++ b/pkg/plugin/implementation/encrypter/cmd/plugin.go @@ -7,12 +7,12 @@ import ( "github.com/beckn/beckn-onix/pkg/plugin/implementation/encrypter" ) -// EncrypterProvider implements the definition.EncrypterProvider interface. -type EncrypterProvider struct{} +// encrypterProvider implements the definition.encrypterProvider interface. +type encrypterProvider struct{} -func (ep EncrypterProvider) New(ctx context.Context, config map[string]string) (definition.Encrypter, func() error, error) { +func (ep encrypterProvider) New(ctx context.Context, config map[string]string) (definition.Encrypter, func() error, error) { return encrypter.New(ctx) } // Provider is the exported symbol that the plugin manager will look for. -var Provider definition.EncrypterProvider = EncrypterProvider{} +var Provider = encrypterProvider{} diff --git a/pkg/plugin/implementation/encrypter/cmd/plugin_test.go b/pkg/plugin/implementation/encrypter/cmd/plugin_test.go index cbb469e..1f65450 100644 --- a/pkg/plugin/implementation/encrypter/cmd/plugin_test.go +++ b/pkg/plugin/implementation/encrypter/cmd/plugin_test.go @@ -28,7 +28,7 @@ func TestEncrypterProviderSuccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create provider and encrypter. - provider := EncrypterProvider{} + provider := encrypterProvider{} encrypter, cleanup, err := provider.New(tt.ctx, tt.config) if err != nil { t.Fatalf("EncrypterProvider.New() error = %v", err) diff --git a/pkg/plugin/implementation/reqpreprocessor/cmd/plugin.go b/pkg/plugin/implementation/reqpreprocessor/cmd/plugin.go new file mode 100644 index 0000000..b89b650 --- /dev/null +++ b/pkg/plugin/implementation/reqpreprocessor/cmd/plugin.go @@ -0,0 +1,24 @@ +package main + +import ( + "context" + "net/http" + "strings" + + "github.com/beckn/beckn-onix/pkg/plugin/implementation/reqpreprocessor" +) + +type provider struct{} + +func (p provider) New(ctx context.Context, c map[string]string) (func(http.Handler) http.Handler, error) { + config := &reqpreprocessor.Config{} + if contextKeysStr, ok := c["contextKeys"]; ok { + config.ContextKeys = strings.Split(contextKeysStr, ",") + } + if role, ok := c["role"]; ok { + config.Role = role + } + return reqpreprocessor.NewPreProcessor(config) +} + +var Provider = provider{} diff --git a/pkg/plugin/implementation/reqpreprocessor/cmd/plugin_test.go b/pkg/plugin/implementation/reqpreprocessor/cmd/plugin_test.go new file mode 100644 index 0000000..6044c44 --- /dev/null +++ b/pkg/plugin/implementation/reqpreprocessor/cmd/plugin_test.go @@ -0,0 +1,85 @@ +package main + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO: Will Split this into success and fail (two test cases) +func TestProviderNew(t *testing.T) { + testCases := []struct { + name string + config map[string]string + expectedError bool + expectedStatus int + prepareRequest func(req *http.Request) + }{ + { + name: "No Config", + config: map[string]string{}, + expectedError: true, + expectedStatus: http.StatusOK, + prepareRequest: func(req *http.Request) { + // Add minimal required headers. + req.Header.Set("context", "test-context") + req.Header.Set("transaction_id", "test-transaction") + }, + }, + { + name: "With Check Keys", + config: map[string]string{ + "contextKeys": "message_id,transaction_id", + }, + expectedError: false, + expectedStatus: http.StatusOK, + prepareRequest: func(req *http.Request) { + // Add headers matching the check keys. + req.Header.Set("context", "test-context") + req.Header.Set("transaction_id", "test-transaction") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + requestBody := `{ + "context": { + "transaction_id": "abc" + } + }` + + p := provider{} + middleware, err := p.New(context.Background(), tc.config) + if tc.expectedError { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.NotNil(t, middleware) + + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("POST", "/", strings.NewReader(requestBody)) + req.Header.Set("Content-Type", "application/json") + if tc.prepareRequest != nil { + tc.prepareRequest(req) + } + + w := httptest.NewRecorder() + middlewaredHandler := middleware(testHandler) + middlewaredHandler.ServeHTTP(w, req) + assert.Equal(t, tc.expectedStatus, w.Code, "Unexpected response status") + responseBody := w.Body.String() + t.Logf("Response Body: %s", responseBody) + + }) + } +} diff --git a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go new file mode 100644 index 0000000..020df4d --- /dev/null +++ b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go @@ -0,0 +1,108 @@ +package reqpreprocessor + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + "github.com/beckn/beckn-onix/pkg/log" + "github.com/google/uuid" +) + +type Config struct { + ContextKeys []string + Role string +} + +type becknRequest struct { + Context map[string]any `json:"context"` +} + +const contextKey = "context" +const subscriberIDKey = "subscriber_id" + +func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) { + if err := validateConfig(cfg); err != nil { + return nil, err + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req becknRequest + ctx := r.Context() + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, "Failed to decode request body", http.StatusBadRequest) + return + } + if req.Context == nil { + http.Error(w, fmt.Sprintf("%s field not found.", contextKey), http.StatusBadRequest) + return + } + var subID any + switch cfg.Role { + case "bap": + subID = req.Context["bap_id"] + case "bpp": + subID = req.Context["bpp_id"] + } + if subID != nil { + log.Debugf(ctx, "adding subscriberId to request:%s, %v", subscriberIDKey, subID) + ctx = context.WithValue(ctx, subscriberIDKey, subID) + } + for _, key := range cfg.ContextKeys { + value := uuid.NewString() + updatedValue := update(req.Context, key, value) + ctx = context.WithValue(ctx, key, updatedValue) + } + reqData := map[string]any{"context": req.Context} + updatedBody, _ := json.Marshal(reqData) + r.Body = io.NopCloser(bytes.NewBuffer(updatedBody)) + r.ContentLength = int64(len(updatedBody)) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + }, nil +} + +func update(wrapper map[string]any, key string, value any) any { + field, exists := wrapper[key] + if !exists || isEmpty(field) { + wrapper[key] = value + return value + } + + return field +} +func isEmpty(v any) bool { + switch v := v.(type) { + case string: + return v == "" + case nil: + return true + default: + return false + } +} + +func validateConfig(cfg *Config) error { + if cfg == nil { + return errors.New("config cannot be nil") + } + + // Check if ContextKeys is empty. + if len(cfg.ContextKeys) == 0 { + return errors.New("ContextKeys cannot be empty") + } + + // Validate that ContextKeys does not contain empty strings. + for _, key := range cfg.ContextKeys { + if key == "" { + return errors.New("ContextKeys cannot contain empty strings") + } + } + return nil +} diff --git a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go new file mode 100644 index 0000000..d70af8e --- /dev/null +++ b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go @@ -0,0 +1,178 @@ +package reqpreprocessor + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewUUIDSetterSuccessCases(t *testing.T) { + tests := []struct { + name string + config *Config + requestBody map[string]any + expectedKeys []string + role string + }{ + { + name: "Valid keys, update missing keys with bap role", + config: &Config{ + ContextKeys: []string{"transaction_id", "message_id"}, + Role: "bap", + }, + requestBody: map[string]any{ + "context": map[string]any{ + "transaction_id": "", + "message_id": nil, + "bap_id": "bap-123", + }, + }, + expectedKeys: []string{"transaction_id", "message_id", "bap_id"}, + role: "bap", + }, + { + name: "Valid keys, do not update existing keys with bpp role", + config: &Config{ + ContextKeys: []string{"transaction_id", "message_id"}, + Role: "bpp", + }, + requestBody: map[string]any{ + "context": map[string]any{ + "transaction_id": "existing-transaction", + "message_id": "existing-message", + "bpp_id": "bpp-456", + }, + }, + expectedKeys: []string{"transaction_id", "message_id", "bpp_id"}, + role: "bpp", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := NewPreProcessor(tt.config) + if err != nil { + t.Fatalf("Unexpected error while creating middleware: %v", err) + } + + bodyBytes, _ := json.Marshal(tt.requestBody) + req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + w.WriteHeader(http.StatusOK) + + subID, ok := ctx.Value(subscriberIDKey).(string) + if !ok { + http.Error(w, "Subscriber ID not found", http.StatusInternalServerError) + return + } + + response := map[string]any{"subscriber_id": subID} + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + }) + + middleware(dummyHandler).ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Expected status code 200, but got %d", rec.Code) + return + } + + var responseBody map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &responseBody); err != nil { + t.Fatal("Failed to unmarshal response body:", err) + } + + expectedSubIDKey := "bap_id" + if tt.role == "bpp" { + expectedSubIDKey = "bpp_id" + } + + subID, ok := responseBody["subscriber_id"].(string) + if !ok { + t.Error("subscriber_id not found in response") + return + } + + expectedSubID := tt.requestBody["context"].(map[string]any)[expectedSubIDKey] + if subID != expectedSubID { + t.Errorf("Expected subscriber_id %v, but got %v", expectedSubID, subID) + } + }) + } +} + +func TestNewUUIDSetterErrorCases(t *testing.T) { + tests := []struct { + name string + config *Config + requestBody map[string]any + expectedCode int + }{ + { + name: "Missing context key", + config: &Config{ + ContextKeys: []string{"transaction_id"}, + }, + requestBody: map[string]any{ + "otherKey": "value", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "Invalid context type", + config: &Config{ + ContextKeys: []string{"transaction_id"}, + }, + requestBody: map[string]any{ + "context": "not-a-map", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "Nil config", + config: nil, + requestBody: map[string]any{}, + expectedCode: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := NewPreProcessor(tt.config) + if tt.config == nil { + if err == nil { + t.Error("Expected an error for nil config, but got none") + } + return + } + if err != nil { + t.Fatalf("Unexpected error while creating middleware: %v", err) + } + + bodyBytes, _ := json.Marshal(tt.requestBody) + req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + middleware(dummyHandler).ServeHTTP(rec, req) + + if rec.Code != tt.expectedCode { + t.Errorf("Expected status code %d, but got %d", tt.expectedCode, rec.Code) + } + }) + } +} diff --git a/pkg/plugin/implementation/router/cmd/plugin.go b/pkg/plugin/implementation/router/cmd/plugin.go index 556f129..d5d71e3 100644 --- a/pkg/plugin/implementation/router/cmd/plugin.go +++ b/pkg/plugin/implementation/router/cmd/plugin.go @@ -4,8 +4,8 @@ import ( "context" "errors" - definition "github.com/beckn/beckn-onix/pkg/plugin/definition" - router "github.com/beckn/beckn-onix/pkg/plugin/implementation/router" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/beckn/beckn-onix/pkg/plugin/implementation/router" ) // RouterProvider provides instances of Router. diff --git a/pkg/plugin/implementation/schemavalidator/cmd/plugin.go b/pkg/plugin/implementation/schemavalidator/cmd/plugin.go new file mode 100644 index 0000000..f71aaaf --- /dev/null +++ b/pkg/plugin/implementation/schemavalidator/cmd/plugin.go @@ -0,0 +1,33 @@ +package main + +import ( + "context" + "errors" + + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/beckn/beckn-onix/pkg/plugin/implementation/schemavalidator" +) + +// schemaValidatorProvider provides instances of schemaValidator. +type schemaValidatorProvider struct{} + +// New initializes a new Verifier instance. +func (vp schemaValidatorProvider) New(ctx context.Context, config map[string]string) (definition.SchemaValidator, func() error, error) { + if ctx == nil { + return nil, nil, errors.New("context cannot be nil") + } + + // Extract schemaDir from the config map + schemaDir, ok := config["schemaDir"] + if !ok || schemaDir == "" { + return nil, nil, errors.New("config must contain 'schemaDir'") + } + + // Create a new schemaValidator instance with the provided configuration + return schemavalidator.New(ctx, &schemavalidator.Config{ + SchemaDir: schemaDir, + }) +} + +// Provider is the exported symbol that the plugin manager will look for. +var Provider = schemaValidatorProvider{} diff --git a/pkg/plugin/implementation/schemavalidator/cmd/plugin_test.go b/pkg/plugin/implementation/schemavalidator/cmd/plugin_test.go new file mode 100644 index 0000000..75fdce0 --- /dev/null +++ b/pkg/plugin/implementation/schemavalidator/cmd/plugin_test.go @@ -0,0 +1,150 @@ +package main + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// setupTestSchema creates a temporary directory and writes a sample schema file. +func setupTestSchema(t *testing.T) string { + t.Helper() + + // Create a temporary directory for the schema + schemaDir, err := os.MkdirTemp("", "schemas") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + + // Create the directory structure for the schema file + schemaFilePath := filepath.Join(schemaDir, "example", "1.0", "test_schema.json") + if err := os.MkdirAll(filepath.Dir(schemaFilePath), 0755); err != nil { + t.Fatalf("Failed to create schema directory structure: %v", err) + } + + // Define a sample schema + schemaContent := `{ + "type": "object", + "properties": { + "context": { + "type": "object", + "properties": { + "domain": {"type": "string"}, + "version": {"type": "string"} + }, + "required": ["domain", "version"] + } + }, + "required": ["context"] + }` + + // Write the schema to the file + if err := os.WriteFile(schemaFilePath, []byte(schemaContent), 0644); err != nil { + t.Fatalf("Failed to write schema file: %v", err) + } + + return schemaDir +} + +// TestValidatorProviderSuccess tests successful ValidatorProvider implementation. +func TestValidatorProviderSuccess(t *testing.T) { + schemaDir := setupTestSchema(t) + defer os.RemoveAll(schemaDir) + + // Define test cases. + tests := []struct { + name string + ctx context.Context + config map[string]string + expectedError string + }{ + { + name: "Valid schema directory", + ctx: context.Background(), // Valid context + config: map[string]string{"schemaDir": schemaDir}, + expectedError: "", + }, + } + + // Test using table-driven tests + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vp := schemaValidatorProvider{} + schemaValidator, _, err := vp.New(tt.ctx, tt.config) + + // Ensure no error occurred + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Ensure the schemaValidator is not nil + if schemaValidator == nil { + t.Error("expected a non-nil schemaValidator, got nil") + } + }) + } +} + +// TestValidatorProviderSuccess tests cases where ValidatorProvider creation should fail. +func TestValidatorProviderFailure(t *testing.T) { + schemaDir := setupTestSchema(t) + defer os.RemoveAll(schemaDir) + + // Define test cases. + tests := []struct { + name string + ctx context.Context + config map[string]string + expectedError string + }{ + { + name: "Config is empty", + ctx: context.Background(), + config: map[string]string{}, + expectedError: "config must contain 'schemaDir'", + }, + { + name: "schemaDir is empty", + ctx: context.Background(), + config: map[string]string{"schemaDir": ""}, + expectedError: "config must contain 'schemaDir'", + }, + { + name: "Invalid schema directory", + ctx: context.Background(), // Valid context + config: map[string]string{"schemaDir": "/invalid/dir"}, + expectedError: "failed to initialise schemaValidator: schema directory does not exist: /invalid/dir", + }, + { + name: "Nil context", + ctx: nil, // Nil context + config: map[string]string{"schemaDir": schemaDir}, + expectedError: "context cannot be nil", + }, + } + + // Test using table-driven tests + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vp := schemaValidatorProvider{} + _, _, err := vp.New(tt.ctx, tt.config) + + // Check for expected error + if tt.expectedError != "" { + if err == nil || !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error %q, got %v", tt.expectedError, err) + } + return + } + + // Ensure no error occurred + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + }) + } +} diff --git a/pkg/plugin/implementation/schemavalidator/schemavalidator.go b/pkg/plugin/implementation/schemavalidator/schemavalidator.go new file mode 100644 index 0000000..715def7 --- /dev/null +++ b/pkg/plugin/implementation/schemavalidator/schemavalidator.go @@ -0,0 +1,197 @@ +package schemavalidator + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "os" + "path" + "path/filepath" + "strings" + + "github.com/beckn/beckn-onix/pkg/model" + + "github.com/santhosh-tekuri/jsonschema/v6" +) + +// Payload represents the structure of the data payload with context information. +type payload struct { + Context struct { + Domain string `json:"domain"` + Version string `json:"version"` + } `json:"context"` +} + +// schemaValidator implements the Validator interface. +type schemaValidator struct { + config *Config + schemaCache map[string]*jsonschema.Schema +} + +// Config struct for SchemaValidator. +type Config struct { + SchemaDir string +} + +// New creates a new ValidatorProvider instance. +func New(ctx context.Context, config *Config) (*schemaValidator, func() error, error) { + // Check if config is nil + if config == nil { + return nil, nil, fmt.Errorf("config cannot be nil") + } + v := &schemaValidator{ + config: config, + schemaCache: make(map[string]*jsonschema.Schema), + } + + // Call Initialise function to load schemas and get validators + if err := v.initialise(); err != nil { + return nil, nil, fmt.Errorf("failed to initialise schemaValidator: %v", err) + } + return v, nil, nil +} + +// Validate validates the given data against the schema. +func (v *schemaValidator) Validate(ctx context.Context, url *url.URL, data []byte) error { + var payloadData payload + err := json.Unmarshal(data, &payloadData) + if err != nil { + return fmt.Errorf("failed to parse JSON payload: %v", err) + } + + // Extract domain, version, and endpoint from the payload and uri. + cxtDomain := payloadData.Context.Domain + version := payloadData.Context.Version + version = fmt.Sprintf("v%s", version) + + endpoint := path.Base(url.String()) + // ToDo Add debug log here + fmt.Println("Handling request for endpoint:", endpoint) + domain := strings.ToLower(cxtDomain) + domain = strings.ReplaceAll(domain, ":", "_") + + // Construct the schema file name. + schemaFileName := fmt.Sprintf("%s_%s_%s", domain, version, endpoint) + + // Retrieve the schema from the cache. + schema, exists := v.schemaCache[schemaFileName] + if !exists { + return fmt.Errorf("schema not found for domain: %s", schemaFileName) + } + + var jsonData any + if err := json.Unmarshal(data, &jsonData); err != nil { + return fmt.Errorf("failed to parse JSON data: %v", err) + } + err = schema.Validate(jsonData) + if err != nil { + // Handle schema validation errors + if validationErr, ok := err.(*jsonschema.ValidationError); ok { + // Convert validation errors into an array of SchemaValError + var schemaErrors []model.Error + for _, cause := range validationErr.Causes { + // Extract the path and message from the validation error + path := strings.Join(cause.InstanceLocation, ".") // JSON path to the invalid field + message := cause.Error() // Validation error message + + // Append the error to the schemaErrors array + schemaErrors = append(schemaErrors, model.Error{ + Paths: path, + Message: message, + }) + } + // Return the array of schema validation errors + return &model.SchemaValidationErr{Errors: schemaErrors} + } + // Return a generic error for non-validation errors + return fmt.Errorf("validation failed: %v", err) + } + + // Return nil if validation succeeds + return nil +} + +// ValidatorProvider provides instances of Validator. +type ValidatorProvider struct{} + +// Initialise initialises the validator provider by compiling all the JSON schema files +// from the specified directory and storing them in a cache indexed by their schema filenames. +func (v *schemaValidator) initialise() error { + schemaDir := v.config.SchemaDir + // Check if the directory exists and is accessible. + info, err := os.Stat(schemaDir) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("schema directory does not exist: %s", schemaDir) + } + return fmt.Errorf("failed to access schema directory: %v", err) + } + if !info.IsDir() { + return fmt.Errorf("provided schema path is not a directory: %s", schemaDir) + } + + compiler := jsonschema.NewCompiler() + + // Helper function to process directories recursively. + var processDir func(dir string) error + processDir = func(dir string) error { + entries, err := os.ReadDir(dir) + if err != nil { + return fmt.Errorf("failed to read directory: %v", err) + } + + for _, entry := range entries { + path := filepath.Join(dir, entry.Name()) + if entry.IsDir() { + // Recursively process subdirectories. + if err := processDir(path); err != nil { + return err + } + } else if filepath.Ext(entry.Name()) == ".json" { + // Process JSON files. + compiledSchema, err := compiler.Compile(path) + if err != nil { + return fmt.Errorf("failed to compile JSON schema from file %s: %v", entry.Name(), err) + } + + // Use relative path from schemaDir to avoid absolute paths and make schema keys domain/version specific. + relativePath, err := filepath.Rel(schemaDir, path) + if err != nil { + return fmt.Errorf("failed to get relative path for file %s: %v", entry.Name(), err) + } + // Split the relative path to get domain, version, and schema. + parts := strings.Split(relativePath, string(os.PathSeparator)) + + // Ensure that the file path has at least 3 parts: domain, version, and schema file. + if len(parts) < 3 { + return fmt.Errorf("invalid schema file structure, expected domain/version/schema.json but got: %s", relativePath) + } + + // Extract domain, version, and schema filename from the parts. + // Validate that the extracted parts are non-empty. + domain := strings.TrimSpace(parts[0]) + version := strings.TrimSpace(parts[1]) + schemaFileName := strings.TrimSpace(parts[2]) + schemaFileName = strings.TrimSuffix(schemaFileName, ".json") + + if domain == "" || version == "" || schemaFileName == "" { + return fmt.Errorf("invalid schema file structure, one or more components are empty. Relative path: %s", relativePath) + } + + // Construct a unique key combining domain, version, and schema name (e.g., ondc_trv10_v2.0.0_schema). + uniqueKey := fmt.Sprintf("%s_%s_%s", domain, version, schemaFileName) + // Store the compiled schema in the SchemaCache using the unique key. + v.schemaCache[uniqueKey] = compiledSchema + } + } + return nil + } + + // Start processing from the root schema directory. + if err := processDir(schemaDir); err != nil { + return fmt.Errorf("failed to read schema directory: %v", err) + } + + return nil +} diff --git a/pkg/plugin/implementation/schemavalidator/schemavalidator_test.go b/pkg/plugin/implementation/schemavalidator/schemavalidator_test.go new file mode 100644 index 0000000..bdb4201 --- /dev/null +++ b/pkg/plugin/implementation/schemavalidator/schemavalidator_test.go @@ -0,0 +1,353 @@ +package schemavalidator + +import ( + "context" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/santhosh-tekuri/jsonschema/v6" +) + +// setupTestSchema creates a temporary directory and writes a sample schema file. +func setupTestSchema(t *testing.T) string { + t.Helper() + + // Create a temporary directory for the schema + schemaDir, err := os.MkdirTemp("", "schemas") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + + // Create the directory structure for the schema file + schemaFilePath := filepath.Join(schemaDir, "example", "v1.0", "endpoint.json") + if err := os.MkdirAll(filepath.Dir(schemaFilePath), 0755); err != nil { + t.Fatalf("Failed to create schema directory structure: %v", err) + } + + // Define a sample schema + schemaContent := `{ + "type": "object", + "properties": { + "context": { + "type": "object", + "properties": { + "domain": {"type": "string"}, + "version": {"type": "string"}, + "action": {"type": "string"} + }, + "required": ["domain", "version", "action"] + } + }, + "required": ["context"] + }` + + // Write the schema to the file + if err := os.WriteFile(schemaFilePath, []byte(schemaContent), 0644); err != nil { + t.Fatalf("Failed to write schema file: %v", err) + } + + return schemaDir +} + +func TestValidator_Validate_Success(t *testing.T) { + tests := []struct { + name string + url string + payload string + wantErr bool + }{ + { + name: "Valid payload", + url: "http://example.com/endpoint", + payload: `{"context": {"domain": "example", "version": "1.0", "action": "endpoint"}}`, + wantErr: false, + }, + } + + // Setup a temporary schema directory for testing + schemaDir := setupTestSchema(t) + defer os.RemoveAll(schemaDir) + + config := &Config{SchemaDir: schemaDir} + v, _, err := New(context.Background(), config) + if err != nil { + t.Fatalf("Failed to create validator: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.url) + err := v.Validate(context.Background(), u, []byte(tt.payload)) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } else { + t.Logf("Test %s passed with no errors", tt.name) + } + }) + } +} + +func TestValidator_Validate_Failure(t *testing.T) { + tests := []struct { + name string + url string + payload string + wantErr string + }{ + { + name: "Invalid JSON payload", + url: "http://example.com/endpoint", + payload: `{"context": {"domain": "example", "version": "1.0"`, + wantErr: "failed to parse JSON payload", + }, + { + name: "Schema validation failure", + url: "http://example.com/endpoint", + payload: `{"context": {"domain": "example", "version": "1.0"}}`, + wantErr: "context: at '/context': missing property 'action'", + }, + { + name: "Schema not found", + url: "http://example.com/unknown_endpoint", + payload: `{"context": {"domain": "example", "version": "1.0"}}`, + wantErr: "schema not found for domain", + }, + } + + // Setup a temporary schema directory for testing + schemaDir := setupTestSchema(t) + defer os.RemoveAll(schemaDir) + + config := &Config{SchemaDir: schemaDir} + v, _, err := New(context.Background(), config) + if err != nil { + t.Fatalf("Failed to create validator: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, _ := url.Parse(tt.url) + err := v.Validate(context.Background(), u, []byte(tt.payload)) + if tt.wantErr != "" { + if err == nil { + t.Errorf("Expected error containing '%s', but got nil", tt.wantErr) + } else if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("Expected error containing '%s', but got '%v'", tt.wantErr, err) + } else { + t.Logf("Test %s passed with expected error: %v", tt.name, err) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } else { + t.Logf("Test %s passed with no errors", tt.name) + } + } + }) + } +} + +func TestValidator_Initialise(t *testing.T) { + tests := []struct { + name string + setupFunc func(schemaDir string) error + wantErr string + }{ + { + name: "Schema directory does not exist", + setupFunc: func(schemaDir string) error { + // Do not create the schema directory + return nil + + }, + wantErr: "schema directory does not exist", + }, + { + name: "Schema path is not a directory", + setupFunc: func(schemaDir string) error { + // Create a file instead of a directory + return os.WriteFile(schemaDir, []byte{}, 0644) + }, + wantErr: "provided schema path is not a directory", + }, + { + name: "Invalid schema file structure", + setupFunc: func(schemaDir string) error { + // Create an invalid schema file structure + invalidSchemaFile := filepath.Join(schemaDir, "invalid_schema.json") + if err := os.MkdirAll(filepath.Dir(invalidSchemaFile), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + return os.WriteFile(invalidSchemaFile, []byte(`{}`), 0644) + }, + wantErr: "invalid schema file structure", + }, + { + name: "Failed to compile JSON schema", + setupFunc: func(schemaDir string) error { + // Create a schema file with invalid JSON + invalidSchemaFile := filepath.Join(schemaDir, "example", "1.0", "endpoint.json") + if err := os.MkdirAll(filepath.Dir(invalidSchemaFile), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + return os.WriteFile(invalidSchemaFile, []byte(`{invalid json}`), 0644) + }, + wantErr: "failed to compile JSON schema", + }, + { + name: "Invalid schema file structure with empty components", + setupFunc: func(schemaDir string) error { + // Create a schema file with empty domain, version, or schema name + invalidSchemaFile := filepath.Join(schemaDir, "", "1.0", "endpoint.json") + if err := os.MkdirAll(filepath.Dir(invalidSchemaFile), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + return os.WriteFile(invalidSchemaFile, []byte(`{ + "type": "object", + "properties": { + "context": { + "type": "object", + "properties": { + "domain": {"type": "string"}, + "version": {"type": "string"} + }, + "required": ["domain", "version"] + } + }, + "required": ["context"] + }`), 0644) + }, + wantErr: "failed to read schema directory: invalid schema file structure, expected domain/version/schema.json but got: 1.0/endpoint.json", + }, + { + name: "Failed to read directory", + setupFunc: func(schemaDir string) error { + // Create a directory and remove read permissions + if err := os.MkdirAll(schemaDir, 0000); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + return nil + }, + wantErr: "failed to read directory", + }, + { + name: "Valid schema directory", + setupFunc: func(schemaDir string) error { + // Create a valid schema file + validSchemaFile := filepath.Join(schemaDir, "example", "1.0", "endpoint.json") + if err := os.MkdirAll(filepath.Dir(validSchemaFile), 0755); err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + return os.WriteFile(validSchemaFile, []byte(`{ + "type": "object", + "properties": { + "context": { + "type": "object", + "properties": { + "domain": {"type": "string"}, + "version": {"type": "string"} + }, + "required": ["domain", "version"] + } + }, + "required": ["context"] + }`), 0644) + }, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup a temporary schema directory for testing + schemaDir := filepath.Join(os.TempDir(), "schemas") + defer os.RemoveAll(schemaDir) + + // Run the setup function to prepare the test case + if err := tt.setupFunc(schemaDir); err != nil { + t.Fatalf("setupFunc() error = %v", err) + } + + config := &Config{SchemaDir: schemaDir} + v := &schemaValidator{ + config: config, + schemaCache: make(map[string]*jsonschema.Schema), + } + + err := v.initialise() + if (err != nil && !strings.Contains(err.Error(), tt.wantErr)) || (err == nil && tt.wantErr != "") { + t.Errorf("Error: initialise() returned error = %v, expected error = %v", err, tt.wantErr) + } else if err == nil { + t.Logf("Test %s passed: validator initialized successfully", tt.name) + } else { + t.Logf("Test %s passed with expected error: %v", tt.name, err) + } + }) + } +} + +func TestValidatorNew_Success(t *testing.T) { + schemaDir := setupTestSchema(t) + defer os.RemoveAll(schemaDir) + + config := &Config{SchemaDir: schemaDir} + _, _, err := New(context.Background(), config) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestValidatorNewFailure(t *testing.T) { + tests := []struct { + name string + config *Config + setupFunc func(schemaDir string) error + wantErr string + }{ + { + name: "Config is nil", + config: nil, + setupFunc: func(schemaDir string) error { + return nil + }, + wantErr: "config cannot be nil", + }, + { + name: "Failed to initialise validators", + config: &Config{ + SchemaDir: "/invalid/path", + }, + setupFunc: func(schemaDir string) error { + // Do not create the schema directory + return nil + }, + wantErr: "ailed to initialise schemaValidator: schema directory does not exist: /invalid/path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Run the setup function if provided + if tt.setupFunc != nil { + schemaDir := "" + if tt.config != nil { + schemaDir = tt.config.SchemaDir + } + if err := tt.setupFunc(schemaDir); err != nil { + t.Fatalf("Setup function failed: %v", err) + } + } + + // Call the New function with the test config + _, _, err := New(context.Background(), tt.config) + if (err != nil && !strings.Contains(err.Error(), tt.wantErr)) || (err == nil && tt.wantErr != "") { + t.Errorf("Error: New() returned error = %v, expected error = %v", err, tt.wantErr) + } else { + t.Logf("Test %s passed with expected error: %v", tt.name, err) + } + }) + } +} diff --git a/pkg/plugin/implementation/signer/cmd/plugin.go b/pkg/plugin/implementation/signer/cmd/plugin.go index 2d78d98..1df515f 100644 --- a/pkg/plugin/implementation/signer/cmd/plugin.go +++ b/pkg/plugin/implementation/signer/cmd/plugin.go @@ -21,4 +21,4 @@ func (p SignerProvider) New(ctx context.Context, config map[string]string) (defi } // Provider is the exported symbol that the plugin manager will look for. -var Provider definition.SignerProvider = SignerProvider{} +var Provider = SignerProvider{} diff --git a/pkg/plugin/implementation/signer/signer.go b/pkg/plugin/implementation/signer/signer.go index c1f2af9..66015e4 100644 --- a/pkg/plugin/implementation/signer/signer.go +++ b/pkg/plugin/implementation/signer/signer.go @@ -23,7 +23,7 @@ type Signer struct { func New(ctx context.Context, config *Config) (*Signer, func() error, error) { s := &Signer{config: config} - return s, s.Close, nil + return s, nil, nil } // hash generates a signing string using BLAKE-512 hashing. @@ -70,8 +70,3 @@ func (s *Signer) Sign(ctx context.Context, body []byte, privateKeyBase64 string, return base64.StdEncoding.EncodeToString(signature), nil } - -// Close releases resources (mock implementation returning nil). -func (s *Signer) Close() error { - return nil -} diff --git a/pkg/plugin/implementation/signvalidator/cmd/plugin.go b/pkg/plugin/implementation/signvalidator/cmd/plugin.go new file mode 100644 index 0000000..947f956 --- /dev/null +++ b/pkg/plugin/implementation/signvalidator/cmd/plugin.go @@ -0,0 +1,24 @@ +package main + +import ( + "context" + "errors" + + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/beckn/beckn-onix/pkg/plugin/implementation/signvalidator" +) + +// provider provides instances of Verifier. +type provider struct{} + +// New initializes a new Verifier instance. +func (vp provider) New(ctx context.Context, config map[string]string) (definition.SignValidator, func() error, error) { + if ctx == nil { + return nil, nil, errors.New("context cannot be nil") + } + + return signvalidator.New(ctx, &signvalidator.Config{}) +} + +// Provider is the exported symbol that the plugin manager will look for. +var Provider = provider{} diff --git a/pkg/plugin/implementation/signvalidator/cmd/plugin_test.go b/pkg/plugin/implementation/signvalidator/cmd/plugin_test.go new file mode 100644 index 0000000..a001ebf --- /dev/null +++ b/pkg/plugin/implementation/signvalidator/cmd/plugin_test.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "testing" +) + +// TestVerifierProviderSuccess tests successful creation of a verifier. +func TestVerifierProviderSuccess(t *testing.T) { + provider := provider{} + + tests := []struct { + name string + ctx context.Context + config map[string]string + }{ + { + name: "Successful creation", + ctx: context.Background(), + config: map[string]string{}, + }, + { + name: "Nil context", + ctx: context.TODO(), + config: map[string]string{}, + }, + { + name: "Empty config", + ctx: context.Background(), + config: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + verifier, close, err := provider.New(tt.ctx, tt.config) + + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + if verifier == nil { + t.Fatal("Expected verifier instance to be non-nil") + } + if close != nil { + if err := close(); err != nil { + t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err) + } + } + }) + } +} + +// TestVerifierProviderFailure tests cases where verifier creation should fail. +func TestVerifierProviderFailure(t *testing.T) { + provider := provider{} + + tests := []struct { + name string + ctx context.Context + config map[string]string + wantErr bool + }{ + { + name: "Nil context failure", + ctx: nil, + config: map[string]string{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + verifierInstance, close, err := provider.New(tt.ctx, tt.config) + + if (err != nil) != tt.wantErr { + t.Fatalf("Expected error: %v, but got: %v", tt.wantErr, err) + } + if verifierInstance != nil { + t.Fatal("Expected verifier instance to be nil") + } + if close != nil { + if err := close(); err != nil { + t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err) + } + } + + }) + } +} diff --git a/pkg/plugin/implementation/signvalidator/signvalidator.go b/pkg/plugin/implementation/signvalidator/signvalidator.go new file mode 100644 index 0000000..c381d40 --- /dev/null +++ b/pkg/plugin/implementation/signvalidator/signvalidator.go @@ -0,0 +1,115 @@ +package signvalidator + +import ( + "context" + "crypto/ed25519" + "encoding/base64" + "fmt" + "strconv" + "strings" + "time" + + "golang.org/x/crypto/blake2b" +) + +// Config struct for Verifier. +type Config struct { +} + +// validator implements the validator interface. +type validator struct { + config *Config +} + +// New creates a new Verifier instance. +func New(ctx context.Context, config *Config) (*validator, func() error, error) { + v := &validator{config: config} + + return v, nil, nil +} + +// Verify checks the signature for the given payload and public key. +func (v *validator) Validate(ctx context.Context, body []byte, header string, publicKeyBase64 string) error { + createdTimestamp, expiredTimestamp, signature, err := parseAuthHeader(header) + if err != nil { + // TODO: Return appropriate error code when Error Code Handling Module is ready + return fmt.Errorf("error parsing header: %w", err) + } + + signatureBytes, err := base64.StdEncoding.DecodeString(signature) + if err != nil { + // TODO: Return appropriate error code when Error Code Handling Module is ready + return fmt.Errorf("error decoding signature: %w", err) + } + + currentTime := time.Now().Unix() + if createdTimestamp > currentTime || currentTime > expiredTimestamp { + // TODO: Return appropriate error code when Error Code Handling Module is ready + return fmt.Errorf("signature is expired or not yet valid") + } + + createdTime := time.Unix(createdTimestamp, 0) + expiredTime := time.Unix(expiredTimestamp, 0) + + signingString := hash(body, createdTime.Unix(), expiredTime.Unix()) + + decodedPublicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64) + if err != nil { + // TODO: Return appropriate error code when Error Code Handling Module is ready + return fmt.Errorf("error decoding public key: %w", err) + } + + if !ed25519.Verify(ed25519.PublicKey(decodedPublicKey), []byte(signingString), signatureBytes) { + // TODO: Return appropriate error code when Error Code Handling Module is ready + return fmt.Errorf("signature verification failed") + } + + return nil +} + +// parseAuthHeader extracts signature values from the Authorization header. +func parseAuthHeader(header string) (int64, int64, string, error) { + header = strings.TrimPrefix(header, "Signature ") + + parts := strings.Split(header, ",") + signatureMap := make(map[string]string) + + for _, part := range parts { + keyValue := strings.SplitN(strings.TrimSpace(part), "=", 2) + if len(keyValue) == 2 { + key := strings.TrimSpace(keyValue[0]) + value := strings.Trim(keyValue[1], "\"") + signatureMap[key] = value + } + } + + createdTimestamp, err := strconv.ParseInt(signatureMap["created"], 10, 64) + if err != nil { + // TODO: Return appropriate error code when Error Code Handling Module is ready + return 0, 0, "", fmt.Errorf("invalid created timestamp: %w", err) + } + + expiredTimestamp, err := strconv.ParseInt(signatureMap["expires"], 10, 64) + if err != nil { + // TODO: Return appropriate error code when Error Code Handling Module is ready + return 0, 0, "", fmt.Errorf("invalid expires timestamp: %w", err) + } + + signature := signatureMap["signature"] + if signature == "" { + // TODO: Return appropriate error code when Error Code Handling Module is ready + return 0, 0, "", fmt.Errorf("signature missing in header") + } + + return createdTimestamp, expiredTimestamp, signature, nil +} + +// hash constructs a signing string for verification. +func hash(payload []byte, createdTimestamp, expiredTimestamp int64) string { + hasher, _ := blake2b.New512(nil) + hasher.Write(payload) + hashSum := hasher.Sum(nil) + digestB64 := base64.StdEncoding.EncodeToString(hashSum) + + return fmt.Sprintf("(created): %d\n(expires): %d\ndigest: BLAKE-512=%s", createdTimestamp, expiredTimestamp, digestB64) +} diff --git a/pkg/plugin/implementation/signvalidator/signvalidator_test.go b/pkg/plugin/implementation/signvalidator/signvalidator_test.go new file mode 100644 index 0000000..160d28b --- /dev/null +++ b/pkg/plugin/implementation/signvalidator/signvalidator_test.go @@ -0,0 +1,147 @@ +package signvalidator + +import ( + "context" + "crypto/ed25519" + "encoding/base64" + "strconv" + "testing" + "time" +) + +// generateTestKeyPair generates a new ED25519 key pair for testing. +func generateTestKeyPair() (string, string) { + publicKey, privateKey, _ := ed25519.GenerateKey(nil) + return base64.StdEncoding.EncodeToString(privateKey), base64.StdEncoding.EncodeToString(publicKey) +} + +// signTestData creates a valid signature for test cases. +func signTestData(privateKeyBase64 string, body []byte, createdAt, expiresAt int64) string { + privateKeyBytes, _ := base64.StdEncoding.DecodeString(privateKeyBase64) + privateKey := ed25519.PrivateKey(privateKeyBytes) + + signingString := hash(body, createdAt, expiresAt) + signature := ed25519.Sign(privateKey, []byte(signingString)) + + return base64.StdEncoding.EncodeToString(signature) +} + +// TestVerifySuccessCases tests all valid signature verification cases. +func TestVerifySuccess(t *testing.T) { + privateKeyBase64, publicKeyBase64 := generateTestKeyPair() + + tests := []struct { + name string + body []byte + createdAt int64 + expiresAt int64 + }{ + { + name: "Valid Signature", + body: []byte("Test Payload"), + createdAt: time.Now().Unix(), + expiresAt: time.Now().Unix() + 3600, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + signature := signTestData(privateKeyBase64, tt.body, tt.createdAt, tt.expiresAt) + header := `Signature created="` + strconv.FormatInt(tt.createdAt, 10) + + `", expires="` + strconv.FormatInt(tt.expiresAt, 10) + + `", signature="` + signature + `"` + + verifier, close, _ := New(context.Background(), &Config{}) + err := verifier.Validate(context.Background(), tt.body, header, publicKeyBase64) + + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + if close != nil { + if err := close(); err != nil { + t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err) + } + } + }) + } +} + +// TestVerifyFailureCases tests all invalid signature verification cases. +func TestVerifyFailure(t *testing.T) { + privateKeyBase64, publicKeyBase64 := generateTestKeyPair() + _, wrongPublicKeyBase64 := generateTestKeyPair() + + tests := []struct { + name string + body []byte + header string + pubKey string + }{ + { + name: "Missing Authorization Header", + body: []byte("Test Payload"), + header: "", + pubKey: publicKeyBase64, + }, + { + name: "Malformed Header", + body: []byte("Test Payload"), + header: `InvalidSignature created="wrong"`, + pubKey: publicKeyBase64, + }, + { + name: "Invalid Base64 Signature", + body: []byte("Test Payload"), + header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) + + `", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) + + `", signature="!!INVALIDBASE64!!"`, + pubKey: publicKeyBase64, + }, + { + name: "Expired Signature", + body: []byte("Test Payload"), + header: `Signature created="` + strconv.FormatInt(time.Now().Unix()-7200, 10) + + `", expires="` + strconv.FormatInt(time.Now().Unix()-3600, 10) + + `", signature="` + signTestData(privateKeyBase64, []byte("Test Payload"), time.Now().Unix()-7200, time.Now().Unix()-3600) + `"`, + pubKey: publicKeyBase64, + }, + { + name: "Invalid Public Key", + body: []byte("Test Payload"), + header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) + + `", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) + + `", signature="` + signTestData(privateKeyBase64, []byte("Test Payload"), time.Now().Unix(), time.Now().Unix()+3600) + `"`, + pubKey: wrongPublicKeyBase64, + }, + { + name: "Invalid Expires Timestamp", + body: []byte("Test Payload"), + header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) + + `", expires="invalid_timestamp"`, + pubKey: publicKeyBase64, + }, + { + name: "Signature Missing in Headers", + body: []byte("Test Payload"), + header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) + + `", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) + `"`, + pubKey: publicKeyBase64, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + verifier, close, _ := New(context.Background(), &Config{}) + err := verifier.Validate(context.Background(), tt.body, tt.header, tt.pubKey) + + if err == nil { + t.Fatal("Expected an error but got none") + } + if close != nil { + if err := close(); err != nil { + t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err) + } + } + }) + } +} diff --git a/pkg/plugin/manager.go b/pkg/plugin/manager.go index 2bf8c20..23b517e 100644 --- a/pkg/plugin/manager.go +++ b/pkg/plugin/manager.go @@ -112,10 +112,11 @@ func (m *Manager) Publisher(ctx context.Context, cfg *Config) (definition.Publis if err != nil { return nil, fmt.Errorf("failed to load provider for %s: %w", cfg.ID, err) } - p, err := pp.New(ctx, cfg.Config) + p, closer, err := pp.New(ctx, cfg.Config) if err != nil { return nil, err } + m.addCloser(closer) return p, nil } @@ -256,8 +257,8 @@ func (m *Manager) Decryptor(ctx context.Context, cfg *Config) (definition.Decryp return decrypter, nil } -func (m *Manager) SignValidator(ctx context.Context, cfg *Config) (definition.Verifier, error) { - svp, err := provider[definition.VerifierProvider](m.plugins, cfg.ID) +func (m *Manager) SignValidator(ctx context.Context, cfg *Config) (definition.SignValidator, error) { + svp, err := provider[definition.SignValidatorProvider](m.plugins, cfg.ID) if err != nil { return nil, fmt.Errorf("failed to load provider for %s: %w", cfg.ID, err) } diff --git a/pkg/response/response.go b/pkg/response/response.go index c72b475..a5ab0c4 100644 --- a/pkg/response/response.go +++ b/pkg/response/response.go @@ -7,48 +7,9 @@ import ( "fmt" "net/http" - "strings" - "github.com/beckn/beckn-onix/pkg/model" ) -// Error represents a standardized error response used across the system. -type Error struct { - // Code is a short, machine-readable error code. - Code string `json:"code,omitempty"` - - // Message provides a human-readable description of the error. - Message string `json:"message,omitempty"` - - // Paths indicates the specific field(s) or endpoint(s) related to the error. - Paths string `json:"paths,omitempty"` -} - -// SchemaValidationErr represents a collection of schema validation failures. -type SchemaValidationErr struct { - Errors []Error -} - -// Error implements the error interface for SchemaValidationErr. -func (e *SchemaValidationErr) Error() string { - var errorMessages []string - for _, err := range e.Errors { - errorMessages = append(errorMessages, fmt.Sprintf("%s: %s", err.Paths, err.Message)) - } - return strings.Join(errorMessages, "; ") -} - -// Message represents a standard message structure with acknowledgment and error information. -type Message struct { - // Ack contains the acknowledgment status of the response. - Ack struct { - Status string `json:"status,omitempty"` - } `json:"ack,omitempty"` - - // Error holds error details if any occurred during processing. - Error *Error `json:"error,omitempty"` -} - // SendAck sends an acknowledgment response (ACK) to the client. func SendAck(w http.ResponseWriter) { resp := &model.Response{ diff --git a/pkg/response/response_test.go b/pkg/response/response_test.go index 8b7f748..96f1caa 100644 --- a/pkg/response/response_test.go +++ b/pkg/response/response_test.go @@ -126,21 +126,6 @@ func TestSendNack(t *testing.T) { } } -func TestSchemaValidationErr_Error(t *testing.T) { - // Create sample validation errors - validationErrors := []Error{ - {Paths: "name", Message: "Name is required"}, - {Paths: "email", Message: "Invalid email format"}, - } - err := SchemaValidationErr{Errors: validationErrors} - expected := "name: Name is required; email: Invalid email format" - if err.Error() != expected { - t.Errorf("err.Error() = %s, want %s", - err.Error(), expected) - - } -} - func compareJSON(expected, actual map[string]interface{}) bool { expectedBytes, _ := json.Marshal(expected) actualBytes, _ := json.Marshal(actual)