resolved reqpreprocessing review comments

1. added logs for contextKey parsing and unmarshaling yml file.
2. resolved router test case issue.
2. resolved logger test case issues.
This commit is contained in:
MohitKatare-protean
2025-04-06 17:09:59 +05:30
parent 278e217c64
commit ec62a3242b
7 changed files with 176 additions and 70 deletions

View File

@@ -101,7 +101,7 @@ func (h *stdHandler) stepCtx(r *http.Request, rh http.Header) (*model.StepContex
// subID retrieves the subscriber ID from the request context.
func (h *stdHandler) subID(ctx context.Context) string {
rSubID, ok := ctx.Value("subscriber_id").(string)
rSubID, ok := ctx.Value(model.ContextKeySubscriberID).(string)
if ok {
return rSubID
}

View File

@@ -12,6 +12,7 @@ import (
"sync"
"time"
"github.com/beckn/beckn-onix/pkg/model"
"github.com/rs/zerolog"
"gopkg.in/natefinch/lumberjack.v2"
)
@@ -54,7 +55,7 @@ var logLevels = map[level]zerolog.Level{
type Config struct {
Level level `yaml:"level"`
Destinations []destination `yaml:"destinations"`
ContextKeys []string `yaml:"contextKeys"`
ContextKeys []model.ContextKey `yaml:"contextKeys"`
}
var (
@@ -277,7 +278,7 @@ func addCtx(ctx context.Context, event *zerolog.Event) {
if !ok {
continue
}
keyStr := key
keyStr := string(key)
event.Any(keyStr, val)
}
}

View File

@@ -13,6 +13,8 @@ import (
"strings"
"testing"
"time"
"github.com/beckn/beckn-onix/pkg/model"
)
type ctxKey any
@@ -63,7 +65,12 @@ func setupLogger(t *testing.T, l level) string {
},
},
},
ContextKeys: []string{"userID", "requestID"},
ContextKeys: []model.ContextKey{
model.ContextKeyTxnID,
model.ContextKeyMsgID,
model.ContextKeySubscriberID,
model.ContextKeyModelID,
},
}
// Initialize logger with the given config
@@ -97,7 +104,7 @@ func parseLogLine(t *testing.T, line string) map[string]interface{} {
func TestDebug(t *testing.T) {
t.Helper()
logPath := setupLogger(t, DebugLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Debug(ctx, "Debug message")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -105,7 +112,7 @@ func TestDebug(t *testing.T) {
}
expected := map[string]interface{}{
"level": "debug",
"userID": "12345",
"subscriber_id": "12345",
"message": "Debug message",
}
@@ -129,7 +136,7 @@ func TestDebug(t *testing.T) {
func TestInfo(t *testing.T) {
logPath := setupLogger(t, InfoLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Info(ctx, "Info message")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -137,7 +144,7 @@ func TestInfo(t *testing.T) {
}
expected := map[string]interface{}{
"level": "info",
"userID": "12345",
"subscriber_id": "12345",
"message": "Info message",
}
@@ -161,7 +168,7 @@ func TestInfo(t *testing.T) {
func TestWarn(t *testing.T) {
logPath := setupLogger(t, WarnLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Warn(ctx, "Warning message")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -169,7 +176,7 @@ func TestWarn(t *testing.T) {
}
expected := map[string]interface{}{
"level": "warn",
"userID": "12345",
"subscriber_id": "12345",
"message": "Warning message",
}
@@ -189,7 +196,7 @@ func TestWarn(t *testing.T) {
func TestError(t *testing.T) {
logPath := setupLogger(t, ErrorLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Error(ctx, fmt.Errorf("test error"), "Error message")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -197,7 +204,7 @@ func TestError(t *testing.T) {
}
expected := map[string]interface{}{
"level": "error",
"userID": "12345",
"subscriber_id": "12345",
"message": "Error message",
"error": "test error",
}
@@ -277,7 +284,7 @@ func TestResponse(t *testing.T) {
func TestFatal(t *testing.T) {
logPath := setupLogger(t, FatalLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Fatal(ctx, fmt.Errorf("fatal error"), "Fatal message")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -285,7 +292,7 @@ func TestFatal(t *testing.T) {
}
expected := map[string]interface{}{
"level": "fatal",
"userID": "12345",
"subscriber_id": "12345",
"message": "Fatal message",
"error": "fatal error",
}
@@ -308,7 +315,7 @@ func TestFatal(t *testing.T) {
func TestPanic(t *testing.T) {
logPath := setupLogger(t, PanicLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Panic(ctx, fmt.Errorf("panic error"), "Panic message")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -316,7 +323,7 @@ func TestPanic(t *testing.T) {
}
expected := map[string]interface{}{
"level": "panic",
"userID": "12345",
"subscriber_id": "12345",
"message": "Panic message",
"error": "panic error",
}
@@ -339,7 +346,7 @@ func TestPanic(t *testing.T) {
func TestDebugf(t *testing.T) {
logPath := setupLogger(t, DebugLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Debugf(ctx, "Debugf message: %s", "test")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -347,7 +354,7 @@ func TestDebugf(t *testing.T) {
}
expected := map[string]interface{}{
"level": "debug",
"userID": "12345",
"subscriber_id": "12345",
"message": "Debugf message: test",
}
@@ -370,7 +377,7 @@ func TestDebugf(t *testing.T) {
func TestInfof(t *testing.T) {
logPath := setupLogger(t, InfoLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Infof(ctx, "Infof message: %s", "test")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -378,7 +385,7 @@ func TestInfof(t *testing.T) {
}
expected := map[string]interface{}{
"level": "info",
"userID": "12345",
"subscriber_id": "12345",
"message": "Infof message: test",
}
@@ -400,7 +407,7 @@ func TestInfof(t *testing.T) {
func TestWarnf(t *testing.T) {
logPath := setupLogger(t, WarnLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
Warnf(ctx, "Warnf message: %s", "test")
lines := readLogFile(t, logPath)
if len(lines) == 0 {
@@ -408,7 +415,7 @@ func TestWarnf(t *testing.T) {
}
expected := map[string]interface{}{
"level": "warn",
"userID": "12345",
"subscriber_id": "12345",
"message": "Warnf message: test",
}
@@ -430,7 +437,7 @@ func TestWarnf(t *testing.T) {
func TestErrorf(t *testing.T) {
logPath := setupLogger(t, ErrorLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
err := fmt.Errorf("error message")
Errorf(ctx, err, "Errorf message: %s", "test")
lines := readLogFile(t, logPath)
@@ -439,7 +446,7 @@ func TestErrorf(t *testing.T) {
}
expected := map[string]interface{}{
"level": "error",
"userID": "12345",
"subscriber_id": "12345",
"message": "Errorf message: test",
"error": "error message",
}
@@ -462,7 +469,7 @@ func TestErrorf(t *testing.T) {
func TestFatalf(t *testing.T) {
logPath := setupLogger(t, FatalLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
err := fmt.Errorf("fatal error")
Fatalf(ctx, err, "Fatalf message: %s", "test")
lines := readLogFile(t, logPath)
@@ -471,7 +478,7 @@ func TestFatalf(t *testing.T) {
}
expected := map[string]interface{}{
"level": "fatal",
"userID": "12345",
"subscriber_id": "12345",
"message": "Fatalf message: test",
"error": "fatal error",
}
@@ -494,7 +501,7 @@ func TestFatalf(t *testing.T) {
func TestPanicf(t *testing.T) {
logPath := setupLogger(t, PanicLevel)
ctx := context.WithValue(context.Background(), userID, "12345")
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
err := fmt.Errorf("panic error")
Panicf(ctx, err, "Panicf message: %s", "test")
lines := readLogFile(t, logPath)
@@ -504,7 +511,7 @@ func TestPanicf(t *testing.T) {
}
expected := map[string]interface{}{
"level": "panic",
"userID": "12345",
"subscriber_id": "12345",
"message": "Panicf message: test",
"error": "panic error",
}

View File

@@ -3,8 +3,8 @@ package model
import (
"errors"
"fmt"
"testing"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
@@ -198,3 +198,54 @@ func TestSchemaValidationErr_BecknError_NoErrors(t *testing.T) {
t.Errorf("beErr.Code = %s, want %s", beErr.Code, expectedCode)
}
}
func TestParseContextKey_ValidKeys(t *testing.T) {
tests := []struct {
input string
expected ContextKey
}{
{"message_id", ContextKeyMsgID},
{"subscriber_id", ContextKeySubscriberID},
{"model_id", ContextKeyModelID},
}
for _, tt := range tests {
key, err := ParseContextKey(tt.input)
if err != nil {
t.Errorf("unexpected error for input %s: %v", tt.input, err)
}
if key != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, key)
}
}
}
func TestParseContextKey_InvalidKey(t *testing.T) {
_, err := ParseContextKey("invalid_key")
if err == nil {
t.Error("expected error for invalid context key, got nil")
}
}
func TestContextKey_UnmarshalYAML_Valid(t *testing.T) {
yamlData := []byte("message_id")
var key ContextKey
err := yaml.Unmarshal(yamlData, &key)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if key != ContextKeyMsgID {
t.Errorf("expected %s, got %s", ContextKeyMsgID, key)
}
}
func TestContextKey_UnmarshalYAML_Invalid(t *testing.T) {
yamlData := []byte("invalid_key")
var key ContextKey
err := yaml.Unmarshal(yamlData, &key)
if err == nil {
t.Error("expected error for invalid context key, got nil")
}
}

View File

@@ -42,6 +42,9 @@ const (
type ContextKey string
const (
// ContextKeyTxnID is the context key used to store and retrieve the transaction ID in a request context.
ContextKeyTxnID ContextKey = "transaction_id"
// ContextKeyMsgID is the context key used to store and retrieve the message ID in a request context.
ContextKeyMsgID ContextKey = "message_id"
@@ -52,6 +55,38 @@ const (
ContextKeyModelID ContextKey = "model_id"
)
var contextKeys = map[string]ContextKey{
"message_id": ContextKeyMsgID,
"subscriber_id": ContextKeySubscriberID,
"model_id": ContextKeyModelID,
}
// ParseContextKey converts a string into a valid ContextKey.
func ParseContextKey(v string) (ContextKey, error) {
key, ok := contextKeys[v]
if !ok {
return "", fmt.Errorf("invalid context key: %s", key)
}
return key, nil
}
// UnmarshalYAML ensures that only known context keys are accepted during YAML unmarshalling.
func (k *ContextKey) UnmarshalYAML(unmarshal func(interface{}) error) error {
var keyStr string
if err := unmarshal(&keyStr); err != nil {
return err
}
parsedKey, err := ParseContextKey(keyStr)
if err != nil {
return err
}
*k = parsedKey
return nil
}
// Role defines the type of participant in the network.
type Role string

View File

@@ -16,6 +16,7 @@ import (
// Config represents the configuration for the request preprocessor middleware.
type Config struct {
Role string
ContextKeys []string
}
const contextKey = "context"
@@ -57,7 +58,12 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.ContextKeySubscriberID, subID)
ctx = context.WithValue(ctx, model.ContextKeySubscriberID, subID)
}
for _, key := range cfg.ContextKeys {
ctxKey, _ := model.ParseContextKey(key)
if v, ok := reqContext[key]; ok {
ctx = context.WithValue(ctx, ctxKey, v)
}
}
r.Body = io.NopCloser(bytes.NewBuffer(body))
r.ContentLength = int64(len(body))
r = r.WithContext(ctx)
@@ -74,5 +80,11 @@ func validateConfig(cfg *Config) error {
if cfg.Role != "bap" && cfg.Role != "bpp" {
return errors.New("role must be either 'bap' or 'bpp'")
}
for _, key := range cfg.ContextKeys {
if _, err := model.ParseContextKey(key); err != nil {
return err
}
}
return nil
}

View File

@@ -296,7 +296,7 @@ func TestValidateRulesFailure(t *testing.T) {
Endpoints: []string{"search"},
},
},
wantErr: `invalid URI htp:// invalid-url.com in request body for url: invalid URL - htp:// invalid-url.com: parse "htp:// invalid-url.com": invalid character " " in host name`,
wantErr: `invalid URL - htp:// invalid-url.com: parse "htp:// invalid-url.com": invalid character " " in host name`,
},
{
name: "Missing topic_id for targetType: publisher",
@@ -326,7 +326,7 @@ func TestValidateRulesFailure(t *testing.T) {
Endpoints: []string{"search"},
},
},
wantErr: "invalid URI htp://invalid-url.com in request body for bpp: URL 'htp://invalid-url.com' must use https scheme",
wantErr: `invalid URL - htp:// invalid-url.com defined in routing config for target type bpp: parse "htp:// invalid-url.com": invalid character " " in host name`,
},
{
name: "Invalid URL for BAP targetType",
@@ -341,7 +341,7 @@ func TestValidateRulesFailure(t *testing.T) {
Endpoints: []string{"search"},
},
},
wantErr: "invalid URI http://[invalid].com in request body for bap",
wantErr: `invalid URL - http:// [invalid].com defined in routing config for target type bap: parse "http:// [invalid].com": invalid character " " in host name`,
},
}
@@ -465,7 +465,7 @@ func TestRouteFailure(t *testing.T) {
configFile: "bap_caller.yaml",
url: "https://example.com/v1/ondc/select",
body: `{"context": {"domain": "ONDC:TRV10", "version": "2.0.0", "bpp_uri": "htp:// invalid-url"}}`, // Invalid scheme (htp instead of http)
wantErr: "invalid BPP URI - htp://invalid-url in request body for select: URL 'htp://invalid-url' must use https scheme",
wantErr: `invalid BPP URI - htp:// invalid-url in request body for select: parse "htp:// invalid-url": invalid character " " in host name`,
},
}