code fix as per the review comments

This commit is contained in:
MohitKatare-protean
2025-04-04 15:45:15 +05:30
parent 49e460f61d
commit fc296b8ef3
5 changed files with 48 additions and 56 deletions

View File

@@ -38,14 +38,18 @@ const (
UnaAuthorizedHeaderGateway string = "Proxy-Authenticate" UnaAuthorizedHeaderGateway string = "Proxy-Authenticate"
) )
// ContextKey is a custom type used as a key for storing and retrieving values in a context.
type ContextKey string type ContextKey string
const ( const (
// MsgIDKey is the context key used to store and retrieve the message ID in a request context. // ContextKeyMsgID is the context key used to store and retrieve the message ID in a request context.
MsgIDKey = ContextKey("message_id") ContextKeyMsgID ContextKey = "message_id"
// SubscriberIDKey is the context key used to store and retrieve the subscriber ID in a request context. // ContextKeySubscriberID is the context key used to store and retrieve the subscriber ID in a request context.
SubscriberIDKey = ContextKey("subscriber_id") ContextKeySubscriberID ContextKey = "subscriber_id"
// ContextKeyModelID is the context key for storing and retrieving the model ID from a request context.
ContextKeyModelID ContextKey = "model_id"
) )
// Role defines the type of participant in the network. // Role defines the type of participant in the network.

View File

@@ -13,12 +13,16 @@ import (
"github.com/beckn/beckn-onix/pkg/model" "github.com/beckn/beckn-onix/pkg/model"
) )
// Config represents the configuration for the request preprocessor middleware.
type Config struct { type Config struct {
Role string Role string
} }
const contextKey = "context" // contextKeyContext is the typed context key for request context.
var contextKeyContext = model.ContextKey("context")
// NewPreProcessor returns a middleware that processes the incoming request,
// extracts the context field from the body, and adds relevant values (like subscriber ID).
func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) { func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
if err := validateConfig(cfg); err != nil { if err := validateConfig(cfg); err != nil {
return nil, err return nil, err
@@ -40,7 +44,7 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
// Extract context from request. // Extract context from request.
reqContext, ok := req["context"].(map[string]interface{}) reqContext, ok := req["context"].(map[string]interface{})
if !ok { if !ok {
http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKey), http.StatusBadRequest) http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKeyContext), http.StatusBadRequest)
return return
} }
var subID any var subID any
@@ -51,8 +55,8 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
subID = reqContext["bpp_id"] subID = reqContext["bpp_id"]
} }
if subID != nil { if subID != nil {
log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.SubscriberIDKey, subID) log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.ContextKeySubscriberID, subID)
ctx = context.WithValue(ctx, model.SubscriberIDKey, subID) ctx = context.WithValue(ctx, model.ContextKeySubscriberID, subID)
} }
r.Body = io.NopCloser(bytes.NewBuffer(body)) r.Body = io.NopCloser(bytes.NewBuffer(body))

View File

@@ -73,11 +73,11 @@ func TestNewPreProcessorSuccessCases(t *testing.T) {
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
gotSubID = ctx.Value(model.SubscriberIDKey) gotSubID = ctx.Value(model.ContextKeySubscriberID)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
// Verify subscriber ID // Verify subscriber ID
subID := ctx.Value(model.SubscriberIDKey) subID := ctx.Value(model.ContextKeySubscriberID)
if subID == nil { if subID == nil {
t.Errorf("Expected subscriber ID but got none %s", ctx) t.Errorf("Expected subscriber ID but got none %s", ctx)
return return
@@ -233,53 +233,37 @@ func TestNewPreProcessorErrorCases(t *testing.T) {
} }
} }
// Mock handler to capture processed request context func TestNewPreProcessorAddsSubscriberIDToContext(t *testing.T) {
func captureContextHandler(t *testing.T, expectedSubID string) http.Handler { cfg := &Config{Role: "bap"}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { middleware, err := NewPreProcessor(cfg)
// Retrieve subscriber_id from context if err != nil {
subID, ok := r.Context().Value(model.SubscriberIDKey).(string) t.Fatalf("Expected no error, got %v", err)
if !ok {
t.Error("subscriber_id should be set in context")
} else if subID != expectedSubID {
t.Errorf("expected subscriber_id %s, got %s", expectedSubID, subID)
} }
samplePayload := map[string]interface{}{
"context": map[string]interface{}{
"bap_id": "bap.example.com",
},
}
bodyBytes, _ := json.Marshal(samplePayload)
var receivedSubscriberID interface{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedSubscriberID = r.Context().Value(model.ContextKeySubscriberID)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
}
// Test NewPreProcessor middleware req := httptest.NewRequest("POST", "/", strings.NewReader(string(bodyBytes)))
func TestNewPreProcessor(t *testing.T) {
testConfig := &Config{
Role: "bap",
}
testPayload := `{
"context": {
"bap_id": "test-bap-id"
}
}`
preProcessor, err := NewPreProcessor(testConfig)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Create test request
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewBufferString(testPayload))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
rr := httptest.NewRecorder()
// Create response recorder middleware(handler).ServeHTTP(rr, req)
recorder := httptest.NewRecorder()
// Wrap handler with middleware if rr.Code != http.StatusOK {
handler := preProcessor(captureContextHandler(t, "test-bap-id")) t.Fatalf("Expected status 200 OK, got %d", rr.Code)
}
// Serve request if receivedSubscriberID != "bap.example.com" {
handler.ServeHTTP(recorder, req) t.Errorf("Expected subscriber ID 'bap.example.com', got %v", receivedSubscriberID)
// Check response status
if recorder.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, recorder.Code)
} }
} }

View File

@@ -47,8 +47,8 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i
w.WriteHeader(status) w.WriteHeader(status)
_, er := w.Write(data) _, er := w.Write(data)
if er != nil { if er != nil {
fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.MsgIDKey)) fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.ContextKeyMsgID))
http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)), http.StatusInternalServerError)
return return
} }
} }
@@ -57,7 +57,7 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i
func internalServerError(ctx context.Context) *model.Error { func internalServerError(ctx context.Context) *model.Error {
return &model.Error{ return &model.Error{
Code: http.StatusText(http.StatusInternalServerError), Code: http.StatusText(http.StatusInternalServerError),
Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)),
} }
} }

View File

@@ -46,7 +46,7 @@ func TestSendAck(t *testing.T) {
} }
func TestSendNack(t *testing.T) { func TestSendNack(t *testing.T) {
ctx := context.WithValue(context.Background(), model.MsgIDKey, "123456") ctx := context.WithValue(context.Background(), model.ContextKeyMsgID, "123456")
tests := []struct { tests := []struct {
name string name string
@@ -197,7 +197,7 @@ func TestNack_1(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ctx := context.WithValue(req.Context(), model.MsgIDKey, "12345") ctx := context.WithValue(req.Context(), model.ContextKeyMsgID, "12345")
var w http.ResponseWriter var w http.ResponseWriter
if tt.useBadWrite { if tt.useBadWrite {