Merge pull request #444 from beckn/feature/contex_id
Request pre-processing changes for contextKeys
This commit is contained in:
@@ -101,7 +101,7 @@ func (h *stdHandler) stepCtx(r *http.Request, rh http.Header) (*model.StepContex
|
|||||||
|
|
||||||
// subID retrieves the subscriber ID from the request context.
|
// subID retrieves the subscriber ID from the request context.
|
||||||
func (h *stdHandler) subID(ctx context.Context) string {
|
func (h *stdHandler) subID(ctx context.Context) string {
|
||||||
rSubID, ok := ctx.Value("subscriber_id").(string)
|
rSubID, ok := ctx.Value(model.ContextKeySubscriberID).(string)
|
||||||
if ok {
|
if ok {
|
||||||
return rSubID
|
return rSubID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func addMiddleware(ctx context.Context, mgr handler.PluginManager, handler http.
|
|||||||
|
|
||||||
func moduleCtxMiddleware(moduleName string, next http.Handler) http.Handler {
|
func moduleCtxMiddleware(moduleName string, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := context.WithValue(r.Context(), model.ContextKeyModuleId, moduleName)
|
ctx := context.WithValue(r.Context(), model.ContextKeyModuleID, moduleName)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func TestRegisterSuccess(t *testing.T) {
|
|||||||
// Create a handler that extracts context
|
// Create a handler that extracts context
|
||||||
var capturedModuleName any
|
var capturedModuleName any
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
capturedModuleName = r.Context().Value(model.ContextKeyModuleId)
|
capturedModuleName = r.Context().Value(model.ContextKeyModuleID)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/beckn/beckn-onix/pkg/model"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
)
|
)
|
||||||
@@ -54,7 +55,7 @@ var logLevels = map[level]zerolog.Level{
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
Level level `yaml:"level"`
|
Level level `yaml:"level"`
|
||||||
Destinations []destination `yaml:"destinations"`
|
Destinations []destination `yaml:"destinations"`
|
||||||
ContextKeys []string `yaml:"contextKeys"`
|
ContextKeys []model.ContextKey `yaml:"contextKeys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -277,7 +278,7 @@ func addCtx(ctx context.Context, event *zerolog.Event) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
keyStr := key
|
keyStr := string(key)
|
||||||
event.Any(keyStr, val)
|
event.Any(keyStr, val)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,12 +13,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/beckn/beckn-onix/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ctxKey any
|
type ctxKey any
|
||||||
|
|
||||||
var requestID ctxKey = "requestID"
|
var requestID ctxKey = "requestID"
|
||||||
var userID ctxKey = "userID"
|
|
||||||
|
|
||||||
const testLogFilePath = "./test_logs/test.log"
|
const testLogFilePath = "./test_logs/test.log"
|
||||||
|
|
||||||
@@ -63,7 +64,12 @@ func setupLogger(t *testing.T, l level) string {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ContextKeys: []string{"userID", "requestID"},
|
ContextKeys: []model.ContextKey{
|
||||||
|
model.ContextKeyTxnID,
|
||||||
|
model.ContextKeyMsgID,
|
||||||
|
model.ContextKeySubscriberID,
|
||||||
|
model.ContextKeyModuleID,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize logger with the given config
|
// Initialize logger with the given config
|
||||||
@@ -97,7 +103,7 @@ func parseLogLine(t *testing.T, line string) map[string]interface{} {
|
|||||||
func TestDebug(t *testing.T) {
|
func TestDebug(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
logPath := setupLogger(t, DebugLevel)
|
logPath := setupLogger(t, DebugLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Debug(ctx, "Debug message")
|
Debug(ctx, "Debug message")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -105,7 +111,7 @@ func TestDebug(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "debug",
|
"level": "debug",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Debug message",
|
"message": "Debug message",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +135,7 @@ func TestDebug(t *testing.T) {
|
|||||||
|
|
||||||
func TestInfo(t *testing.T) {
|
func TestInfo(t *testing.T) {
|
||||||
logPath := setupLogger(t, InfoLevel)
|
logPath := setupLogger(t, InfoLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Info(ctx, "Info message")
|
Info(ctx, "Info message")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -137,7 +143,7 @@ func TestInfo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "info",
|
"level": "info",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Info message",
|
"message": "Info message",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,7 +167,7 @@ func TestInfo(t *testing.T) {
|
|||||||
|
|
||||||
func TestWarn(t *testing.T) {
|
func TestWarn(t *testing.T) {
|
||||||
logPath := setupLogger(t, WarnLevel)
|
logPath := setupLogger(t, WarnLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Warn(ctx, "Warning message")
|
Warn(ctx, "Warning message")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -169,7 +175,7 @@ func TestWarn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "warn",
|
"level": "warn",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Warning message",
|
"message": "Warning message",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,7 +195,7 @@ func TestWarn(t *testing.T) {
|
|||||||
|
|
||||||
func TestError(t *testing.T) {
|
func TestError(t *testing.T) {
|
||||||
logPath := setupLogger(t, ErrorLevel)
|
logPath := setupLogger(t, ErrorLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Error(ctx, fmt.Errorf("test error"), "Error message")
|
Error(ctx, fmt.Errorf("test error"), "Error message")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -197,7 +203,7 @@ func TestError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "error",
|
"level": "error",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Error message",
|
"message": "Error message",
|
||||||
"error": "test error",
|
"error": "test error",
|
||||||
}
|
}
|
||||||
@@ -277,7 +283,7 @@ func TestResponse(t *testing.T) {
|
|||||||
|
|
||||||
func TestFatal(t *testing.T) {
|
func TestFatal(t *testing.T) {
|
||||||
logPath := setupLogger(t, FatalLevel)
|
logPath := setupLogger(t, FatalLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Fatal(ctx, fmt.Errorf("fatal error"), "Fatal message")
|
Fatal(ctx, fmt.Errorf("fatal error"), "Fatal message")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -285,7 +291,7 @@ func TestFatal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "fatal",
|
"level": "fatal",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Fatal message",
|
"message": "Fatal message",
|
||||||
"error": "fatal error",
|
"error": "fatal error",
|
||||||
}
|
}
|
||||||
@@ -308,7 +314,7 @@ func TestFatal(t *testing.T) {
|
|||||||
|
|
||||||
func TestPanic(t *testing.T) {
|
func TestPanic(t *testing.T) {
|
||||||
logPath := setupLogger(t, PanicLevel)
|
logPath := setupLogger(t, PanicLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Panic(ctx, fmt.Errorf("panic error"), "Panic message")
|
Panic(ctx, fmt.Errorf("panic error"), "Panic message")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -316,7 +322,7 @@ func TestPanic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "panic",
|
"level": "panic",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Panic message",
|
"message": "Panic message",
|
||||||
"error": "panic error",
|
"error": "panic error",
|
||||||
}
|
}
|
||||||
@@ -339,7 +345,7 @@ func TestPanic(t *testing.T) {
|
|||||||
|
|
||||||
func TestDebugf(t *testing.T) {
|
func TestDebugf(t *testing.T) {
|
||||||
logPath := setupLogger(t, DebugLevel)
|
logPath := setupLogger(t, DebugLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Debugf(ctx, "Debugf message: %s", "test")
|
Debugf(ctx, "Debugf message: %s", "test")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -347,7 +353,7 @@ func TestDebugf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "debug",
|
"level": "debug",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Debugf message: test",
|
"message": "Debugf message: test",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,7 +376,7 @@ func TestDebugf(t *testing.T) {
|
|||||||
|
|
||||||
func TestInfof(t *testing.T) {
|
func TestInfof(t *testing.T) {
|
||||||
logPath := setupLogger(t, InfoLevel)
|
logPath := setupLogger(t, InfoLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Infof(ctx, "Infof message: %s", "test")
|
Infof(ctx, "Infof message: %s", "test")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -378,7 +384,7 @@ func TestInfof(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "info",
|
"level": "info",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Infof message: test",
|
"message": "Infof message: test",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -400,7 +406,7 @@ func TestInfof(t *testing.T) {
|
|||||||
|
|
||||||
func TestWarnf(t *testing.T) {
|
func TestWarnf(t *testing.T) {
|
||||||
logPath := setupLogger(t, WarnLevel)
|
logPath := setupLogger(t, WarnLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
Warnf(ctx, "Warnf message: %s", "test")
|
Warnf(ctx, "Warnf message: %s", "test")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
if len(lines) == 0 {
|
if len(lines) == 0 {
|
||||||
@@ -408,7 +414,7 @@ func TestWarnf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "warn",
|
"level": "warn",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Warnf message: test",
|
"message": "Warnf message: test",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -430,7 +436,7 @@ func TestWarnf(t *testing.T) {
|
|||||||
|
|
||||||
func TestErrorf(t *testing.T) {
|
func TestErrorf(t *testing.T) {
|
||||||
logPath := setupLogger(t, ErrorLevel)
|
logPath := setupLogger(t, ErrorLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
err := fmt.Errorf("error message")
|
err := fmt.Errorf("error message")
|
||||||
Errorf(ctx, err, "Errorf message: %s", "test")
|
Errorf(ctx, err, "Errorf message: %s", "test")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
@@ -439,7 +445,7 @@ func TestErrorf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "error",
|
"level": "error",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Errorf message: test",
|
"message": "Errorf message: test",
|
||||||
"error": "error message",
|
"error": "error message",
|
||||||
}
|
}
|
||||||
@@ -462,7 +468,7 @@ func TestErrorf(t *testing.T) {
|
|||||||
|
|
||||||
func TestFatalf(t *testing.T) {
|
func TestFatalf(t *testing.T) {
|
||||||
logPath := setupLogger(t, FatalLevel)
|
logPath := setupLogger(t, FatalLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
err := fmt.Errorf("fatal error")
|
err := fmt.Errorf("fatal error")
|
||||||
Fatalf(ctx, err, "Fatalf message: %s", "test")
|
Fatalf(ctx, err, "Fatalf message: %s", "test")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
@@ -471,7 +477,7 @@ func TestFatalf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "fatal",
|
"level": "fatal",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Fatalf message: test",
|
"message": "Fatalf message: test",
|
||||||
"error": "fatal error",
|
"error": "fatal error",
|
||||||
}
|
}
|
||||||
@@ -494,7 +500,7 @@ func TestFatalf(t *testing.T) {
|
|||||||
|
|
||||||
func TestPanicf(t *testing.T) {
|
func TestPanicf(t *testing.T) {
|
||||||
logPath := setupLogger(t, PanicLevel)
|
logPath := setupLogger(t, PanicLevel)
|
||||||
ctx := context.WithValue(context.Background(), userID, "12345")
|
ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345")
|
||||||
err := fmt.Errorf("panic error")
|
err := fmt.Errorf("panic error")
|
||||||
Panicf(ctx, err, "Panicf message: %s", "test")
|
Panicf(ctx, err, "Panicf message: %s", "test")
|
||||||
lines := readLogFile(t, logPath)
|
lines := readLogFile(t, logPath)
|
||||||
@@ -504,7 +510,7 @@ func TestPanicf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := map[string]interface{}{
|
expected := map[string]interface{}{
|
||||||
"level": "panic",
|
"level": "panic",
|
||||||
"userID": "12345",
|
"subscriber_id": "12345",
|
||||||
"message": "Panicf message: test",
|
"message": "Panicf message: test",
|
||||||
"error": "panic error",
|
"error": "panic error",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -198,3 +198,55 @@ func TestSchemaValidationErr_BecknError_NoErrors(t *testing.T) {
|
|||||||
t.Errorf("beErr.Code = %s, want %s", beErr.Code, expectedCode)
|
t.Errorf("beErr.Code = %s, want %s", beErr.Code, expectedCode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseContextKey_ValidKeys(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected ContextKey
|
||||||
|
}{
|
||||||
|
{"transaction_id", ContextKeyTxnID},
|
||||||
|
{"message_id", ContextKeyMsgID},
|
||||||
|
{"subscriber_id", ContextKeySubscriberID},
|
||||||
|
{"module_id", ContextKeyModuleID},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
key, err := ParseContextKey(tt.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error for input %s: %v", tt.input, err)
|
||||||
|
}
|
||||||
|
if key != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseContextKey_InvalidKey(t *testing.T) {
|
||||||
|
_, err := ParseContextKey("invalid_key")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid context key, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextKey_UnmarshalYAML_Valid(t *testing.T) {
|
||||||
|
yamlData := []byte("message_id")
|
||||||
|
var key ContextKey
|
||||||
|
|
||||||
|
err := yaml.Unmarshal(yamlData, &key)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if key != ContextKeyMsgID {
|
||||||
|
t.Errorf("expected %s, got %s", ContextKeyMsgID, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextKey_UnmarshalYAML_Invalid(t *testing.T) {
|
||||||
|
yamlData := []byte("invalid_key")
|
||||||
|
var key ContextKey
|
||||||
|
|
||||||
|
err := yaml.Unmarshal(yamlData, &key)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid context key, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,10 +16,6 @@ type Subscriber struct {
|
|||||||
Domain string `json:"domain"`
|
Domain string `json:"domain"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextKey string
|
|
||||||
|
|
||||||
const ContextKeyModuleId ContextKey = "module_id"
|
|
||||||
|
|
||||||
// Subscription represents subscription details of a network participant.
|
// Subscription represents subscription details of a network participant.
|
||||||
type Subscription struct {
|
type Subscription struct {
|
||||||
Subscriber `json:",inline"`
|
Subscriber `json:",inline"`
|
||||||
@@ -42,10 +38,55 @@ const (
|
|||||||
UnaAuthorizedHeaderGateway string = "Proxy-Authenticate"
|
UnaAuthorizedHeaderGateway string = "Proxy-Authenticate"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey string
|
// ContextKey is a custom type used as a key for storing and retrieving values in a context.
|
||||||
|
type ContextKey string
|
||||||
|
|
||||||
// MsgIDKey is the context key used to store and retrieve the message ID in a request context.
|
const (
|
||||||
const MsgIDKey = contextKey("message_id")
|
// ContextKeyTxnID is the context key used to store and retrieve the transaction ID in a request context.
|
||||||
|
ContextKeyTxnID ContextKey = "transaction_id"
|
||||||
|
|
||||||
|
// ContextKeyMsgID is the context key used to store and retrieve the message ID in a request context.
|
||||||
|
ContextKeyMsgID ContextKey = "message_id"
|
||||||
|
|
||||||
|
// ContextKeySubscriberID is the context key used to store and retrieve the subscriber ID in a request context.
|
||||||
|
ContextKeySubscriberID ContextKey = "subscriber_id"
|
||||||
|
|
||||||
|
// ContextKeyModuleID is the context key for storing and retrieving the model ID from a request context.
|
||||||
|
ContextKeyModuleID ContextKey = "module_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
var contextKeys = map[string]ContextKey{
|
||||||
|
"transaction_id": ContextKeyTxnID,
|
||||||
|
"message_id": ContextKeyMsgID,
|
||||||
|
"subscriber_id": ContextKeySubscriberID,
|
||||||
|
"module_id": ContextKeyModuleID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseContextKey converts a string into a valid ContextKey.
|
||||||
|
func ParseContextKey(v string) (ContextKey, error) {
|
||||||
|
key, ok := contextKeys[v]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("invalid context key: %s", v)
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalYAML ensures that only known context keys are accepted during YAML unmarshalling.
|
||||||
|
func (k *ContextKey) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
var keyStr string
|
||||||
|
if err := unmarshal(&keyStr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedKey, err := ParseContextKey(keyStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*k = parsedKey
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// Role defines the type of participant in the network.
|
// Role defines the type of participant in the network.
|
||||||
type Role string
|
type Role string
|
||||||
|
|||||||
@@ -10,19 +10,19 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/beckn/beckn-onix/pkg/log"
|
"github.com/beckn/beckn-onix/pkg/log"
|
||||||
|
"github.com/beckn/beckn-onix/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Config represents the configuration for the request preprocessor middleware.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Role string
|
Role string
|
||||||
|
ContextKeys []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type keyType string
|
const contextKey = "context"
|
||||||
|
|
||||||
const (
|
|
||||||
contextKey keyType = "context"
|
|
||||||
subscriberIDKey keyType = "subscriber_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
// NewPreProcessor returns a middleware that processes the incoming request,
|
||||||
|
// extracts the context field from the body, and adds relevant values (like subscriber ID).
|
||||||
func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
|
func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
|
||||||
if err := validateConfig(cfg); err != nil {
|
if err := validateConfig(cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -41,7 +41,7 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract context from request
|
// Extract context from request.
|
||||||
reqContext, ok := req["context"].(map[string]interface{})
|
reqContext, ok := req["context"].(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKey), http.StatusBadRequest)
|
http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKey), http.StatusBadRequest)
|
||||||
@@ -55,10 +55,15 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) {
|
|||||||
subID = reqContext["bpp_id"]
|
subID = reqContext["bpp_id"]
|
||||||
}
|
}
|
||||||
if subID != nil {
|
if subID != nil {
|
||||||
log.Debugf(ctx, "adding subscriberId to request:%s, %v", subscriberIDKey, subID)
|
log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.ContextKeySubscriberID, subID)
|
||||||
ctx = context.WithValue(ctx, subscriberIDKey, subID)
|
ctx = context.WithValue(ctx, model.ContextKeySubscriberID, subID)
|
||||||
|
}
|
||||||
|
for _, key := range cfg.ContextKeys {
|
||||||
|
ctxKey, _ := model.ParseContextKey(key)
|
||||||
|
if v, ok := reqContext[key]; ok {
|
||||||
|
ctx = context.WithValue(ctx, ctxKey, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||||
r.ContentLength = int64(len(body))
|
r.ContentLength = int64(len(body))
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
@@ -75,5 +80,11 @@ func validateConfig(cfg *Config) error {
|
|||||||
if cfg.Role != "bap" && cfg.Role != "bpp" {
|
if cfg.Role != "bap" && cfg.Role != "bpp" {
|
||||||
return errors.New("role must be either 'bap' or 'bpp'")
|
return errors.New("role must be either 'bap' or 'bpp'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, key := range cfg.ContextKeys {
|
||||||
|
if _, err := model.ParseContextKey(key); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/beckn/beckn-onix/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ToDo Separate Middleware creation and execution.
|
// ToDo Separate Middleware creation and execution.
|
||||||
@@ -71,11 +73,11 @@ func TestNewPreProcessorSuccessCases(t *testing.T) {
|
|||||||
|
|
||||||
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
gotSubID = ctx.Value(subscriberIDKey)
|
gotSubID = ctx.Value(model.ContextKeySubscriberID)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
// Verify subscriber ID
|
// Verify subscriber ID
|
||||||
subID := ctx.Value(subscriberIDKey)
|
subID := ctx.Value(model.ContextKeySubscriberID)
|
||||||
if subID == nil {
|
if subID == nil {
|
||||||
t.Errorf("Expected subscriber ID but got none %s", ctx)
|
t.Errorf("Expected subscriber ID but got none %s", ctx)
|
||||||
return
|
return
|
||||||
@@ -230,3 +232,38 @@ func TestNewPreProcessorErrorCases(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewPreProcessorAddsSubscriberIDToContext(t *testing.T) {
|
||||||
|
cfg := &Config{Role: "bap"}
|
||||||
|
middleware, err := NewPreProcessor(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
samplePayload := map[string]interface{}{
|
||||||
|
"context": map[string]interface{}{
|
||||||
|
"bap_id": "bap.example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
bodyBytes, _ := json.Marshal(samplePayload)
|
||||||
|
|
||||||
|
var receivedSubscriberID interface{}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedSubscriberID = r.Context().Value(model.ContextKeySubscriberID)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/", strings.NewReader(string(bodyBytes)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
middleware(handler).ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("Expected status 200 OK, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
if receivedSubscriberID != "bap.example.com" {
|
||||||
|
t.Errorf("Expected subscriber ID 'bap.example.com', got %v", receivedSubscriberID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -291,12 +291,12 @@ func TestValidateRulesFailure(t *testing.T) {
|
|||||||
Version: "1.0.0",
|
Version: "1.0.0",
|
||||||
TargetType: "url",
|
TargetType: "url",
|
||||||
Target: target{
|
Target: target{
|
||||||
URL: "htp://invalid-url.com", // Invalid scheme
|
URL: "htp:// invalid-url.com", // Invalid scheme
|
||||||
},
|
},
|
||||||
Endpoints: []string{"search"},
|
Endpoints: []string{"search"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantErr: "invalid URL - htp://invalid-url.com: URL 'htp://invalid-url.com' must use https scheme",
|
wantErr: `invalid URL - htp:// invalid-url.com: parse "htp:// invalid-url.com": invalid character " " in host name`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Missing topic_id for targetType: publisher",
|
name: "Missing topic_id for targetType: publisher",
|
||||||
@@ -321,12 +321,12 @@ func TestValidateRulesFailure(t *testing.T) {
|
|||||||
Version: "1.0.0",
|
Version: "1.0.0",
|
||||||
TargetType: "bpp",
|
TargetType: "bpp",
|
||||||
Target: target{
|
Target: target{
|
||||||
URL: "htp://invalid-url.com", // Invalid URL
|
URL: "htp:// invalid-url.com", // Invalid URL
|
||||||
},
|
},
|
||||||
Endpoints: []string{"search"},
|
Endpoints: []string{"search"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantErr: "invalid URL - htp://invalid-url.com defined in routing config for target type bpp",
|
wantErr: `invalid URL - htp:// invalid-url.com defined in routing config for target type bpp: parse "htp:// invalid-url.com": invalid character " " in host name`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid URL for BAP targetType",
|
name: "Invalid URL for BAP targetType",
|
||||||
@@ -336,12 +336,12 @@ func TestValidateRulesFailure(t *testing.T) {
|
|||||||
Version: "1.0.0",
|
Version: "1.0.0",
|
||||||
TargetType: "bap",
|
TargetType: "bap",
|
||||||
Target: target{
|
Target: target{
|
||||||
URL: "http://[invalid].com", // Invalid host
|
URL: "http:// [invalid].com", // Invalid host
|
||||||
},
|
},
|
||||||
Endpoints: []string{"search"},
|
Endpoints: []string{"search"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wantErr: "invalid URL - http://[invalid].com defined in routing config for target type bap",
|
wantErr: `invalid URL - http:// [invalid].com defined in routing config for target type bap: parse "http:// [invalid].com": invalid character " " in host name`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -464,8 +464,8 @@ func TestRouteFailure(t *testing.T) {
|
|||||||
name: "Invalid bpp_uri format in request",
|
name: "Invalid bpp_uri format in request",
|
||||||
configFile: "bap_caller.yaml",
|
configFile: "bap_caller.yaml",
|
||||||
url: "https://example.com/v1/ondc/select",
|
url: "https://example.com/v1/ondc/select",
|
||||||
body: `{"context": {"domain": "ONDC:TRV10", "version": "2.0.0", "bpp_uri": "htp://invalid-url"}}`, // Invalid scheme (htp instead of http)
|
body: `{"context": {"domain": "ONDC:TRV10", "version": "2.0.0", "bpp_uri": "htp:// invalid-url"}}`, // Invalid scheme (htp instead of http)
|
||||||
wantErr: "invalid BPP URI - htp://invalid-url in request body for select: URL 'htp://invalid-url' must use https scheme",
|
wantErr: `invalid BPP URI - htp:// invalid-url in request body for select: parse "htp:// invalid-url": invalid character " " in host name`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/beckn/beckn-onix/pkg/log"
|
||||||
"github.com/beckn/beckn-onix/pkg/model"
|
"github.com/beckn/beckn-onix/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,8 +48,8 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i
|
|||||||
w.WriteHeader(status)
|
w.WriteHeader(status)
|
||||||
_, er := w.Write(data)
|
_, er := w.Write(data)
|
||||||
if er != nil {
|
if er != nil {
|
||||||
fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.MsgIDKey))
|
log.Debugf(ctx, "Error writing response: %v, MessageID: %s", er, ctx.Value(model.ContextKeyMsgID))
|
||||||
http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -57,7 +58,7 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i
|
|||||||
func internalServerError(ctx context.Context) *model.Error {
|
func internalServerError(ctx context.Context) *model.Error {
|
||||||
return &model.Error{
|
return &model.Error{
|
||||||
Code: http.StatusText(http.StatusInternalServerError),
|
Code: http.StatusText(http.StatusInternalServerError),
|
||||||
Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)),
|
Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func TestSendAck(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSendNack(t *testing.T) {
|
func TestSendNack(t *testing.T) {
|
||||||
ctx := context.WithValue(context.Background(), model.MsgIDKey, "123456")
|
ctx := context.WithValue(context.Background(), model.ContextKeyMsgID, "123456")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -197,7 +197,7 @@ func TestNack_1(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ctx := context.WithValue(req.Context(), model.MsgIDKey, "12345")
|
ctx := context.WithValue(req.Context(), model.ContextKeyMsgID, "12345")
|
||||||
|
|
||||||
var w http.ResponseWriter
|
var w http.ResponseWriter
|
||||||
if tt.useBadWrite {
|
if tt.useBadWrite {
|
||||||
|
|||||||
Reference in New Issue
Block a user