updated as per the review comments

This commit is contained in:
MohitKatare-protean
2025-03-30 19:13:02 +05:30
parent 244a7be7c1
commit f0e39e34e7
35 changed files with 1605 additions and 124 deletions

25
Dockerfile.adapter Normal file
View File

@@ -0,0 +1,25 @@
FROM golang:1.24-bullseye AS builder
WORKDIR /app
COPY cmd/adapter ./cmd/adapter
COPY core/ ./core
COPY pkg/ ./pkg
COPY go.mod .
COPY go.sum .
RUN go mod download
RUN go build -o server cmd/adapter/main.go
# Create a minimal runtime image
FROM cgr.dev/chainguard/wolfi-base
# ✅ Alpine is removed; using minimal Debian
WORKDIR /app
# Copy only the built binary and plugin
COPY --from=builder /app/server .
# Expose port 8080
EXPOSE 8080
# Run the Go server with the config flag from environment variable.
CMD ["sh", "-c", "./server --config=${CONFIG_FILE}"]

View File

@@ -28,7 +28,7 @@ func (m *MockPluginManager) Middleware(ctx context.Context, cfg *plugin.Config)
} }
// SignValidator returns a mock implementation of the Verifier interface. // SignValidator returns a mock implementation of the Verifier interface.
func (m *MockPluginManager) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.Verifier, error) { func (m *MockPluginManager) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.SignValidator, error) {
return nil, nil return nil, nil
} }
@@ -200,7 +200,7 @@ func TestRunFailure(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
configData string configData string
mockMgr func() (*plugin.Manager, func(), error) mockMgr func() (*MockPluginManager, func(), error)
mockLogger func(cfg *Config) error mockLogger func(cfg *Config) error
mockServer func(ctx context.Context, mgr handler.PluginManager, cfg *Config) (http.Handler, error) mockServer func(ctx context.Context, mgr handler.PluginManager, cfg *Config) (http.Handler, error)
expectedErr string expectedErr string
@@ -208,8 +208,8 @@ func TestRunFailure(t *testing.T) {
{ {
name: "Invalid Config File", name: "Invalid Config File",
configData: "invalid_config.yaml", configData: "invalid_config.yaml",
mockMgr: func() (*plugin.Manager, func(), error) { mockMgr: func() (*MockPluginManager, func(), error) {
return &plugin.Manager{}, func() {}, nil return &MockPluginManager{}, func() {}, nil
}, },
mockLogger: func(cfg *Config) error { mockLogger: func(cfg *Config) error {
return nil return nil
@@ -236,9 +236,10 @@ func TestRunFailure(t *testing.T) {
// Mock dependencies // Mock dependencies
originalNewManager := newManagerFunc originalNewManager := newManagerFunc
newManagerFunc = func(ctx context.Context, cfg *plugin.ManagerConfig) (*plugin.Manager, func(), error) { // newManagerFunc = func(ctx context.Context, cfg *plugin.ManagerConfig) (*plugin.Manager, func(), error) {
return tt.mockMgr() // return tt.mockMgr()
} // }
newManagerFunc = nil
defer func() { newManagerFunc = originalNewManager }() defer func() { newManagerFunc = originalNewManager }()
originalNewServer := newServerFunc originalNewServer := newServerFunc

View File

@@ -12,7 +12,7 @@ import (
// PluginManager defines an interface for managing plugins dynamically. // PluginManager defines an interface for managing plugins dynamically.
type PluginManager interface { type PluginManager interface {
Middleware(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) Middleware(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error)
SignValidator(ctx context.Context, cfg *plugin.Config) (definition.Verifier, error) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.SignValidator, error)
Validator(ctx context.Context, cfg *plugin.Config) (definition.SchemaValidator, error) Validator(ctx context.Context, cfg *plugin.Config) (definition.SchemaValidator, error)
Router(ctx context.Context, cfg *plugin.Config) (definition.Router, error) Router(ctx context.Context, cfg *plugin.Config) (definition.Router, error)
Publisher(ctx context.Context, cfg *plugin.Config) (definition.Publisher, error) Publisher(ctx context.Context, cfg *plugin.Config) (definition.Publisher, error)

View File

@@ -21,7 +21,7 @@ import (
type stdHandler struct { type stdHandler struct {
signer definition.Signer signer definition.Signer
steps []definition.Step steps []definition.Step
signValidator definition.Verifier signValidator definition.SignValidator
cache definition.Cache cache definition.Cache
km definition.KeyManager km definition.KeyManager
schemaValidator definition.SchemaValidator schemaValidator definition.SchemaValidator
@@ -108,13 +108,15 @@ func (h *stdHandler) subID(ctx context.Context) string {
return h.SubscriberID return h.SubscriberID
} }
var proxyFunc = proxy
// route handles request forwarding or message publishing based on the routing type. // route handles request forwarding or message publishing based on the routing type.
func route(ctx *model.StepContext, r *http.Request, w http.ResponseWriter, pb definition.Publisher) { func route(ctx *model.StepContext, r *http.Request, w http.ResponseWriter, pb definition.Publisher) {
log.Debugf(ctx, "Routing to ctx.Route to %#v", ctx.Route) log.Debugf(ctx, "Routing to ctx.Route to %#v", ctx.Route)
switch ctx.Route.TargetType { switch ctx.Route.TargetType {
case "url": case "url":
log.Infof(ctx.Context, "Forwarding request to URL: %s", ctx.Route.URL) log.Infof(ctx.Context, "Forwarding request to URL: %s", ctx.Route.URL)
proxy(r, w, ctx.Route.URL) proxyFunc(r, w, ctx.Route.URL)
return return
case "publisher": case "publisher":
if pb == nil { if pb == nil {
@@ -124,7 +126,7 @@ func route(ctx *model.StepContext, r *http.Request, w http.ResponseWriter, pb de
return return
} }
log.Infof(ctx.Context, "Publishing message to: %s", ctx.Route.PublisherID) log.Infof(ctx.Context, "Publishing message to: %s", ctx.Route.PublisherID)
if err := pb.Publish(ctx, ctx.Body); err != nil { if err := pb.Publish(ctx, ctx.Route.PublisherID, ctx.Body); err != nil {
log.Errorf(ctx.Context, err, "Failed to publish message") log.Errorf(ctx.Context, err, "Failed to publish message")
http.Error(w, "Error publishing message", http.StatusInternalServerError) http.Error(w, "Error publishing message", http.StatusInternalServerError)
response.SendNack(ctx, w, err) response.SendNack(ctx, w, err)

View File

@@ -52,12 +52,12 @@ func (s *signStep) Run(ctx *model.StepContext) error {
// validateSignStep represents the signature validation step. // validateSignStep represents the signature validation step.
type validateSignStep struct { type validateSignStep struct {
validator definition.Verifier validator definition.SignValidator
km definition.KeyManager km definition.KeyManager
} }
// newValidateSignStep initializes and returns a new validate sign step. // newValidateSignStep initializes and returns a new validate sign step.
func newValidateSignStep(signValidator definition.Verifier, km definition.KeyManager) (definition.Step, error) { func newValidateSignStep(signValidator definition.SignValidator, km definition.KeyManager) (definition.Step, error) {
if signValidator == nil { if signValidator == nil {
return nil, fmt.Errorf("invalid config: SignValidator plugin not configured") return nil, fmt.Errorf("invalid config: SignValidator plugin not configured")
} }
@@ -102,7 +102,7 @@ func (s *validateSignStep) validate(ctx *model.StepContext, value string) error
if err != nil { if err != nil {
return fmt.Errorf("failed to get validation key: %w", err) return fmt.Errorf("failed to get validation key: %w", err)
} }
if _, err := s.validator.Verify(ctx, ctx.Body, []byte(value), key); err != nil { if err := s.validator.Validate(ctx, ctx.Body, value, key); err != nil {
return fmt.Errorf("sign validation failed: %w", err) return fmt.Errorf("sign validation failed: %w", err)
} }
return nil return nil

View File

@@ -23,7 +23,7 @@ func (m *mockPluginManager) Middleware(ctx context.Context, cfg *plugin.Config)
} }
// SignValidator returns a mock verifier implementation. // SignValidator returns a mock verifier implementation.
func (m *mockPluginManager) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.Verifier, error) { func (m *mockPluginManager) SignValidator(ctx context.Context, cfg *plugin.Config) (definition.SignValidator, error) {
return nil, nil return nil, nil
} }

View File

@@ -86,6 +86,14 @@ type Route struct {
URL *url.URL // For API calls URL *url.URL // For API calls
} }
type Keyset struct {
UniqueKeyID string
SigningPrivate string
SigningPublic string
EncrPrivate string
EncrPublic string
}
// StepContext holds context information for a request processing step. // StepContext holds context information for a request processing step.
type StepContext struct { type StepContext struct {
context.Context context.Context

View File

@@ -6,18 +6,10 @@ import (
"github.com/beckn/beckn-onix/pkg/model" "github.com/beckn/beckn-onix/pkg/model"
) )
type Keyset struct {
UniqueKeyID string
SigningPrivate string
SigningPublic string
EncrPrivate string
EncrPublic string
}
// KeyManager defines the interface for key management operations/methods. // KeyManager defines the interface for key management operations/methods.
type KeyManager interface { type KeyManager interface {
GenerateKeyPairs() (*Keyset, error) GenerateKeyPairs() (*model.Keyset, error)
StorePrivateKeys(ctx context.Context, keyID string, keys *Keyset) error StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error
SigningPrivateKey(ctx context.Context, keyID string) (string, string, error) SigningPrivateKey(ctx context.Context, keyID string) (string, string, error)
EncrPrivateKey(ctx context.Context, keyID string) (string, string, error) EncrPrivateKey(ctx context.Context, keyID string) (string, string, error)
SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error)
@@ -29,7 +21,3 @@ type KeyManager interface {
type KeyManagerProvider interface { type KeyManagerProvider interface {
New(context.Context, Cache, RegistryLookup, map[string]string) (KeyManager, func() error, error) New(context.Context, Cache, RegistryLookup, map[string]string) (KeyManager, func() error, error)
} }
type RegistryLookup interface {
Lookup(ctx context.Context, req *model.Subscription) ([]model.Subscription, error)
}

View File

@@ -5,12 +5,10 @@ import "context"
// Publisher defines the general publisher interface for messaging plugins. // Publisher defines the general publisher interface for messaging plugins.
type Publisher interface { type Publisher interface {
// Publish sends a message (as a byte slice) using the underlying messaging system. // Publish sends a message (as a byte slice) using the underlying messaging system.
Publish(ctx context.Context, msg []byte) error Publish(context.Context, string, []byte) error
Close() error // Important for releasing resources.
} }
type PublisherProvider interface { type PublisherProvider interface {
// New initializes a new publisher instance with the given configuration. // New initializes a new publisher instance with the given configuration.
New(ctx context.Context, config map[string]string) (Publisher, error) New(ctx context.Context, config map[string]string) (Publisher, func(), error)
} }

View File

@@ -0,0 +1,11 @@
package definition
import (
"context"
"github.com/beckn/beckn-onix/pkg/model"
)
type RegistryLookup interface {
Lookup(ctx context.Context, req *model.Subscription) ([]model.Subscription, error)
}

View File

@@ -3,14 +3,9 @@ package definition
import ( import (
"context" "context"
"net/url" "net/url"
)
// Route defines the structure for the Route returned. "github.com/beckn/beckn-onix/pkg/model"
type Route struct { )
TargetType string // "url" or "msgq" or "bap" or "bpp"
PublisherID string // For message queues
URL *url.URL // For API calls
}
// RouterProvider initializes the a new Router instance with the given config. // RouterProvider initializes the a new Router instance with the given config.
type RouterProvider interface { type RouterProvider interface {
@@ -20,5 +15,5 @@ type RouterProvider interface {
// Router defines the interface for routing requests. // Router defines the interface for routing requests.
type Router interface { type Router interface {
// Route determines the routing destination based on the request context. // Route determines the routing destination based on the request context.
Route(ctx context.Context, url *url.URL, body []byte) (*Route, error) Route(ctx context.Context, url *url.URL, body []byte) (*model.Route, error)
} }

View File

@@ -8,7 +8,6 @@ type Signer interface {
// The signature is created with the given timestamps: createdAt (signature creation time) // The signature is created with the given timestamps: createdAt (signature creation time)
// and expiresAt (signature expiration time). // and expiresAt (signature expiration time).
Sign(ctx context.Context, body []byte, privateKeyBase64 string, createdAt, expiresAt int64) (string, error) Sign(ctx context.Context, body []byte, privateKeyBase64 string, createdAt, expiresAt int64) (string, error)
Close() error // Close for releasing resources
} }
// SignerProvider initializes a new signer instance with the given config. // SignerProvider initializes a new signer instance with the given config.
@@ -16,9 +15,3 @@ type SignerProvider interface {
// New creates a new signer instance based on the provided config. // New creates a new signer instance based on the provided config.
New(ctx context.Context, config map[string]string) (Signer, func() error, error) New(ctx context.Context, config map[string]string) (Signer, func() error, error)
} }
// PrivateKeyManager is the interface for key management plugin.
type PrivateKeyManager interface {
// PrivateKey retrieves the private key for the given subscriberID and keyID.
PrivateKey(ctx context.Context, subscriberID string, keyID string) (string, error)
}

View File

@@ -0,0 +1,15 @@
package definition
import "context"
// SignValidator defines the method for verifying signatures.
type SignValidator interface {
// Validate checks the validity of the signature for the given body.
Validate(ctx context.Context, body []byte, header string, publicKeyBase64 string) error
}
// SignValidatorProvider initializes a new Verifier instance with the given config.
type SignValidatorProvider interface {
// New creates a new Verifier instance based on the provided config.
New(ctx context.Context, config map[string]string) (SignValidator, func() error, error)
}

View File

@@ -7,13 +7,13 @@ import (
decrypter "github.com/beckn/beckn-onix/pkg/plugin/implementation/decrypter" decrypter "github.com/beckn/beckn-onix/pkg/plugin/implementation/decrypter"
) )
// DecrypterProvider implements the definition.DecrypterProvider interface. // decrypterProvider implements the definition.decrypterProvider interface.
type DecrypterProvider struct{} type decrypterProvider struct{}
// New creates a new Decrypter instance using the provided configuration. // New creates a new Decrypter instance using the provided configuration.
func (dp DecrypterProvider) New(ctx context.Context, config map[string]string) (definition.Decrypter, func() error, error) { func (dp decrypterProvider) New(ctx context.Context, config map[string]string) (definition.Decrypter, func() error, error) {
return decrypter.New(ctx) return decrypter.New(ctx)
} }
// Provider is the exported symbol that the plugin manager will look for. // Provider is the exported symbol that the plugin manager will look for.
var Provider definition.DecrypterProvider = DecrypterProvider{} var Provider = decrypterProvider{}

View File

@@ -25,7 +25,7 @@ func TestDecrypterProviderSuccess(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
provider := DecrypterProvider{} provider := decrypterProvider{}
decrypter, cleanup, err := provider.New(tt.ctx, tt.config) decrypter, cleanup, err := provider.New(tt.ctx, tt.config)
// Check error. // Check error.

View File

@@ -7,12 +7,12 @@ import (
"github.com/beckn/beckn-onix/pkg/plugin/implementation/encrypter" "github.com/beckn/beckn-onix/pkg/plugin/implementation/encrypter"
) )
// EncrypterProvider implements the definition.EncrypterProvider interface. // encrypterProvider implements the definition.encrypterProvider interface.
type EncrypterProvider struct{} type encrypterProvider struct{}
func (ep EncrypterProvider) New(ctx context.Context, config map[string]string) (definition.Encrypter, func() error, error) { func (ep encrypterProvider) New(ctx context.Context, config map[string]string) (definition.Encrypter, func() error, error) {
return encrypter.New(ctx) return encrypter.New(ctx)
} }
// Provider is the exported symbol that the plugin manager will look for. // Provider is the exported symbol that the plugin manager will look for.
var Provider definition.EncrypterProvider = EncrypterProvider{} var Provider = encrypterProvider{}

View File

@@ -28,7 +28,7 @@ func TestEncrypterProviderSuccess(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Create provider and encrypter. // Create provider and encrypter.
provider := EncrypterProvider{} provider := encrypterProvider{}
encrypter, cleanup, err := provider.New(tt.ctx, tt.config) encrypter, cleanup, err := provider.New(tt.ctx, tt.config)
if err != nil { if err != nil {
t.Fatalf("EncrypterProvider.New() error = %v", err) t.Fatalf("EncrypterProvider.New() error = %v", err)

View File

@@ -0,0 +1,24 @@
package main
import (
"context"
"net/http"
"strings"
"github.com/beckn/beckn-onix/pkg/plugin/implementation/reqpreprocessor"
)
type provider struct{}
func (p provider) New(ctx context.Context, c map[string]string) (func(http.Handler) http.Handler, error) {
config := &reqpreprocessor.Config{}
if contextKeysStr, ok := c["contextKeys"]; ok {
config.ContextKeys = strings.Split(contextKeysStr, ",")
}
if role, ok := c["role"]; ok {
config.Role = role
}
return reqpreprocessor.NewPreProcessor(config)
}
var Provider = provider{}

View File

@@ -0,0 +1,85 @@
package main
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO: Will Split this into success and fail (two test cases)
func TestProviderNew(t *testing.T) {
testCases := []struct {
name string
config map[string]string
expectedError bool
expectedStatus int
prepareRequest func(req *http.Request)
}{
{
name: "No Config",
config: map[string]string{},
expectedError: true,
expectedStatus: http.StatusOK,
prepareRequest: func(req *http.Request) {
// Add minimal required headers.
req.Header.Set("context", "test-context")
req.Header.Set("transaction_id", "test-transaction")
},
},
{
name: "With Check Keys",
config: map[string]string{
"contextKeys": "message_id,transaction_id",
},
expectedError: false,
expectedStatus: http.StatusOK,
prepareRequest: func(req *http.Request) {
// Add headers matching the check keys.
req.Header.Set("context", "test-context")
req.Header.Set("transaction_id", "test-transaction")
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
requestBody := `{
"context": {
"transaction_id": "abc"
}
}`
p := provider{}
middleware, err := p.New(context.Background(), tc.config)
if tc.expectedError {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.NotNil(t, middleware)
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("POST", "/", strings.NewReader(requestBody))
req.Header.Set("Content-Type", "application/json")
if tc.prepareRequest != nil {
tc.prepareRequest(req)
}
w := httptest.NewRecorder()
middlewaredHandler := middleware(testHandler)
middlewaredHandler.ServeHTTP(w, req)
assert.Equal(t, tc.expectedStatus, w.Code, "Unexpected response status")
responseBody := w.Body.String()
t.Logf("Response Body: %s", responseBody)
})
}
}

View File

@@ -0,0 +1,108 @@
package reqpreprocessor
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/beckn/beckn-onix/pkg/log"
"github.com/google/uuid"
)
type Config struct {
ContextKeys []string
Role string
}
type becknRequest struct {
Context map[string]any `json:"context"`
}
const contextKey = "context"
const subscriberIDKey = "subscriber_id"
func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
if err := validateConfig(cfg); err != nil {
return nil, err
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var req becknRequest
ctx := r.Context()
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Failed to decode request body", http.StatusBadRequest)
return
}
if req.Context == nil {
http.Error(w, fmt.Sprintf("%s field not found.", contextKey), http.StatusBadRequest)
return
}
var subID any
switch cfg.Role {
case "bap":
subID = req.Context["bap_id"]
case "bpp":
subID = req.Context["bpp_id"]
}
if subID != nil {
log.Debugf(ctx, "adding subscriberId to request:%s, %v", subscriberIDKey, subID)
ctx = context.WithValue(ctx, subscriberIDKey, subID)
}
for _, key := range cfg.ContextKeys {
value := uuid.NewString()
updatedValue := update(req.Context, key, value)
ctx = context.WithValue(ctx, key, updatedValue)
}
reqData := map[string]any{"context": req.Context}
updatedBody, _ := json.Marshal(reqData)
r.Body = io.NopCloser(bytes.NewBuffer(updatedBody))
r.ContentLength = int64(len(updatedBody))
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}, nil
}
func update(wrapper map[string]any, key string, value any) any {
field, exists := wrapper[key]
if !exists || isEmpty(field) {
wrapper[key] = value
return value
}
return field
}
func isEmpty(v any) bool {
switch v := v.(type) {
case string:
return v == ""
case nil:
return true
default:
return false
}
}
func validateConfig(cfg *Config) error {
if cfg == nil {
return errors.New("config cannot be nil")
}
// Check if ContextKeys is empty.
if len(cfg.ContextKeys) == 0 {
return errors.New("ContextKeys cannot be empty")
}
// Validate that ContextKeys does not contain empty strings.
for _, key := range cfg.ContextKeys {
if key == "" {
return errors.New("ContextKeys cannot contain empty strings")
}
}
return nil
}

View File

@@ -0,0 +1,178 @@
package reqpreprocessor
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestNewUUIDSetterSuccessCases(t *testing.T) {
tests := []struct {
name string
config *Config
requestBody map[string]any
expectedKeys []string
role string
}{
{
name: "Valid keys, update missing keys with bap role",
config: &Config{
ContextKeys: []string{"transaction_id", "message_id"},
Role: "bap",
},
requestBody: map[string]any{
"context": map[string]any{
"transaction_id": "",
"message_id": nil,
"bap_id": "bap-123",
},
},
expectedKeys: []string{"transaction_id", "message_id", "bap_id"},
role: "bap",
},
{
name: "Valid keys, do not update existing keys with bpp role",
config: &Config{
ContextKeys: []string{"transaction_id", "message_id"},
Role: "bpp",
},
requestBody: map[string]any{
"context": map[string]any{
"transaction_id": "existing-transaction",
"message_id": "existing-message",
"bpp_id": "bpp-456",
},
},
expectedKeys: []string{"transaction_id", "message_id", "bpp_id"},
role: "bpp",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware, err := NewPreProcessor(tt.config)
if err != nil {
t.Fatalf("Unexpected error while creating middleware: %v", err)
}
bodyBytes, _ := json.Marshal(tt.requestBody)
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
w.WriteHeader(http.StatusOK)
subID, ok := ctx.Value(subscriberIDKey).(string)
if !ok {
http.Error(w, "Subscriber ID not found", http.StatusInternalServerError)
return
}
response := map[string]any{"subscriber_id": subID}
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
middleware(dummyHandler).ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Expected status code 200, but got %d", rec.Code)
return
}
var responseBody map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &responseBody); err != nil {
t.Fatal("Failed to unmarshal response body:", err)
}
expectedSubIDKey := "bap_id"
if tt.role == "bpp" {
expectedSubIDKey = "bpp_id"
}
subID, ok := responseBody["subscriber_id"].(string)
if !ok {
t.Error("subscriber_id not found in response")
return
}
expectedSubID := tt.requestBody["context"].(map[string]any)[expectedSubIDKey]
if subID != expectedSubID {
t.Errorf("Expected subscriber_id %v, but got %v", expectedSubID, subID)
}
})
}
}
func TestNewUUIDSetterErrorCases(t *testing.T) {
tests := []struct {
name string
config *Config
requestBody map[string]any
expectedCode int
}{
{
name: "Missing context key",
config: &Config{
ContextKeys: []string{"transaction_id"},
},
requestBody: map[string]any{
"otherKey": "value",
},
expectedCode: http.StatusBadRequest,
},
{
name: "Invalid context type",
config: &Config{
ContextKeys: []string{"transaction_id"},
},
requestBody: map[string]any{
"context": "not-a-map",
},
expectedCode: http.StatusBadRequest,
},
{
name: "Nil config",
config: nil,
requestBody: map[string]any{},
expectedCode: http.StatusInternalServerError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware, err := NewPreProcessor(tt.config)
if tt.config == nil {
if err == nil {
t.Error("Expected an error for nil config, but got none")
}
return
}
if err != nil {
t.Fatalf("Unexpected error while creating middleware: %v", err)
}
bodyBytes, _ := json.Marshal(tt.requestBody)
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware(dummyHandler).ServeHTTP(rec, req)
if rec.Code != tt.expectedCode {
t.Errorf("Expected status code %d, but got %d", tt.expectedCode, rec.Code)
}
})
}
}

View File

@@ -4,8 +4,8 @@ import (
"context" "context"
"errors" "errors"
definition "github.com/beckn/beckn-onix/pkg/plugin/definition" "github.com/beckn/beckn-onix/pkg/plugin/definition"
router "github.com/beckn/beckn-onix/pkg/plugin/implementation/router" "github.com/beckn/beckn-onix/pkg/plugin/implementation/router"
) )
// RouterProvider provides instances of Router. // RouterProvider provides instances of Router.

View File

@@ -0,0 +1,33 @@
package main
import (
"context"
"errors"
"github.com/beckn/beckn-onix/pkg/plugin/definition"
"github.com/beckn/beckn-onix/pkg/plugin/implementation/schemavalidator"
)
// schemaValidatorProvider provides instances of schemaValidator.
type schemaValidatorProvider struct{}
// New initializes a new Verifier instance.
func (vp schemaValidatorProvider) New(ctx context.Context, config map[string]string) (definition.SchemaValidator, func() error, error) {
if ctx == nil {
return nil, nil, errors.New("context cannot be nil")
}
// Extract schemaDir from the config map
schemaDir, ok := config["schemaDir"]
if !ok || schemaDir == "" {
return nil, nil, errors.New("config must contain 'schemaDir'")
}
// Create a new schemaValidator instance with the provided configuration
return schemavalidator.New(ctx, &schemavalidator.Config{
SchemaDir: schemaDir,
})
}
// Provider is the exported symbol that the plugin manager will look for.
var Provider = schemaValidatorProvider{}

View File

@@ -0,0 +1,150 @@
package main
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// setupTestSchema creates a temporary directory and writes a sample schema file.
func setupTestSchema(t *testing.T) string {
t.Helper()
// Create a temporary directory for the schema
schemaDir, err := os.MkdirTemp("", "schemas")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
// Create the directory structure for the schema file
schemaFilePath := filepath.Join(schemaDir, "example", "1.0", "test_schema.json")
if err := os.MkdirAll(filepath.Dir(schemaFilePath), 0755); err != nil {
t.Fatalf("Failed to create schema directory structure: %v", err)
}
// Define a sample schema
schemaContent := `{
"type": "object",
"properties": {
"context": {
"type": "object",
"properties": {
"domain": {"type": "string"},
"version": {"type": "string"}
},
"required": ["domain", "version"]
}
},
"required": ["context"]
}`
// Write the schema to the file
if err := os.WriteFile(schemaFilePath, []byte(schemaContent), 0644); err != nil {
t.Fatalf("Failed to write schema file: %v", err)
}
return schemaDir
}
// TestValidatorProviderSuccess tests successful ValidatorProvider implementation.
func TestValidatorProviderSuccess(t *testing.T) {
schemaDir := setupTestSchema(t)
defer os.RemoveAll(schemaDir)
// Define test cases.
tests := []struct {
name string
ctx context.Context
config map[string]string
expectedError string
}{
{
name: "Valid schema directory",
ctx: context.Background(), // Valid context
config: map[string]string{"schemaDir": schemaDir},
expectedError: "",
},
}
// Test using table-driven tests
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
vp := schemaValidatorProvider{}
schemaValidator, _, err := vp.New(tt.ctx, tt.config)
// Ensure no error occurred
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
// Ensure the schemaValidator is not nil
if schemaValidator == nil {
t.Error("expected a non-nil schemaValidator, got nil")
}
})
}
}
// TestValidatorProviderSuccess tests cases where ValidatorProvider creation should fail.
func TestValidatorProviderFailure(t *testing.T) {
schemaDir := setupTestSchema(t)
defer os.RemoveAll(schemaDir)
// Define test cases.
tests := []struct {
name string
ctx context.Context
config map[string]string
expectedError string
}{
{
name: "Config is empty",
ctx: context.Background(),
config: map[string]string{},
expectedError: "config must contain 'schemaDir'",
},
{
name: "schemaDir is empty",
ctx: context.Background(),
config: map[string]string{"schemaDir": ""},
expectedError: "config must contain 'schemaDir'",
},
{
name: "Invalid schema directory",
ctx: context.Background(), // Valid context
config: map[string]string{"schemaDir": "/invalid/dir"},
expectedError: "failed to initialise schemaValidator: schema directory does not exist: /invalid/dir",
},
{
name: "Nil context",
ctx: nil, // Nil context
config: map[string]string{"schemaDir": schemaDir},
expectedError: "context cannot be nil",
},
}
// Test using table-driven tests
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
vp := schemaValidatorProvider{}
_, _, err := vp.New(tt.ctx, tt.config)
// Check for expected error
if tt.expectedError != "" {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error %q, got %v", tt.expectedError, err)
}
return
}
// Ensure no error occurred
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
})
}
}

View File

@@ -0,0 +1,197 @@
package schemavalidator
import (
"context"
"encoding/json"
"fmt"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"github.com/beckn/beckn-onix/pkg/model"
"github.com/santhosh-tekuri/jsonschema/v6"
)
// Payload represents the structure of the data payload with context information.
type payload struct {
Context struct {
Domain string `json:"domain"`
Version string `json:"version"`
} `json:"context"`
}
// schemaValidator implements the Validator interface.
type schemaValidator struct {
config *Config
schemaCache map[string]*jsonschema.Schema
}
// Config struct for SchemaValidator.
type Config struct {
SchemaDir string
}
// New creates a new ValidatorProvider instance.
func New(ctx context.Context, config *Config) (*schemaValidator, func() error, error) {
// Check if config is nil
if config == nil {
return nil, nil, fmt.Errorf("config cannot be nil")
}
v := &schemaValidator{
config: config,
schemaCache: make(map[string]*jsonschema.Schema),
}
// Call Initialise function to load schemas and get validators
if err := v.initialise(); err != nil {
return nil, nil, fmt.Errorf("failed to initialise schemaValidator: %v", err)
}
return v, nil, nil
}
// Validate validates the given data against the schema.
func (v *schemaValidator) Validate(ctx context.Context, url *url.URL, data []byte) error {
var payloadData payload
err := json.Unmarshal(data, &payloadData)
if err != nil {
return fmt.Errorf("failed to parse JSON payload: %v", err)
}
// Extract domain, version, and endpoint from the payload and uri.
cxtDomain := payloadData.Context.Domain
version := payloadData.Context.Version
version = fmt.Sprintf("v%s", version)
endpoint := path.Base(url.String())
// ToDo Add debug log here
fmt.Println("Handling request for endpoint:", endpoint)
domain := strings.ToLower(cxtDomain)
domain = strings.ReplaceAll(domain, ":", "_")
// Construct the schema file name.
schemaFileName := fmt.Sprintf("%s_%s_%s", domain, version, endpoint)
// Retrieve the schema from the cache.
schema, exists := v.schemaCache[schemaFileName]
if !exists {
return fmt.Errorf("schema not found for domain: %s", schemaFileName)
}
var jsonData any
if err := json.Unmarshal(data, &jsonData); err != nil {
return fmt.Errorf("failed to parse JSON data: %v", err)
}
err = schema.Validate(jsonData)
if err != nil {
// Handle schema validation errors
if validationErr, ok := err.(*jsonschema.ValidationError); ok {
// Convert validation errors into an array of SchemaValError
var schemaErrors []model.Error
for _, cause := range validationErr.Causes {
// Extract the path and message from the validation error
path := strings.Join(cause.InstanceLocation, ".") // JSON path to the invalid field
message := cause.Error() // Validation error message
// Append the error to the schemaErrors array
schemaErrors = append(schemaErrors, model.Error{
Paths: path,
Message: message,
})
}
// Return the array of schema validation errors
return &model.SchemaValidationErr{Errors: schemaErrors}
}
// Return a generic error for non-validation errors
return fmt.Errorf("validation failed: %v", err)
}
// Return nil if validation succeeds
return nil
}
// ValidatorProvider provides instances of Validator.
type ValidatorProvider struct{}
// Initialise initialises the validator provider by compiling all the JSON schema files
// from the specified directory and storing them in a cache indexed by their schema filenames.
func (v *schemaValidator) initialise() error {
schemaDir := v.config.SchemaDir
// Check if the directory exists and is accessible.
info, err := os.Stat(schemaDir)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("schema directory does not exist: %s", schemaDir)
}
return fmt.Errorf("failed to access schema directory: %v", err)
}
if !info.IsDir() {
return fmt.Errorf("provided schema path is not a directory: %s", schemaDir)
}
compiler := jsonschema.NewCompiler()
// Helper function to process directories recursively.
var processDir func(dir string) error
processDir = func(dir string) error {
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("failed to read directory: %v", err)
}
for _, entry := range entries {
path := filepath.Join(dir, entry.Name())
if entry.IsDir() {
// Recursively process subdirectories.
if err := processDir(path); err != nil {
return err
}
} else if filepath.Ext(entry.Name()) == ".json" {
// Process JSON files.
compiledSchema, err := compiler.Compile(path)
if err != nil {
return fmt.Errorf("failed to compile JSON schema from file %s: %v", entry.Name(), err)
}
// Use relative path from schemaDir to avoid absolute paths and make schema keys domain/version specific.
relativePath, err := filepath.Rel(schemaDir, path)
if err != nil {
return fmt.Errorf("failed to get relative path for file %s: %v", entry.Name(), err)
}
// Split the relative path to get domain, version, and schema.
parts := strings.Split(relativePath, string(os.PathSeparator))
// Ensure that the file path has at least 3 parts: domain, version, and schema file.
if len(parts) < 3 {
return fmt.Errorf("invalid schema file structure, expected domain/version/schema.json but got: %s", relativePath)
}
// Extract domain, version, and schema filename from the parts.
// Validate that the extracted parts are non-empty.
domain := strings.TrimSpace(parts[0])
version := strings.TrimSpace(parts[1])
schemaFileName := strings.TrimSpace(parts[2])
schemaFileName = strings.TrimSuffix(schemaFileName, ".json")
if domain == "" || version == "" || schemaFileName == "" {
return fmt.Errorf("invalid schema file structure, one or more components are empty. Relative path: %s", relativePath)
}
// Construct a unique key combining domain, version, and schema name (e.g., ondc_trv10_v2.0.0_schema).
uniqueKey := fmt.Sprintf("%s_%s_%s", domain, version, schemaFileName)
// Store the compiled schema in the SchemaCache using the unique key.
v.schemaCache[uniqueKey] = compiledSchema
}
}
return nil
}
// Start processing from the root schema directory.
if err := processDir(schemaDir); err != nil {
return fmt.Errorf("failed to read schema directory: %v", err)
}
return nil
}

View File

@@ -0,0 +1,353 @@
package schemavalidator
import (
"context"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"github.com/santhosh-tekuri/jsonschema/v6"
)
// setupTestSchema creates a temporary directory and writes a sample schema file.
func setupTestSchema(t *testing.T) string {
t.Helper()
// Create a temporary directory for the schema
schemaDir, err := os.MkdirTemp("", "schemas")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
// Create the directory structure for the schema file
schemaFilePath := filepath.Join(schemaDir, "example", "v1.0", "endpoint.json")
if err := os.MkdirAll(filepath.Dir(schemaFilePath), 0755); err != nil {
t.Fatalf("Failed to create schema directory structure: %v", err)
}
// Define a sample schema
schemaContent := `{
"type": "object",
"properties": {
"context": {
"type": "object",
"properties": {
"domain": {"type": "string"},
"version": {"type": "string"},
"action": {"type": "string"}
},
"required": ["domain", "version", "action"]
}
},
"required": ["context"]
}`
// Write the schema to the file
if err := os.WriteFile(schemaFilePath, []byte(schemaContent), 0644); err != nil {
t.Fatalf("Failed to write schema file: %v", err)
}
return schemaDir
}
func TestValidator_Validate_Success(t *testing.T) {
tests := []struct {
name string
url string
payload string
wantErr bool
}{
{
name: "Valid payload",
url: "http://example.com/endpoint",
payload: `{"context": {"domain": "example", "version": "1.0", "action": "endpoint"}}`,
wantErr: false,
},
}
// Setup a temporary schema directory for testing
schemaDir := setupTestSchema(t)
defer os.RemoveAll(schemaDir)
config := &Config{SchemaDir: schemaDir}
v, _, err := New(context.Background(), config)
if err != nil {
t.Fatalf("Failed to create validator: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, _ := url.Parse(tt.url)
err := v.Validate(context.Background(), u, []byte(tt.payload))
if err != nil {
t.Errorf("Unexpected error: %v", err)
} else {
t.Logf("Test %s passed with no errors", tt.name)
}
})
}
}
func TestValidator_Validate_Failure(t *testing.T) {
tests := []struct {
name string
url string
payload string
wantErr string
}{
{
name: "Invalid JSON payload",
url: "http://example.com/endpoint",
payload: `{"context": {"domain": "example", "version": "1.0"`,
wantErr: "failed to parse JSON payload",
},
{
name: "Schema validation failure",
url: "http://example.com/endpoint",
payload: `{"context": {"domain": "example", "version": "1.0"}}`,
wantErr: "context: at '/context': missing property 'action'",
},
{
name: "Schema not found",
url: "http://example.com/unknown_endpoint",
payload: `{"context": {"domain": "example", "version": "1.0"}}`,
wantErr: "schema not found for domain",
},
}
// Setup a temporary schema directory for testing
schemaDir := setupTestSchema(t)
defer os.RemoveAll(schemaDir)
config := &Config{SchemaDir: schemaDir}
v, _, err := New(context.Background(), config)
if err != nil {
t.Fatalf("Failed to create validator: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, _ := url.Parse(tt.url)
err := v.Validate(context.Background(), u, []byte(tt.payload))
if tt.wantErr != "" {
if err == nil {
t.Errorf("Expected error containing '%s', but got nil", tt.wantErr)
} else if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("Expected error containing '%s', but got '%v'", tt.wantErr, err)
} else {
t.Logf("Test %s passed with expected error: %v", tt.name, err)
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
} else {
t.Logf("Test %s passed with no errors", tt.name)
}
}
})
}
}
func TestValidator_Initialise(t *testing.T) {
tests := []struct {
name string
setupFunc func(schemaDir string) error
wantErr string
}{
{
name: "Schema directory does not exist",
setupFunc: func(schemaDir string) error {
// Do not create the schema directory
return nil
},
wantErr: "schema directory does not exist",
},
{
name: "Schema path is not a directory",
setupFunc: func(schemaDir string) error {
// Create a file instead of a directory
return os.WriteFile(schemaDir, []byte{}, 0644)
},
wantErr: "provided schema path is not a directory",
},
{
name: "Invalid schema file structure",
setupFunc: func(schemaDir string) error {
// Create an invalid schema file structure
invalidSchemaFile := filepath.Join(schemaDir, "invalid_schema.json")
if err := os.MkdirAll(filepath.Dir(invalidSchemaFile), 0755); err != nil {
t.Fatalf("Failed to create directory: %v", err)
}
return os.WriteFile(invalidSchemaFile, []byte(`{}`), 0644)
},
wantErr: "invalid schema file structure",
},
{
name: "Failed to compile JSON schema",
setupFunc: func(schemaDir string) error {
// Create a schema file with invalid JSON
invalidSchemaFile := filepath.Join(schemaDir, "example", "1.0", "endpoint.json")
if err := os.MkdirAll(filepath.Dir(invalidSchemaFile), 0755); err != nil {
t.Fatalf("Failed to create directory: %v", err)
}
return os.WriteFile(invalidSchemaFile, []byte(`{invalid json}`), 0644)
},
wantErr: "failed to compile JSON schema",
},
{
name: "Invalid schema file structure with empty components",
setupFunc: func(schemaDir string) error {
// Create a schema file with empty domain, version, or schema name
invalidSchemaFile := filepath.Join(schemaDir, "", "1.0", "endpoint.json")
if err := os.MkdirAll(filepath.Dir(invalidSchemaFile), 0755); err != nil {
t.Fatalf("Failed to create directory: %v", err)
}
return os.WriteFile(invalidSchemaFile, []byte(`{
"type": "object",
"properties": {
"context": {
"type": "object",
"properties": {
"domain": {"type": "string"},
"version": {"type": "string"}
},
"required": ["domain", "version"]
}
},
"required": ["context"]
}`), 0644)
},
wantErr: "failed to read schema directory: invalid schema file structure, expected domain/version/schema.json but got: 1.0/endpoint.json",
},
{
name: "Failed to read directory",
setupFunc: func(schemaDir string) error {
// Create a directory and remove read permissions
if err := os.MkdirAll(schemaDir, 0000); err != nil {
t.Fatalf("Failed to create directory: %v", err)
}
return nil
},
wantErr: "failed to read directory",
},
{
name: "Valid schema directory",
setupFunc: func(schemaDir string) error {
// Create a valid schema file
validSchemaFile := filepath.Join(schemaDir, "example", "1.0", "endpoint.json")
if err := os.MkdirAll(filepath.Dir(validSchemaFile), 0755); err != nil {
t.Fatalf("Failed to create directory: %v", err)
}
return os.WriteFile(validSchemaFile, []byte(`{
"type": "object",
"properties": {
"context": {
"type": "object",
"properties": {
"domain": {"type": "string"},
"version": {"type": "string"}
},
"required": ["domain", "version"]
}
},
"required": ["context"]
}`), 0644)
},
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup a temporary schema directory for testing
schemaDir := filepath.Join(os.TempDir(), "schemas")
defer os.RemoveAll(schemaDir)
// Run the setup function to prepare the test case
if err := tt.setupFunc(schemaDir); err != nil {
t.Fatalf("setupFunc() error = %v", err)
}
config := &Config{SchemaDir: schemaDir}
v := &schemaValidator{
config: config,
schemaCache: make(map[string]*jsonschema.Schema),
}
err := v.initialise()
if (err != nil && !strings.Contains(err.Error(), tt.wantErr)) || (err == nil && tt.wantErr != "") {
t.Errorf("Error: initialise() returned error = %v, expected error = %v", err, tt.wantErr)
} else if err == nil {
t.Logf("Test %s passed: validator initialized successfully", tt.name)
} else {
t.Logf("Test %s passed with expected error: %v", tt.name, err)
}
})
}
}
func TestValidatorNew_Success(t *testing.T) {
schemaDir := setupTestSchema(t)
defer os.RemoveAll(schemaDir)
config := &Config{SchemaDir: schemaDir}
_, _, err := New(context.Background(), config)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
func TestValidatorNewFailure(t *testing.T) {
tests := []struct {
name string
config *Config
setupFunc func(schemaDir string) error
wantErr string
}{
{
name: "Config is nil",
config: nil,
setupFunc: func(schemaDir string) error {
return nil
},
wantErr: "config cannot be nil",
},
{
name: "Failed to initialise validators",
config: &Config{
SchemaDir: "/invalid/path",
},
setupFunc: func(schemaDir string) error {
// Do not create the schema directory
return nil
},
wantErr: "ailed to initialise schemaValidator: schema directory does not exist: /invalid/path",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Run the setup function if provided
if tt.setupFunc != nil {
schemaDir := ""
if tt.config != nil {
schemaDir = tt.config.SchemaDir
}
if err := tt.setupFunc(schemaDir); err != nil {
t.Fatalf("Setup function failed: %v", err)
}
}
// Call the New function with the test config
_, _, err := New(context.Background(), tt.config)
if (err != nil && !strings.Contains(err.Error(), tt.wantErr)) || (err == nil && tt.wantErr != "") {
t.Errorf("Error: New() returned error = %v, expected error = %v", err, tt.wantErr)
} else {
t.Logf("Test %s passed with expected error: %v", tt.name, err)
}
})
}
}

View File

@@ -21,4 +21,4 @@ func (p SignerProvider) New(ctx context.Context, config map[string]string) (defi
} }
// Provider is the exported symbol that the plugin manager will look for. // Provider is the exported symbol that the plugin manager will look for.
var Provider definition.SignerProvider = SignerProvider{} var Provider = SignerProvider{}

View File

@@ -23,7 +23,7 @@ type Signer struct {
func New(ctx context.Context, config *Config) (*Signer, func() error, error) { func New(ctx context.Context, config *Config) (*Signer, func() error, error) {
s := &Signer{config: config} s := &Signer{config: config}
return s, s.Close, nil return s, nil, nil
} }
// hash generates a signing string using BLAKE-512 hashing. // hash generates a signing string using BLAKE-512 hashing.
@@ -70,8 +70,3 @@ func (s *Signer) Sign(ctx context.Context, body []byte, privateKeyBase64 string,
return base64.StdEncoding.EncodeToString(signature), nil return base64.StdEncoding.EncodeToString(signature), nil
} }
// Close releases resources (mock implementation returning nil).
func (s *Signer) Close() error {
return nil
}

View File

@@ -0,0 +1,24 @@
package main
import (
"context"
"errors"
"github.com/beckn/beckn-onix/pkg/plugin/definition"
"github.com/beckn/beckn-onix/pkg/plugin/implementation/signvalidator"
)
// provider provides instances of Verifier.
type provider struct{}
// New initializes a new Verifier instance.
func (vp provider) New(ctx context.Context, config map[string]string) (definition.SignValidator, func() error, error) {
if ctx == nil {
return nil, nil, errors.New("context cannot be nil")
}
return signvalidator.New(ctx, &signvalidator.Config{})
}
// Provider is the exported symbol that the plugin manager will look for.
var Provider = provider{}

View File

@@ -0,0 +1,89 @@
package main
import (
"context"
"testing"
)
// TestVerifierProviderSuccess tests successful creation of a verifier.
func TestVerifierProviderSuccess(t *testing.T) {
provider := provider{}
tests := []struct {
name string
ctx context.Context
config map[string]string
}{
{
name: "Successful creation",
ctx: context.Background(),
config: map[string]string{},
},
{
name: "Nil context",
ctx: context.TODO(),
config: map[string]string{},
},
{
name: "Empty config",
ctx: context.Background(),
config: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
verifier, close, err := provider.New(tt.ctx, tt.config)
if err != nil {
t.Fatalf("Expected no error, but got: %v", err)
}
if verifier == nil {
t.Fatal("Expected verifier instance to be non-nil")
}
if close != nil {
if err := close(); err != nil {
t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err)
}
}
})
}
}
// TestVerifierProviderFailure tests cases where verifier creation should fail.
func TestVerifierProviderFailure(t *testing.T) {
provider := provider{}
tests := []struct {
name string
ctx context.Context
config map[string]string
wantErr bool
}{
{
name: "Nil context failure",
ctx: nil,
config: map[string]string{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
verifierInstance, close, err := provider.New(tt.ctx, tt.config)
if (err != nil) != tt.wantErr {
t.Fatalf("Expected error: %v, but got: %v", tt.wantErr, err)
}
if verifierInstance != nil {
t.Fatal("Expected verifier instance to be nil")
}
if close != nil {
if err := close(); err != nil {
t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err)
}
}
})
}
}

View File

@@ -0,0 +1,115 @@
package signvalidator
import (
"context"
"crypto/ed25519"
"encoding/base64"
"fmt"
"strconv"
"strings"
"time"
"golang.org/x/crypto/blake2b"
)
// Config struct for Verifier.
type Config struct {
}
// validator implements the validator interface.
type validator struct {
config *Config
}
// New creates a new Verifier instance.
func New(ctx context.Context, config *Config) (*validator, func() error, error) {
v := &validator{config: config}
return v, nil, nil
}
// Verify checks the signature for the given payload and public key.
func (v *validator) Validate(ctx context.Context, body []byte, header string, publicKeyBase64 string) error {
createdTimestamp, expiredTimestamp, signature, err := parseAuthHeader(header)
if err != nil {
// TODO: Return appropriate error code when Error Code Handling Module is ready
return fmt.Errorf("error parsing header: %w", err)
}
signatureBytes, err := base64.StdEncoding.DecodeString(signature)
if err != nil {
// TODO: Return appropriate error code when Error Code Handling Module is ready
return fmt.Errorf("error decoding signature: %w", err)
}
currentTime := time.Now().Unix()
if createdTimestamp > currentTime || currentTime > expiredTimestamp {
// TODO: Return appropriate error code when Error Code Handling Module is ready
return fmt.Errorf("signature is expired or not yet valid")
}
createdTime := time.Unix(createdTimestamp, 0)
expiredTime := time.Unix(expiredTimestamp, 0)
signingString := hash(body, createdTime.Unix(), expiredTime.Unix())
decodedPublicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
if err != nil {
// TODO: Return appropriate error code when Error Code Handling Module is ready
return fmt.Errorf("error decoding public key: %w", err)
}
if !ed25519.Verify(ed25519.PublicKey(decodedPublicKey), []byte(signingString), signatureBytes) {
// TODO: Return appropriate error code when Error Code Handling Module is ready
return fmt.Errorf("signature verification failed")
}
return nil
}
// parseAuthHeader extracts signature values from the Authorization header.
func parseAuthHeader(header string) (int64, int64, string, error) {
header = strings.TrimPrefix(header, "Signature ")
parts := strings.Split(header, ",")
signatureMap := make(map[string]string)
for _, part := range parts {
keyValue := strings.SplitN(strings.TrimSpace(part), "=", 2)
if len(keyValue) == 2 {
key := strings.TrimSpace(keyValue[0])
value := strings.Trim(keyValue[1], "\"")
signatureMap[key] = value
}
}
createdTimestamp, err := strconv.ParseInt(signatureMap["created"], 10, 64)
if err != nil {
// TODO: Return appropriate error code when Error Code Handling Module is ready
return 0, 0, "", fmt.Errorf("invalid created timestamp: %w", err)
}
expiredTimestamp, err := strconv.ParseInt(signatureMap["expires"], 10, 64)
if err != nil {
// TODO: Return appropriate error code when Error Code Handling Module is ready
return 0, 0, "", fmt.Errorf("invalid expires timestamp: %w", err)
}
signature := signatureMap["signature"]
if signature == "" {
// TODO: Return appropriate error code when Error Code Handling Module is ready
return 0, 0, "", fmt.Errorf("signature missing in header")
}
return createdTimestamp, expiredTimestamp, signature, nil
}
// hash constructs a signing string for verification.
func hash(payload []byte, createdTimestamp, expiredTimestamp int64) string {
hasher, _ := blake2b.New512(nil)
hasher.Write(payload)
hashSum := hasher.Sum(nil)
digestB64 := base64.StdEncoding.EncodeToString(hashSum)
return fmt.Sprintf("(created): %d\n(expires): %d\ndigest: BLAKE-512=%s", createdTimestamp, expiredTimestamp, digestB64)
}

View File

@@ -0,0 +1,147 @@
package signvalidator
import (
"context"
"crypto/ed25519"
"encoding/base64"
"strconv"
"testing"
"time"
)
// generateTestKeyPair generates a new ED25519 key pair for testing.
func generateTestKeyPair() (string, string) {
publicKey, privateKey, _ := ed25519.GenerateKey(nil)
return base64.StdEncoding.EncodeToString(privateKey), base64.StdEncoding.EncodeToString(publicKey)
}
// signTestData creates a valid signature for test cases.
func signTestData(privateKeyBase64 string, body []byte, createdAt, expiresAt int64) string {
privateKeyBytes, _ := base64.StdEncoding.DecodeString(privateKeyBase64)
privateKey := ed25519.PrivateKey(privateKeyBytes)
signingString := hash(body, createdAt, expiresAt)
signature := ed25519.Sign(privateKey, []byte(signingString))
return base64.StdEncoding.EncodeToString(signature)
}
// TestVerifySuccessCases tests all valid signature verification cases.
func TestVerifySuccess(t *testing.T) {
privateKeyBase64, publicKeyBase64 := generateTestKeyPair()
tests := []struct {
name string
body []byte
createdAt int64
expiresAt int64
}{
{
name: "Valid Signature",
body: []byte("Test Payload"),
createdAt: time.Now().Unix(),
expiresAt: time.Now().Unix() + 3600,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
signature := signTestData(privateKeyBase64, tt.body, tt.createdAt, tt.expiresAt)
header := `Signature created="` + strconv.FormatInt(tt.createdAt, 10) +
`", expires="` + strconv.FormatInt(tt.expiresAt, 10) +
`", signature="` + signature + `"`
verifier, close, _ := New(context.Background(), &Config{})
err := verifier.Validate(context.Background(), tt.body, header, publicKeyBase64)
if err != nil {
t.Fatalf("Expected no error, but got: %v", err)
}
if close != nil {
if err := close(); err != nil {
t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err)
}
}
})
}
}
// TestVerifyFailureCases tests all invalid signature verification cases.
func TestVerifyFailure(t *testing.T) {
privateKeyBase64, publicKeyBase64 := generateTestKeyPair()
_, wrongPublicKeyBase64 := generateTestKeyPair()
tests := []struct {
name string
body []byte
header string
pubKey string
}{
{
name: "Missing Authorization Header",
body: []byte("Test Payload"),
header: "",
pubKey: publicKeyBase64,
},
{
name: "Malformed Header",
body: []byte("Test Payload"),
header: `InvalidSignature created="wrong"`,
pubKey: publicKeyBase64,
},
{
name: "Invalid Base64 Signature",
body: []byte("Test Payload"),
header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) +
`", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) +
`", signature="!!INVALIDBASE64!!"`,
pubKey: publicKeyBase64,
},
{
name: "Expired Signature",
body: []byte("Test Payload"),
header: `Signature created="` + strconv.FormatInt(time.Now().Unix()-7200, 10) +
`", expires="` + strconv.FormatInt(time.Now().Unix()-3600, 10) +
`", signature="` + signTestData(privateKeyBase64, []byte("Test Payload"), time.Now().Unix()-7200, time.Now().Unix()-3600) + `"`,
pubKey: publicKeyBase64,
},
{
name: "Invalid Public Key",
body: []byte("Test Payload"),
header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) +
`", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) +
`", signature="` + signTestData(privateKeyBase64, []byte("Test Payload"), time.Now().Unix(), time.Now().Unix()+3600) + `"`,
pubKey: wrongPublicKeyBase64,
},
{
name: "Invalid Expires Timestamp",
body: []byte("Test Payload"),
header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) +
`", expires="invalid_timestamp"`,
pubKey: publicKeyBase64,
},
{
name: "Signature Missing in Headers",
body: []byte("Test Payload"),
header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) +
`", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) + `"`,
pubKey: publicKeyBase64,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
verifier, close, _ := New(context.Background(), &Config{})
err := verifier.Validate(context.Background(), tt.body, tt.header, tt.pubKey)
if err == nil {
t.Fatal("Expected an error but got none")
}
if close != nil {
if err := close(); err != nil {
t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err)
}
}
})
}
}

View File

@@ -112,10 +112,11 @@ func (m *Manager) Publisher(ctx context.Context, cfg *Config) (definition.Publis
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load provider for %s: %w", cfg.ID, err) return nil, fmt.Errorf("failed to load provider for %s: %w", cfg.ID, err)
} }
p, err := pp.New(ctx, cfg.Config) p, closer, err := pp.New(ctx, cfg.Config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.addCloser(closer)
return p, nil return p, nil
} }
@@ -256,8 +257,8 @@ func (m *Manager) Decryptor(ctx context.Context, cfg *Config) (definition.Decryp
return decrypter, nil return decrypter, nil
} }
func (m *Manager) SignValidator(ctx context.Context, cfg *Config) (definition.Verifier, error) { func (m *Manager) SignValidator(ctx context.Context, cfg *Config) (definition.SignValidator, error) {
svp, err := provider[definition.VerifierProvider](m.plugins, cfg.ID) svp, err := provider[definition.SignValidatorProvider](m.plugins, cfg.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load provider for %s: %w", cfg.ID, err) return nil, fmt.Errorf("failed to load provider for %s: %w", cfg.ID, err)
} }

View File

@@ -7,48 +7,9 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/beckn/beckn-onix/pkg/model" "github.com/beckn/beckn-onix/pkg/model"
) )
// Error represents a standardized error response used across the system.
type Error struct {
// Code is a short, machine-readable error code.
Code string `json:"code,omitempty"`
// Message provides a human-readable description of the error.
Message string `json:"message,omitempty"`
// Paths indicates the specific field(s) or endpoint(s) related to the error.
Paths string `json:"paths,omitempty"`
}
// SchemaValidationErr represents a collection of schema validation failures.
type SchemaValidationErr struct {
Errors []Error
}
// Error implements the error interface for SchemaValidationErr.
func (e *SchemaValidationErr) Error() string {
var errorMessages []string
for _, err := range e.Errors {
errorMessages = append(errorMessages, fmt.Sprintf("%s: %s", err.Paths, err.Message))
}
return strings.Join(errorMessages, "; ")
}
// Message represents a standard message structure with acknowledgment and error information.
type Message struct {
// Ack contains the acknowledgment status of the response.
Ack struct {
Status string `json:"status,omitempty"`
} `json:"ack,omitempty"`
// Error holds error details if any occurred during processing.
Error *Error `json:"error,omitempty"`
}
// SendAck sends an acknowledgment response (ACK) to the client. // SendAck sends an acknowledgment response (ACK) to the client.
func SendAck(w http.ResponseWriter) { func SendAck(w http.ResponseWriter) {
resp := &model.Response{ resp := &model.Response{

View File

@@ -126,21 +126,6 @@ func TestSendNack(t *testing.T) {
} }
} }
func TestSchemaValidationErr_Error(t *testing.T) {
// Create sample validation errors
validationErrors := []Error{
{Paths: "name", Message: "Name is required"},
{Paths: "email", Message: "Invalid email format"},
}
err := SchemaValidationErr{Errors: validationErrors}
expected := "name: Name is required; email: Invalid email format"
if err.Error() != expected {
t.Errorf("err.Error() = %s, want %s",
err.Error(), expected)
}
}
func compareJSON(expected, actual map[string]interface{}) bool { func compareJSON(expected, actual map[string]interface{}) bool {
expectedBytes, _ := json.Marshal(expected) expectedBytes, _ := json.Marshal(expected)
actualBytes, _ := json.Marshal(actual) actualBytes, _ := json.Marshal(actual)