diff --git a/core/module/handler/stdHandler.go b/core/module/handler/stdHandler.go index 251102e..3a3caaa 100644 --- a/core/module/handler/stdHandler.go +++ b/core/module/handler/stdHandler.go @@ -101,7 +101,7 @@ func (h *stdHandler) stepCtx(r *http.Request, rh http.Header) (*model.StepContex // subID retrieves the subscriber ID from the request context. func (h *stdHandler) subID(ctx context.Context) string { - rSubID, ok := ctx.Value("subscriber_id").(string) + rSubID, ok := ctx.Value(model.ContextKeySubscriberID).(string) if ok { return rSubID } diff --git a/core/module/module.go b/core/module/module.go index 4641fcb..2a28e62 100644 --- a/core/module/module.go +++ b/core/module/module.go @@ -76,7 +76,7 @@ func addMiddleware(ctx context.Context, mgr handler.PluginManager, handler http. func moduleCtxMiddleware(moduleName string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), model.ContextKeyModuleId, moduleName) + ctx := context.WithValue(r.Context(), model.ContextKeyModuleID, moduleName) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/core/module/module_test.go b/core/module/module_test.go index 6091fc1..ffeaafe 100644 --- a/core/module/module_test.go +++ b/core/module/module_test.go @@ -107,7 +107,7 @@ func TestRegisterSuccess(t *testing.T) { // Create a handler that extracts context var capturedModuleName any testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedModuleName = r.Context().Value(model.ContextKeyModuleId) + capturedModuleName = r.Context().Value(model.ContextKeyModuleID) w.WriteHeader(http.StatusOK) }) diff --git a/pkg/log/log.go b/pkg/log/log.go index eabd9f0..1f5f59f 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/beckn/beckn-onix/pkg/model" "github.com/rs/zerolog" "gopkg.in/natefinch/lumberjack.v2" ) @@ -52,9 +53,9 @@ var logLevels = map[level]zerolog.Level{ // Config represents the configuration for logging. type Config struct { - Level level `yaml:"level"` - Destinations []destination `yaml:"destinations"` - ContextKeys []string `yaml:"contextKeys"` + Level level `yaml:"level"` + Destinations []destination `yaml:"destinations"` + ContextKeys []model.ContextKey `yaml:"contextKeys"` } var ( @@ -277,7 +278,7 @@ func addCtx(ctx context.Context, event *zerolog.Event) { if !ok { continue } - keyStr := key + keyStr := string(key) event.Any(keyStr, val) } } diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go index 2e874ae..2b93d73 100644 --- a/pkg/log/log_test.go +++ b/pkg/log/log_test.go @@ -13,12 +13,13 @@ import ( "strings" "testing" "time" + + "github.com/beckn/beckn-onix/pkg/model" ) type ctxKey any var requestID ctxKey = "requestID" -var userID ctxKey = "userID" const testLogFilePath = "./test_logs/test.log" @@ -63,7 +64,12 @@ func setupLogger(t *testing.T, l level) string { }, }, }, - ContextKeys: []string{"userID", "requestID"}, + ContextKeys: []model.ContextKey{ + model.ContextKeyTxnID, + model.ContextKeyMsgID, + model.ContextKeySubscriberID, + model.ContextKeyModuleID, + }, } // Initialize logger with the given config @@ -97,16 +103,16 @@ func parseLogLine(t *testing.T, line string) map[string]interface{} { func TestDebug(t *testing.T) { t.Helper() logPath := setupLogger(t, DebugLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Debug(ctx, "Debug message") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "debug", - "userID": "12345", - "message": "Debug message", + "level": "debug", + "subscriber_id": "12345", + "message": "Debug message", } var found bool @@ -129,16 +135,16 @@ func TestDebug(t *testing.T) { func TestInfo(t *testing.T) { logPath := setupLogger(t, InfoLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Info(ctx, "Info message") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "info", - "userID": "12345", - "message": "Info message", + "level": "info", + "subscriber_id": "12345", + "message": "Info message", } var found bool @@ -161,16 +167,16 @@ func TestInfo(t *testing.T) { func TestWarn(t *testing.T) { logPath := setupLogger(t, WarnLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Warn(ctx, "Warning message") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "warn", - "userID": "12345", - "message": "Warning message", + "level": "warn", + "subscriber_id": "12345", + "message": "Warning message", } var found bool @@ -189,17 +195,17 @@ func TestWarn(t *testing.T) { func TestError(t *testing.T) { logPath := setupLogger(t, ErrorLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Error(ctx, fmt.Errorf("test error"), "Error message") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "error", - "userID": "12345", - "message": "Error message", - "error": "test error", + "level": "error", + "subscriber_id": "12345", + "message": "Error message", + "error": "test error", } var found bool @@ -277,17 +283,17 @@ func TestResponse(t *testing.T) { func TestFatal(t *testing.T) { logPath := setupLogger(t, FatalLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Fatal(ctx, fmt.Errorf("fatal error"), "Fatal message") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "fatal", - "userID": "12345", - "message": "Fatal message", - "error": "fatal error", + "level": "fatal", + "subscriber_id": "12345", + "message": "Fatal message", + "error": "fatal error", } var found bool @@ -308,17 +314,17 @@ func TestFatal(t *testing.T) { func TestPanic(t *testing.T) { logPath := setupLogger(t, PanicLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Panic(ctx, fmt.Errorf("panic error"), "Panic message") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "panic", - "userID": "12345", - "message": "Panic message", - "error": "panic error", + "level": "panic", + "subscriber_id": "12345", + "message": "Panic message", + "error": "panic error", } var found bool @@ -339,16 +345,16 @@ func TestPanic(t *testing.T) { func TestDebugf(t *testing.T) { logPath := setupLogger(t, DebugLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Debugf(ctx, "Debugf message: %s", "test") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "debug", - "userID": "12345", - "message": "Debugf message: test", + "level": "debug", + "subscriber_id": "12345", + "message": "Debugf message: test", } var found bool @@ -370,16 +376,16 @@ func TestDebugf(t *testing.T) { func TestInfof(t *testing.T) { logPath := setupLogger(t, InfoLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Infof(ctx, "Infof message: %s", "test") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "info", - "userID": "12345", - "message": "Infof message: test", + "level": "info", + "subscriber_id": "12345", + "message": "Infof message: test", } var found bool @@ -400,16 +406,16 @@ func TestInfof(t *testing.T) { func TestWarnf(t *testing.T) { logPath := setupLogger(t, WarnLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") Warnf(ctx, "Warnf message: %s", "test") lines := readLogFile(t, logPath) if len(lines) == 0 { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "warn", - "userID": "12345", - "message": "Warnf message: test", + "level": "warn", + "subscriber_id": "12345", + "message": "Warnf message: test", } var found bool @@ -430,7 +436,7 @@ func TestWarnf(t *testing.T) { func TestErrorf(t *testing.T) { logPath := setupLogger(t, ErrorLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") err := fmt.Errorf("error message") Errorf(ctx, err, "Errorf message: %s", "test") lines := readLogFile(t, logPath) @@ -438,10 +444,10 @@ func TestErrorf(t *testing.T) { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "error", - "userID": "12345", - "message": "Errorf message: test", - "error": "error message", + "level": "error", + "subscriber_id": "12345", + "message": "Errorf message: test", + "error": "error message", } var found bool @@ -462,7 +468,7 @@ func TestErrorf(t *testing.T) { func TestFatalf(t *testing.T) { logPath := setupLogger(t, FatalLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") err := fmt.Errorf("fatal error") Fatalf(ctx, err, "Fatalf message: %s", "test") lines := readLogFile(t, logPath) @@ -470,10 +476,10 @@ func TestFatalf(t *testing.T) { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "fatal", - "userID": "12345", - "message": "Fatalf message: test", - "error": "fatal error", + "level": "fatal", + "subscriber_id": "12345", + "message": "Fatalf message: test", + "error": "fatal error", } var found bool @@ -494,7 +500,7 @@ func TestFatalf(t *testing.T) { func TestPanicf(t *testing.T) { logPath := setupLogger(t, PanicLevel) - ctx := context.WithValue(context.Background(), userID, "12345") + ctx := context.WithValue(context.Background(), model.ContextKeySubscriberID, "12345") err := fmt.Errorf("panic error") Panicf(ctx, err, "Panicf message: %s", "test") lines := readLogFile(t, logPath) @@ -503,10 +509,10 @@ func TestPanicf(t *testing.T) { t.Fatal("No logs were written.") } expected := map[string]interface{}{ - "level": "panic", - "userID": "12345", - "message": "Panicf message: test", - "error": "panic error", + "level": "panic", + "subscriber_id": "12345", + "message": "Panicf message: test", + "error": "panic error", } var found bool diff --git a/pkg/model/error_test.go b/pkg/model/error_test.go index 7b9fd95..1ac952e 100644 --- a/pkg/model/error_test.go +++ b/pkg/model/error_test.go @@ -198,3 +198,55 @@ func TestSchemaValidationErr_BecknError_NoErrors(t *testing.T) { t.Errorf("beErr.Code = %s, want %s", beErr.Code, expectedCode) } } + +func TestParseContextKey_ValidKeys(t *testing.T) { + tests := []struct { + input string + expected ContextKey + }{ + {"transaction_id", ContextKeyTxnID}, + {"message_id", ContextKeyMsgID}, + {"subscriber_id", ContextKeySubscriberID}, + {"module_id", ContextKeyModuleID}, + } + + for _, tt := range tests { + key, err := ParseContextKey(tt.input) + if err != nil { + t.Errorf("unexpected error for input %s: %v", tt.input, err) + } + if key != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, key) + } + } +} + +func TestParseContextKey_InvalidKey(t *testing.T) { + _, err := ParseContextKey("invalid_key") + if err == nil { + t.Error("expected error for invalid context key, got nil") + } +} + +func TestContextKey_UnmarshalYAML_Valid(t *testing.T) { + yamlData := []byte("message_id") + var key ContextKey + + err := yaml.Unmarshal(yamlData, &key) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if key != ContextKeyMsgID { + t.Errorf("expected %s, got %s", ContextKeyMsgID, key) + } +} + +func TestContextKey_UnmarshalYAML_Invalid(t *testing.T) { + yamlData := []byte("invalid_key") + var key ContextKey + + err := yaml.Unmarshal(yamlData, &key) + if err == nil { + t.Error("expected error for invalid context key, got nil") + } +} diff --git a/pkg/model/model.go b/pkg/model/model.go index 536a9ad..a91e2a3 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -16,10 +16,6 @@ type Subscriber struct { Domain string `json:"domain"` } -type ContextKey string - -const ContextKeyModuleId ContextKey = "module_id" - // Subscription represents subscription details of a network participant. type Subscription struct { Subscriber `json:",inline"` @@ -42,10 +38,55 @@ const ( UnaAuthorizedHeaderGateway string = "Proxy-Authenticate" ) -type contextKey string +// ContextKey is a custom type used as a key for storing and retrieving values in a context. +type ContextKey string -// MsgIDKey is the context key used to store and retrieve the message ID in a request context. -const MsgIDKey = contextKey("message_id") +const ( + // ContextKeyTxnID is the context key used to store and retrieve the transaction ID in a request context. + ContextKeyTxnID ContextKey = "transaction_id" + + // ContextKeyMsgID is the context key used to store and retrieve the message ID in a request context. + ContextKeyMsgID ContextKey = "message_id" + + // ContextKeySubscriberID is the context key used to store and retrieve the subscriber ID in a request context. + ContextKeySubscriberID ContextKey = "subscriber_id" + + // ContextKeyModuleID is the context key for storing and retrieving the model ID from a request context. + ContextKeyModuleID ContextKey = "module_id" +) + +var contextKeys = map[string]ContextKey{ + "transaction_id": ContextKeyTxnID, + "message_id": ContextKeyMsgID, + "subscriber_id": ContextKeySubscriberID, + "module_id": ContextKeyModuleID, +} + +// ParseContextKey converts a string into a valid ContextKey. +func ParseContextKey(v string) (ContextKey, error) { + key, ok := contextKeys[v] + if !ok { + return "", fmt.Errorf("invalid context key: %s", v) + } + return key, nil +} + +// UnmarshalYAML ensures that only known context keys are accepted during YAML unmarshalling. +func (k *ContextKey) UnmarshalYAML(unmarshal func(interface{}) error) error { + var keyStr string + if err := unmarshal(&keyStr); err != nil { + return err + } + + parsedKey, err := ParseContextKey(keyStr) + if err != nil { + return err + } + + *k = parsedKey + return nil + +} // Role defines the type of participant in the network. type Role string diff --git a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go index 6d30c38..7ca2901 100644 --- a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go +++ b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor.go @@ -10,19 +10,19 @@ import ( "net/http" "github.com/beckn/beckn-onix/pkg/log" + "github.com/beckn/beckn-onix/pkg/model" ) +// Config represents the configuration for the request preprocessor middleware. type Config struct { - Role string + Role string + ContextKeys []string } -type keyType string - -const ( - contextKey keyType = "context" - subscriberIDKey keyType = "subscriber_id" -) +const contextKey = "context" +// NewPreProcessor returns a middleware that processes the incoming request, +// extracts the context field from the body, and adds relevant values (like subscriber ID). func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) { if err := validateConfig(cfg); err != nil { return nil, err @@ -41,7 +41,7 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) { return } - // Extract context from request + // Extract context from request. reqContext, ok := req["context"].(map[string]interface{}) if !ok { http.Error(w, fmt.Sprintf("%s field not found or invalid.", contextKey), http.StatusBadRequest) @@ -55,10 +55,15 @@ func NewPreProcessor(cfg *Config) (func(http.Handler) http.Handler, error) { subID = reqContext["bpp_id"] } if subID != nil { - log.Debugf(ctx, "adding subscriberId to request:%s, %v", subscriberIDKey, subID) - ctx = context.WithValue(ctx, subscriberIDKey, subID) + log.Debugf(ctx, "adding subscriberId to request:%s, %v", model.ContextKeySubscriberID, subID) + ctx = context.WithValue(ctx, model.ContextKeySubscriberID, subID) + } + for _, key := range cfg.ContextKeys { + ctxKey, _ := model.ParseContextKey(key) + if v, ok := reqContext[key]; ok { + ctx = context.WithValue(ctx, ctxKey, v) + } } - r.Body = io.NopCloser(bytes.NewBuffer(body)) r.ContentLength = int64(len(body)) r = r.WithContext(ctx) @@ -75,5 +80,11 @@ func validateConfig(cfg *Config) error { if cfg.Role != "bap" && cfg.Role != "bpp" { return errors.New("role must be either 'bap' or 'bpp'") } + + for _, key := range cfg.ContextKeys { + if _, err := model.ParseContextKey(key); err != nil { + return err + } + } return nil } diff --git a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go index 960d7ea..ed78d2d 100644 --- a/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go +++ b/pkg/plugin/implementation/reqpreprocessor/reqpreprocessor_test.go @@ -7,6 +7,8 @@ import ( "net/http/httptest" "strings" "testing" + + "github.com/beckn/beckn-onix/pkg/model" ) // ToDo Separate Middleware creation and execution. @@ -71,11 +73,11 @@ func TestNewPreProcessorSuccessCases(t *testing.T) { dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - gotSubID = ctx.Value(subscriberIDKey) + gotSubID = ctx.Value(model.ContextKeySubscriberID) w.WriteHeader(http.StatusOK) // Verify subscriber ID - subID := ctx.Value(subscriberIDKey) + subID := ctx.Value(model.ContextKeySubscriberID) if subID == nil { t.Errorf("Expected subscriber ID but got none %s", ctx) return @@ -230,3 +232,38 @@ func TestNewPreProcessorErrorCases(t *testing.T) { }) } } + +func TestNewPreProcessorAddsSubscriberIDToContext(t *testing.T) { + cfg := &Config{Role: "bap"} + middleware, err := NewPreProcessor(cfg) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + samplePayload := map[string]interface{}{ + "context": map[string]interface{}{ + "bap_id": "bap.example.com", + }, + } + bodyBytes, _ := json.Marshal(samplePayload) + + var receivedSubscriberID interface{} + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedSubscriberID = r.Context().Value(model.ContextKeySubscriberID) + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("POST", "/", strings.NewReader(string(bodyBytes))) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + middleware(handler).ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("Expected status 200 OK, got %d", rr.Code) + } + if receivedSubscriberID != "bap.example.com" { + t.Errorf("Expected subscriber ID 'bap.example.com', got %v", receivedSubscriberID) + } +} diff --git a/pkg/plugin/implementation/router/router_test.go b/pkg/plugin/implementation/router/router_test.go index 7937ca5..d0bb271 100644 --- a/pkg/plugin/implementation/router/router_test.go +++ b/pkg/plugin/implementation/router/router_test.go @@ -291,12 +291,12 @@ func TestValidateRulesFailure(t *testing.T) { Version: "1.0.0", TargetType: "url", Target: target{ - URL: "htp://invalid-url.com", // Invalid scheme + URL: "htp:// invalid-url.com", // Invalid scheme }, Endpoints: []string{"search"}, }, }, - wantErr: "invalid URL - htp://invalid-url.com: URL 'htp://invalid-url.com' must use https scheme", + wantErr: `invalid URL - htp:// invalid-url.com: parse "htp:// invalid-url.com": invalid character " " in host name`, }, { name: "Missing topic_id for targetType: publisher", @@ -321,12 +321,12 @@ func TestValidateRulesFailure(t *testing.T) { Version: "1.0.0", TargetType: "bpp", Target: target{ - URL: "htp://invalid-url.com", // Invalid URL + URL: "htp:// invalid-url.com", // Invalid URL }, Endpoints: []string{"search"}, }, }, - wantErr: "invalid URL - htp://invalid-url.com defined in routing config for target type bpp", + wantErr: `invalid URL - htp:// invalid-url.com defined in routing config for target type bpp: parse "htp:// invalid-url.com": invalid character " " in host name`, }, { name: "Invalid URL for BAP targetType", @@ -336,12 +336,12 @@ func TestValidateRulesFailure(t *testing.T) { Version: "1.0.0", TargetType: "bap", Target: target{ - URL: "http://[invalid].com", // Invalid host + URL: "http:// [invalid].com", // Invalid host }, Endpoints: []string{"search"}, }, }, - wantErr: "invalid URL - http://[invalid].com defined in routing config for target type bap", + wantErr: `invalid URL - http:// [invalid].com defined in routing config for target type bap: parse "http:// [invalid].com": invalid character " " in host name`, }, } @@ -464,8 +464,8 @@ func TestRouteFailure(t *testing.T) { name: "Invalid bpp_uri format in request", configFile: "bap_caller.yaml", url: "https://example.com/v1/ondc/select", - body: `{"context": {"domain": "ONDC:TRV10", "version": "2.0.0", "bpp_uri": "htp://invalid-url"}}`, // Invalid scheme (htp instead of http) - wantErr: "invalid BPP URI - htp://invalid-url in request body for select: URL 'htp://invalid-url' must use https scheme", + body: `{"context": {"domain": "ONDC:TRV10", "version": "2.0.0", "bpp_uri": "htp:// invalid-url"}}`, // Invalid scheme (htp instead of http) + wantErr: `invalid BPP URI - htp:// invalid-url in request body for select: parse "htp:// invalid-url": invalid character " " in host name`, }, } diff --git a/pkg/response/response.go b/pkg/response/response.go index a5ab0c4..0ced3de 100644 --- a/pkg/response/response.go +++ b/pkg/response/response.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" + "github.com/beckn/beckn-onix/pkg/log" "github.com/beckn/beckn-onix/pkg/model" ) @@ -47,8 +48,8 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i w.WriteHeader(status) _, er := w.Write(data) if er != nil { - fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.MsgIDKey)) - http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), http.StatusInternalServerError) + log.Debugf(ctx, "Error writing response: %v, MessageID: %s", er, ctx.Value(model.ContextKeyMsgID)) + http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)), http.StatusInternalServerError) return } } @@ -57,7 +58,7 @@ func nack(ctx context.Context, w http.ResponseWriter, err *model.Error, status i func internalServerError(ctx context.Context) *model.Error { return &model.Error{ Code: http.StatusText(http.StatusInternalServerError), - Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), + Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.ContextKeyMsgID)), } } diff --git a/pkg/response/response_test.go b/pkg/response/response_test.go index 96f1caa..f0c1c6d 100644 --- a/pkg/response/response_test.go +++ b/pkg/response/response_test.go @@ -46,7 +46,7 @@ func TestSendAck(t *testing.T) { } func TestSendNack(t *testing.T) { - ctx := context.WithValue(context.Background(), model.MsgIDKey, "123456") + ctx := context.WithValue(context.Background(), model.ContextKeyMsgID, "123456") tests := []struct { name string @@ -197,7 +197,7 @@ func TestNack_1(t *testing.T) { if err != nil { t.Fatal(err) } - ctx := context.WithValue(req.Context(), model.MsgIDKey, "12345") + ctx := context.WithValue(req.Context(), model.ContextKeyMsgID, "12345") var w http.ResponseWriter if tt.useBadWrite {