resolved reqpreprocessing review comments

1. added logs for contextKey parsing and unmarshaling yml file.
2. resolved router test case issue.
2. resolved logger test case issues.
This commit is contained in:
MohitKatare-protean
2025-04-06 17:09:59 +05:30
parent 278e217c64
commit ec62a3242b
7 changed files with 176 additions and 70 deletions

View File

@@ -101,7 +101,7 @@ func (h *stdHandler) stepCtx(r *http.Request, rh http.Header) (*model.StepContex
// subID retrieves the subscriber ID from the request context. // subID retrieves the subscriber ID from the request context.
func (h *stdHandler) subID(ctx context.Context) string { func (h *stdHandler) subID(ctx context.Context) string {
rSubID, ok := ctx.Value("subscriber_id").(string) rSubID, ok := ctx.Value(model.ContextKeySubscriberID).(string)
if ok { if ok {
return rSubID return rSubID
} }

View File

@@ -12,6 +12,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/beckn/beckn-onix/pkg/model"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
) )
@@ -54,7 +55,7 @@ var logLevels = map[level]zerolog.Level{
type Config struct { type Config struct {
Level level `yaml:"level"` Level level `yaml:"level"`
Destinations []destination `yaml:"destinations"` Destinations []destination `yaml:"destinations"`
ContextKeys []string `yaml:"contextKeys"` ContextKeys []model.ContextKey `yaml:"contextKeys"`
} }
var ( var (
@@ -277,7 +278,7 @@ func addCtx(ctx context.Context, event *zerolog.Event) {
if !ok { if !ok {
continue continue
} }
keyStr := key keyStr := string(key)
event.Any(keyStr, val) event.Any(keyStr, val)
} }
} }

View File

@@ -13,6 +13,8 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/beckn/beckn-onix/pkg/model"
) )
type ctxKey any type ctxKey any
@@ -63,7 +65,12 @@ func setupLogger(t *testing.T, l level) string {
}, },
}, },
}, },
ContextKeys: []string{"userID", "requestID"}, ContextKeys: []model.ContextKey{
model.ContextKeyTxnID,
model.ContextKeyMsgID,
model.ContextKeySubscriberID,
model.ContextKeyModelID,
},
} }
// Initialize logger with the given config // Initialize logger with the given config
@@ -97,7 +104,7 @@ func parseLogLine(t *testing.T, line string) map[string]interface{} {
func TestDebug(t *testing.T) { func TestDebug(t *testing.T) {
t.Helper() t.Helper()
logPath := setupLogger(t, DebugLevel) logPath := setupLogger(t, DebugLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Debug(ctx, "Debug message") Debug(ctx, "Debug message")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -105,7 +112,7 @@ func TestDebug(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "debug", "level": "debug",
"userID": "12345", "subscriber_id": "12345",
"message": "Debug message", "message": "Debug message",
} }
@@ -129,7 +136,7 @@ func TestDebug(t *testing.T) {
func TestInfo(t *testing.T) { func TestInfo(t *testing.T) {
logPath := setupLogger(t, InfoLevel) logPath := setupLogger(t, InfoLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Info(ctx, "Info message") Info(ctx, "Info message")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -137,7 +144,7 @@ func TestInfo(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "info", "level": "info",
"userID": "12345", "subscriber_id": "12345",
"message": "Info message", "message": "Info message",
} }
@@ -161,7 +168,7 @@ func TestInfo(t *testing.T) {
func TestWarn(t *testing.T) { func TestWarn(t *testing.T) {
logPath := setupLogger(t, WarnLevel) logPath := setupLogger(t, WarnLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Warn(ctx, "Warning message") Warn(ctx, "Warning message")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -169,7 +176,7 @@ func TestWarn(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "warn", "level": "warn",
"userID": "12345", "subscriber_id": "12345",
"message": "Warning message", "message": "Warning message",
} }
@@ -189,7 +196,7 @@ func TestWarn(t *testing.T) {
func TestError(t *testing.T) { func TestError(t *testing.T) {
logPath := setupLogger(t, ErrorLevel) logPath := setupLogger(t, ErrorLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Error(ctx, fmt.Errorf("test error"), "Error message") Error(ctx, fmt.Errorf("test error"), "Error message")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -197,7 +204,7 @@ func TestError(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "error", "level": "error",
"userID": "12345", "subscriber_id": "12345",
"message": "Error message", "message": "Error message",
"error": "test error", "error": "test error",
} }
@@ -277,7 +284,7 @@ func TestResponse(t *testing.T) {
func TestFatal(t *testing.T) { func TestFatal(t *testing.T) {
logPath := setupLogger(t, FatalLevel) logPath := setupLogger(t, FatalLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Fatal(ctx, fmt.Errorf("fatal error"), "Fatal message") Fatal(ctx, fmt.Errorf("fatal error"), "Fatal message")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -285,7 +292,7 @@ func TestFatal(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "fatal", "level": "fatal",
"userID": "12345", "subscriber_id": "12345",
"message": "Fatal message", "message": "Fatal message",
"error": "fatal error", "error": "fatal error",
} }
@@ -308,7 +315,7 @@ func TestFatal(t *testing.T) {
func TestPanic(t *testing.T) { func TestPanic(t *testing.T) {
logPath := setupLogger(t, PanicLevel) logPath := setupLogger(t, PanicLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Panic(ctx, fmt.Errorf("panic error"), "Panic message") Panic(ctx, fmt.Errorf("panic error"), "Panic message")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -316,7 +323,7 @@ func TestPanic(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "panic", "level": "panic",
"userID": "12345", "subscriber_id": "12345",
"message": "Panic message", "message": "Panic message",
"error": "panic error", "error": "panic error",
} }
@@ -339,7 +346,7 @@ func TestPanic(t *testing.T) {
func TestDebugf(t *testing.T) { func TestDebugf(t *testing.T) {
logPath := setupLogger(t, DebugLevel) logPath := setupLogger(t, DebugLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Debugf(ctx, "Debugf message: %s", "test") Debugf(ctx, "Debugf message: %s", "test")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -347,7 +354,7 @@ func TestDebugf(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "debug", "level": "debug",
"userID": "12345", "subscriber_id": "12345",
"message": "Debugf message: test", "message": "Debugf message: test",
} }
@@ -370,7 +377,7 @@ func TestDebugf(t *testing.T) {
func TestInfof(t *testing.T) { func TestInfof(t *testing.T) {
logPath := setupLogger(t, InfoLevel) logPath := setupLogger(t, InfoLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Infof(ctx, "Infof message: %s", "test") Infof(ctx, "Infof message: %s", "test")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -378,7 +385,7 @@ func TestInfof(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "info", "level": "info",
"userID": "12345", "subscriber_id": "12345",
"message": "Infof message: test", "message": "Infof message: test",
} }
@@ -400,7 +407,7 @@ func TestInfof(t *testing.T) {
func TestWarnf(t *testing.T) { func TestWarnf(t *testing.T) {
logPath := setupLogger(t, WarnLevel) logPath := setupLogger(t, WarnLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Warnf(ctx, "Warnf message: %s", "test") Warnf(ctx, "Warnf message: %s", "test")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
if len(lines) == 0 { if len(lines) == 0 {
@@ -408,7 +415,7 @@ func TestWarnf(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "warn", "level": "warn",
"userID": "12345", "subscriber_id": "12345",
"message": "Warnf message: test", "message": "Warnf message: test",
} }
@@ -430,7 +437,7 @@ func TestWarnf(t *testing.T) {
func TestErrorf(t *testing.T) { func TestErrorf(t *testing.T) {
logPath := setupLogger(t, ErrorLevel) logPath := setupLogger(t, ErrorLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
err := fmt.Errorf("error message") err := fmt.Errorf("error message")
Errorf(ctx, err, "Errorf message: %s", "test") Errorf(ctx, err, "Errorf message: %s", "test")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
@@ -439,7 +446,7 @@ func TestErrorf(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "error", "level": "error",
"userID": "12345", "subscriber_id": "12345",
"message": "Errorf message: test", "message": "Errorf message: test",
"error": "error message", "error": "error message",
} }
@@ -462,7 +469,7 @@ func TestErrorf(t *testing.T) {
func TestFatalf(t *testing.T) { func TestFatalf(t *testing.T) {
logPath := setupLogger(t, FatalLevel) logPath := setupLogger(t, FatalLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
err := fmt.Errorf("fatal error") err := fmt.Errorf("fatal error")
Fatalf(ctx, err, "Fatalf message: %s", "test") Fatalf(ctx, err, "Fatalf message: %s", "test")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
@@ -471,7 +478,7 @@ func TestFatalf(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "fatal", "level": "fatal",
"userID": "12345", "subscriber_id": "12345",
"message": "Fatalf message: test", "message": "Fatalf message: test",
"error": "fatal error", "error": "fatal error",
} }
@@ -494,7 +501,7 @@ func TestFatalf(t *testing.T) {
func TestPanicf(t *testing.T) { func TestPanicf(t *testing.T) {
logPath := setupLogger(t, PanicLevel) logPath := setupLogger(t, PanicLevel)
ctx := context.WithValue(context.Background(), userID, "12345") ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
err := fmt.Errorf("panic error") err := fmt.Errorf("panic error")
Panicf(ctx, err, "Panicf message: %s", "test") Panicf(ctx, err, "Panicf message: %s", "test")
lines := readLogFile(t, logPath) lines := readLogFile(t, logPath)
@@ -504,7 +511,7 @@ func TestPanicf(t *testing.T) {
} }
expected := map[string]interface{}{ expected := map[string]interface{}{
"level": "panic", "level": "panic",
"userID": "12345", "subscriber_id": "12345",
"message": "Panicf message: test", "message": "Panicf message: test",
"error": "panic error", "error": "panic error",
} }

View File

@@ -3,8 +3,8 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"testing"
"net/http" "net/http"
"testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@@ -198,3 +198,54 @@ func TestSchemaValidationErr_BecknError_NoErrors(t *testing.T) {
t.Errorf("beErr.Code = %s, want %s", beErr.Code, expectedCode) t.Errorf("beErr.Code = %s, want %s", beErr.Code, expectedCode)
} }
} }
func TestParseContextKey_ValidKeys(t *testing.T) {
tests := []struct {
input string
expected ContextKey
}{
{"message_id", ContextKeyMsgID},
{"subscriber_id", ContextKeySubscriberID},
{"model_id", ContextKeyModelID},
}
for _, tt := range tests {
key, err := ParseContextKey(tt.input)
if err != nil {
t.Errorf("unexpected error for input %s: %v", tt.input, err)
}
if key != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, key)
}
}
}
func TestParseContextKey_InvalidKey(t *testing.T) {
_, err := ParseContextKey("invalid_key")
if err == nil {
t.Error("expected error for invalid context key, got nil")
}
}
func TestContextKey_UnmarshalYAML_Valid(t *testing.T) {
yamlData := []byte("message_id")
var key ContextKey
err := yaml.Unmarshal(yamlData, &key)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if key != ContextKeyMsgID {
t.Errorf("expected %s, got %s", ContextKeyMsgID, key)
}
}
func TestContextKey_UnmarshalYAML_Invalid(t *testing.T) {
yamlData := []byte("invalid_key")
var key ContextKey
err := yaml.Unmarshal(yamlData, &key)
if err == nil {
t.Error("expected error for invalid context key, got nil")
}
}

View File

@@ -42,6 +42,9 @@ const (
type ContextKey string type ContextKey string
const ( const (
// ContextKeyTxnID is the context key used to store and retrieve the transaction ID in a request context.
ContextKeyTxnID ContextKey = "transaction_id"
// ContextKeyMsgID is the context key used to store and retrieve the message ID in a request context. // ContextKeyMsgID is the context key used to store and retrieve the message ID in a request context.
ContextKeyMsgID ContextKey = "message_id" ContextKeyMsgID ContextKey = "message_id"
@@ -52,6 +55,38 @@ const (
ContextKeyModelID ContextKey = "model_id" ContextKeyModelID ContextKey = "model_id"
) )
var contextKeys = map[string]ContextKey{
"message_id": ContextKeyMsgID,
"subscriber_id": ContextKeySubscriberID,
"model_id": ContextKeyModelID,
}
// ParseContextKey converts a string into a valid ContextKey.
func ParseContextKey(v string) (ContextKey, error) {
key, ok := contextKeys[v]
if !ok {
return "", fmt.Errorf("invalid context key: %s", key)
}
return key, nil
}
// UnmarshalYAML ensures that only known context keys are accepted during YAML unmarshalling.
func (k *ContextKey) UnmarshalYAML(unmarshal func(interface{}) error) error {
var keyStr string
if err := unmarshal(&keyStr); err != nil {
return err
}
parsedKey, err := ParseContextKey(keyStr)
if err != nil {
return err
}
*k = parsedKey
return nil
}
// Role defines the type of participant in the network. // Role defines the type of participant in the network.
type Role string type Role string

View File

@@ -16,6 +16,7 @@ import (
// Config represents the configuration for the request preprocessor middleware. // Config represents the configuration for the request preprocessor middleware.
type Config struct { type Config struct {
Role string Role string
ContextKeys []string
} }
const contextKey = "context" const contextKey = "context"
@@ -57,7 +58,12 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.ContextKeySubscriberID, subID) log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.ContextKeySubscriberID, subID)
ctx = context.WithValue(ctx, model.ContextKeySubscriberID, subID) ctx = context.WithValue(ctx, model.ContextKeySubscriberID, subID)
} }
for _, key := range cfg.ContextKeys {
ctxKey, _ := model.ParseContextKey(key)
if v, ok := reqContext[key]; ok {
ctx = context.WithValue(ctx, ctxKey, v)
}
}
r.Body = io.NopCloser(bytes.NewBuffer(body)) r.Body = io.NopCloser(bytes.NewBuffer(body))
r.ContentLength = int64(len(body)) r.ContentLength = int64(len(body))
r = r.WithContext(ctx) r = r.WithContext(ctx)
@@ -74,5 +80,11 @@ func validateConfig(cfg *Config) error {
if cfg.Role != "bap" && cfg.Role != "bpp" { if cfg.Role != "bap" && cfg.Role != "bpp" {
return errors.New("role must be either 'bap' or 'bpp'") return errors.New("role must be either 'bap' or 'bpp'")
} }
for _, key := range cfg.ContextKeys {
if _, err := model.ParseContextKey(key); err != nil {
return err
}
}
return nil return nil
} }

View File

@@ -296,7 +296,7 @@ func TestValidateRulesFailure(t *testing.T) {
Endpoints: []string{"search"}, Endpoints: []string{"search"},
}, },
}, },
wantErr: `invalid URI htp:// invalid-url.com in request body for url: invalid URL - htp:// invalid-url.com: parse "htp:// invalid-url.com": invalid character " " in host name`, wantErr: `invalid URL - htp:// invalid-url.com: parse "htp:// invalid-url.com": invalid character " " in host name`,
}, },
{ {
name: "Missing topic_id for targetType: publisher", name: "Missing topic_id for targetType: publisher",
@@ -321,12 +321,12 @@ func TestValidateRulesFailure(t *testing.T) {
Version: "1.0.0", Version: "1.0.0",
TargetType: "bpp", TargetType: "bpp",
Target: target{ Target: target{
URL: "htp://invalid-url.com", // Invalid URL URL: "htp:// invalid-url.com", // Invalid URL
}, },
Endpoints: []string{"search"}, Endpoints: []string{"search"},
}, },
}, },
wantErr: "invalid URI htp://invalid-url.com in request body for bpp: URL 'htp://invalid-url.com' must use https scheme", wantErr: `invalid URL - htp:// invalid-url.com defined in routing config for target type bpp: parse "htp:// invalid-url.com": invalid character " " in host name`,
}, },
{ {
name: "Invalid URL for BAP targetType", name: "Invalid URL for BAP targetType",
@@ -336,12 +336,12 @@ func TestValidateRulesFailure(t *testing.T) {
Version: "1.0.0", Version: "1.0.0",
TargetType: "bap", TargetType: "bap",
Target: target{ Target: target{
URL: "http://[invalid].com", // Invalid host URL: "http:// [invalid].com", // Invalid host
}, },
Endpoints: []string{"search"}, Endpoints: []string{"search"},
}, },
}, },
wantErr: "invalid URI http://[invalid].com in request body for bap", wantErr: `invalid URL - http:// [invalid].com defined in routing config for target type bap: parse "http:// [invalid].com": invalid character " " in host name`,
}, },
} }
@@ -464,8 +464,8 @@ func TestRouteFailure(t *testing.T) {
name: "Invalid bpp_uri format in request", name: "Invalid bpp_uri format in request",
configFile: "bap_caller.yaml", configFile: "bap_caller.yaml",
url: "https://example.com/v1/ondc/select", url: "https://example.com/v1/ondc/select",
body: `{"context": {"domain": "ONDC:TRV10", "version": "2.0.0", "bpp_uri": "htp://invalid-url"}}`, // Invalid scheme (htp instead of http) body: `{"context": {"domain": "ONDC:TRV10", "version": "2.0.0", "bpp_uri": "htp:// invalid-url"}}`, // Invalid scheme (htp instead of http)
wantErr: "invalid BPP URI - htp://invalid-url in request body for select: URL 'htp://invalid-url' must use https scheme", wantErr: `invalid BPP URI - htp:// invalid-url in request body for select: parse "htp:// invalid-url": invalid character " " in host name`,
}, },
} }