code fix as per the review comments
This commit is contained in:
@@ -38,14 +38,18 @@ const (
|
|||||||
UnaAuthorizedHeaderGateway string = "Proxy-Authenticate"
|
UnaAuthorizedHeaderGateway string = "Proxy-Authenticate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ContextKey is a custom type used as a key for storing and retrieving values in a context.
|
||||||
type ContextKey string
|
type ContextKey string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// MsgIDKey is the context key used to store and retrieve the message ID in a request context.
|
// ContextKeyMsgID is the context key used to store and retrieve the message ID in a request context.
|
||||||
MsgIDKey = ContextKey("message_id")
|
ContextKeyMsgID ContextKey = "message_id"
|
||||||
|
|
||||||
// SubscriberIDKey is the context key used to store and retrieve the subscriber ID in a request context.
|
// ContextKeySubscriberID is the context key used to store and retrieve the subscriber ID in a request context.
|
||||||
SubscriberIDKey = ContextKey("subscriber_id")
|
ContextKeySubscriberID ContextKey = "subscriber_id"
|
||||||
|
|
||||||
|
// ContextKeyModelID is the context key for storing and retrieving the model ID from a request context.
|
||||||
|
ContextKeyModelID ContextKey = "model_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Role defines the type of participant in the network.
|
// Role defines the type of participant in the network.
|
||||||
|
|||||||
@@ -13,12 +13,16 @@ import (
|
|||||||
"github.com/beckn/beckn-onix/pkg/model"
|
"github.com/beckn/beckn-onix/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Config represents the configuration for the request preprocessor middleware.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Role string
|
Role string
|
||||||
}
|
}
|
||||||
|
|
||||||
const contextKey = "context"
|
// contextKeyContext is the typed context key for request context.
|
||||||
|
var contextKeyContext = model.ContextKey("context")
|
||||||
|
|
||||||
|
// NewPreProcessor returns a middleware that processes the incoming request,
|
||||||
|
// extracts the context field from the body, and adds relevant values (like subscriber ID).
|
||||||
func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
|
func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
|
||||||
if err := validateConfig(cfg); err != nil {
|
if err := validateConfig(cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -40,7 +44,7 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
|
|||||||
// Extract context from request.
|
// Extract context from request.
|
||||||
reqContext, ok := req["context"].(map[string]interface{})
|
reqContext, ok := req["context"].(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKey), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKeyContext), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var subID any
|
var subID any
|
||||||
@@ -51,8 +55,8 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
|
|||||||
subID = reqContext["bpp_id"]
|
subID = reqContext["bpp_id"]
|
||||||
}
|
}
|
||||||
if subID != nil {
|
if subID != nil {
|
||||||
log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.SubscriberIDKey, subID)
|
log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.ContextKeySubscriberID, subID)
|
||||||
ctx = context.WithValue(ctx, model.SubscriberIDKey, subID)
|
ctx = context.WithValue(ctx, model.ContextKeySubscriberID, subID)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||||
|
|||||||
@@ -73,11 +73,11 @@ func TestNewPreProcessorSuccessCases(t *testing.T) {
|
|||||||
|
|
||||||
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
gotSubID = ctx.Value(model.SubscriberIDKey)
|
gotSubID = ctx.Value(model.ContextKeySubscriberID)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
// Verify subscriber ID
|
// Verify subscriber ID
|
||||||
subID := ctx.Value(model.SubscriberIDKey)
|
subID := ctx.Value(model.ContextKeySubscriberID)
|
||||||
if subID == nil {
|
if subID == nil {
|
||||||
t.Errorf("Expected subscriber ID but got none %s", ctx)
|
t.Errorf("Expected subscriber ID but got none %s", ctx)
|
||||||
return
|
return
|
||||||
@@ -233,53 +233,37 @@ func TestNewPreProcessorErrorCases(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mock handler to capture processed request context
|
func TestNewPreProcessorAddsSubscriberIDToContext(t *testing.T) {
|
||||||
func captureContextHandler(t *testing.T, expectedSubID string) http.Handler {
|
cfg := &Config{Role: "bap"}
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
middleware, err := NewPreProcessor(cfg)
|
||||||
// Retrieve subscriber_id from context
|
if err != nil {
|
||||||
subID, ok := r.Context().Value(model.SubscriberIDKey).(string)
|
t.Fatalf("Expected no error, got %v", err)
|
||||||
if !ok {
|
}
|
||||||
t.Error("subscriber_id should be set in context")
|
|
||||||
} else if subID != expectedSubID {
|
|
||||||
t.Errorf("expected subscriber_id %s, got %s", expectedSubID, subID)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
samplePayload := map[string]interface{}{
|
||||||
|
"context": map[string]interface{}{
|
||||||
|
"bap_id": "bap.example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
bodyBytes, _ := json.Marshal(samplePayload)
|
||||||
|
|
||||||
|
var receivedSubscriberID interface{}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedSubscriberID = r.Context().Value(model.ContextKeySubscriberID)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
}
|
|
||||||
|
|
||||||
// Test NewPreProcessor middleware
|
req := httptest.NewRequest("POST", "/", strings.NewReader(string(bodyBytes)))
|
||||||
func TestNewPreProcessor(t *testing.T) {
|
|
||||||
testConfig := &Config{
|
|
||||||
Role: "bap",
|
|
||||||
}
|
|
||||||
|
|
||||||
testPayload := `{
|
|
||||||
"context": {
|
|
||||||
"bap_id": "test-bap-id"
|
|
||||||
}
|
|
||||||
}`
|
|
||||||
|
|
||||||
preProcessor, err := NewPreProcessor(testConfig)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("expected no error, got: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create test request
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewBufferString(testPayload))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
// Create response recorder
|
middleware(handler).ServeHTTP(rr, req)
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
|
|
||||||
// Wrap handler with middleware
|
if rr.Code != http.StatusOK {
|
||||||
handler := preProcessor(captureContextHandler(t, "test-bap-id"))
|
t.Fatalf("Expected status 200 OK, got %d", rr.Code)
|
||||||
|
}
|
||||||
// Serve request
|
if receivedSubscriberID != "bap.example.com" {
|
||||||
handler.ServeHTTP(recorder, req)
|
t.Errorf("Expected subscriber ID 'bap.example.com', got %v", receivedSubscriberID)
|
||||||
|
|
||||||
// Check response status
|
|
||||||
if recorder.Code != http.StatusOK {
|
|
||||||
t.Errorf("expected status %d, got %d", http.StatusOK, recorder.Code)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i
|
|||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
_, er := w.Write(data)
|
_, er := w.Write(data)
|
||||||
if er != nil {
|
if er != nil {
|
||||||
fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.MsgIDKey))
|
fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.ContextKeyMsgID))
|
||||||
http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -57,7 +57,7 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i
|
|||||||
func internalServerError(ctx context.Context) *model.Error {
|
func internalServerError(ctx context.Context) *model.Error {
|
||||||
return &model.Error{
|
return &model.Error{
|
||||||
Code: http.StatusText(http.StatusInternalServerError),
|
Code: http.StatusText(http.StatusInternalServerError),
|
||||||
Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)),
|
Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func TestSendAck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendNack(t *testing.T) {
|
func TestSendNack(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), model.MsgIDKey, "123456")
|
ctx := context.WithValue(context.Background(), model.ContextKeyMsgID, "123456")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -197,7 +197,7 @@ func TestNack_1(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(req.Context(), model.MsgIDKey, "12345")
|
ctx := context.WithValue(req.Context(), model.ContextKeyMsgID, "12345")
|
||||||
|
|
||||||
var w http.ResponseWriter
|
var w http.ResponseWriter
|
||||||
if tt.useBadWrite {
|
if tt.useBadWrite {
|
||||||
|
|||||||
Reference in New Issue
Block a user