diff --git a/pkg/model/model.go b/pkg/model/model.go index 8555046..13802f4 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -38,14 +38,18 @@ const ( UnaAuthorizedHeaderGateway string = "Proxy-Authenticate" ) +// ContextKey is a custom type used as a key for storing and retrieving values in a context. type ContextKey string const ( - // MsgIDKey is the context key used to store and retrieve the message ID in a request context. - MsgIDKey = ContextKey("message_id") + // ContextKeyMsgID is the context key used to store and retrieve the message ID in a request context. + ContextKeyMsgID ContextKey = "message_id" - // SubscriberIDKey is the context key used to store and retrieve the subscriber ID in a request context. - SubscriberIDKey = ContextKey("subscriber_id") + // ContextKeySubscriberID is the context key used to store and retrieve the subscriber ID in a request context. + ContextKeySubscriberID ContextKey = "subscriber_id" + + // ContextKeyModelID is the context key for storing and retrieving the model ID from a request context. + ContextKeyModelID ContextKey = "model_id" ) // Role defines the type of participant in the network. diff --git a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go index 441a3a7..6653051 100644 --- a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go +++ b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go @@ -13,12 +13,16 @@ import ( "github.com/beckn/beckn-onix/pkg/model" ) +// Config represents the configuration for the request preprocessor middleware. type Config struct { Role string } -const contextKey = "context" +// contextKeyContext is the typed context key for request context. +var contextKeyContext = model.ContextKey("context") +// NewPreProcessor returns a middleware that processes the incoming request, +// extracts the context field from the body, and adds relevant values (like subscriber ID). func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) { if err := validateConfig(cfg); err != nil { return nil, err @@ -40,7 +44,7 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) { // Extract context from request. reqContext, ok := req["context"].(map[string]interface{}) if !ok { - http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKey), http.StatusBadRequest) + http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKeyContext), http.StatusBadRequest) return } var subID any @@ -51,8 +55,8 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) { subID = reqContext["bpp_id"] } if subID != nil { - log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.SubscriberIDKey, subID) - ctx = context.WithValue(ctx, model.SubscriberIDKey, subID) + log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.ContextKeySubscriberID, subID) + ctx = context.WithValue(ctx, model.ContextKeySubscriberID, subID) } r.Body = io.NopCloser(bytes.NewBuffer(body)) diff --git a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go index 20beb18..ed78d2d 100644 --- a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go +++ b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go @@ -73,11 +73,11 @@ func TestNewPreProcessorSuccessCases(t *testing.T) { dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - gotSubID = ctx.Value(model.SubscriberIDKey) + gotSubID = ctx.Value(model.ContextKeySubscriberID) w.WriteHeader(http.StatusOK) // Verify subscriber ID - subID := ctx.Value(model.SubscriberIDKey) + subID := ctx.Value(model.ContextKeySubscriberID) if subID == nil { t.Errorf("Expected subscriber ID but got none %s", ctx) return @@ -233,53 +233,37 @@ func TestNewPreProcessorErrorCases(t *testing.T) { } } -// Mock handler to capture processed request context -func captureContextHandler(t *testing.T, expectedSubID string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Retrieve subscriber_id from context - subID, ok := r.Context().Value(model.SubscriberIDKey).(string) - if !ok { - t.Error("subscriber_id should be set in context") - } else if subID != expectedSubID { - t.Errorf("expected subscriber_id %s, got %s", expectedSubID, subID) - } +func TestNewPreProcessorAddsSubscriberIDToContext(t *testing.T) { + cfg := &Config{Role: "bap"} + middleware, err := NewPreProcessor(cfg) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + samplePayload := map[string]interface{}{ + "context": map[string]interface{}{ + "bap_id": "bap.example.com", + }, + } + bodyBytes, _ := json.Marshal(samplePayload) + + var receivedSubscriberID interface{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedSubscriberID = r.Context().Value(model.ContextKeySubscriberID) w.WriteHeader(http.StatusOK) }) -} -// Test NewPreProcessor middleware -func TestNewPreProcessor(t *testing.T) { - testConfig := &Config{ - Role: "bap", - } - - testPayload := `{ - "context": { - "bap_id": "test-bap-id" - } - }` - - preProcessor, err := NewPreProcessor(testConfig) - if err != nil { - t.Fatalf("expected no error, got: %v", err) - } - - // Create test request - req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewBufferString(testPayload)) + req := httptest.NewRequest("POST", "/", strings.NewReader(string(bodyBytes))) req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() - // Create response recorder - recorder := httptest.NewRecorder() + middleware(handler).ServeHTTP(rr, req) - // Wrap handler with middleware - handler := preProcessor(captureContextHandler(t, "test-bap-id")) - - // Serve request - handler.ServeHTTP(recorder, req) - - // Check response status - if recorder.Code != http.StatusOK { - t.Errorf("expected status %d, got %d", http.StatusOK, recorder.Code) + if rr.Code != http.StatusOK { + t.Fatalf("Expected status 200 OK, got %d", rr.Code) + } + if receivedSubscriberID != "bap.example.com" { + t.Errorf("Expected subscriber ID 'bap.example.com', got %v", receivedSubscriberID) } } diff --git a/pkg/response/response.go b/pkg/response/response.go index a5ab0c4..f9ac9d5 100644 --- a/pkg/response/response.go +++ b/pkg/response/response.go @@ -47,8 +47,8 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i w.WriteHeader(status) _, er := w.Write(data) if er != nil { - fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.MsgIDKey)) - http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), http.StatusInternalServerError) + fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.ContextKeyMsgID)) + http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)), http.StatusInternalServerError) return } } @@ -57,7 +57,7 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i func internalServerError(ctx context.Context) *model.Error { return &model.Error{ Code: http.StatusText(http.StatusInternalServerError), - Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), + Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)), } } diff --git a/pkg/response/response_test.go b/pkg/response/response_test.go index 96f1caa..f0c1c6d 100644 --- a/pkg/response/response_test.go +++ b/pkg/response/response_test.go @@ -46,7 +46,7 @@ func TestSendAck(t *testing.T) { } func TestSendNack(t *testing.T) { - ctx := context.WithValue(context.Background(), model.MsgIDKey, "123456") + ctx := context.WithValue(context.Background(), model.ContextKeyMsgID, "123456") tests := []struct { name string @@ -197,7 +197,7 @@ func TestNack_1(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.WithValue(req.Context(), model.MsgIDKey, "12345") + ctx := context.WithValue(req.Context(), model.ContextKeyMsgID, "12345") var w http.ResponseWriter if tt.useBadWrite {