diff --git a/shared/plugin/implementation/requestPreProcessor/reqpreprocessor.go b/shared/plugin/implementation/requestPreProcessor/reqpreprocessor.go new file mode 100644 index 0000000..177aecc --- /dev/null +++ b/shared/plugin/implementation/requestPreProcessor/reqpreprocessor.go @@ -0,0 +1,100 @@ +package reqpreprocessor + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + "github.com/google/uuid" +) + +type Config struct { + checkKeys []string +} + +const contextKey = "context" + +func NewUUIDSetter(cfg *Config) (func(http.Handler) http.Handler, error) { + if err := validateConfig(cfg); err != nil { + return nil, err + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + var data map[string]any + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusInternalServerError) + return + } + if err := json.Unmarshal(body, &data); err != nil { + http.Error(w, "Failed to decode request body", http.StatusBadRequest) + return + } + contextRaw := data[contextKey] + if contextRaw == nil { + http.Error(w, fmt.Sprintf("%s field not found.", contextKey), http.StatusBadRequest) + return + } + contextData, ok := contextRaw.(map[string]any) + if !ok { + http.Error(w, fmt.Sprintf("%s field is not a map.", contextKey), http.StatusBadRequest) + return + } + ctx := r.Context() + for _, key := range cfg.checkKeys { + value := uuid.NewString() + updatedValue := update(contextData, key, value) + ctx = context.WithValue(ctx, key, updatedValue) + } + data[contextKey] = contextData + updatedBody, err := json.Marshal(data) + if err != nil { + http.Error(w, "Failed to marshal updated JSON", http.StatusInternalServerError) + return + } + r.Body = io.NopCloser(bytes.NewBuffer(updatedBody)) + r.ContentLength = int64(len(updatedBody)) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + }, nil +} + +func update(wrapper map[string]any, key string, value any) any { + field, exists := wrapper[key] + if !exists || isEmpty(field) { + wrapper[key] = value + return value + } + return field +} +func isEmpty(v any) bool { + switch v := v.(type) { + case string: + return v == "" + case nil: + return true + default: + return false + } +} + +func validateConfig(cfg *Config) error { + if cfg == nil { + return errors.New("config cannot be nil") + } + if len(cfg.checkKeys) == 0 { + return errors.New("checkKeys cannot be empty") + } + for _, key := range cfg.checkKeys { + if key == "" { + return errors.New("checkKeys cannot contain empty strings") + } + } + return nil +} diff --git a/shared/plugin/implementation/requestPreProcessor/reqpreprocessor_test.go b/shared/plugin/implementation/requestPreProcessor/reqpreprocessor_test.go new file mode 100644 index 0000000..47e46e0 --- /dev/null +++ b/shared/plugin/implementation/requestPreProcessor/reqpreprocessor_test.go @@ -0,0 +1,143 @@ +package reqpreprocessor + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewUUIDSetter(t *testing.T) { + tests := []struct { + name string + config *Config + requestBody map[string]any + expectedCode int + expectedKeys []string + }{ + { + name: "Valid keys, update missing keys", + config: &Config{ + checkKeys: []string{"transaction_id", "message_id"}, + }, + requestBody: map[string]any{ + "context": map[string]any{ + "transaction_id": "", + "message_id": nil, + }, + }, + expectedCode: http.StatusOK, + expectedKeys: []string{"transaction_id", "message_id"}, + }, + { + name: "Valid keys, do not update existing keys", + config: &Config{ + checkKeys: []string{"transaction_id", "message_id"}, + }, + requestBody: map[string]any{ + "context": map[string]any{ + "transaction_id": "existing-transaction", + "message_id": "existing-message", + }, + }, + expectedCode: http.StatusOK, + expectedKeys: []string{"transaction_id", "message_id"}, + }, + { + name: "Missing context key", + config: &Config{ + checkKeys: []string{"transaction_id"}, + }, + requestBody: map[string]any{ + "otherKey": "value", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "Invalid context type", + config: &Config{ + checkKeys: []string{"transaction_id"}, + }, + requestBody: map[string]any{ + "context": "not-a-map", + }, + expectedCode: http.StatusBadRequest, + }, + { + name: "Empty checkKeys in config", + config: &Config{ + checkKeys: []string{}, + }, + requestBody: map[string]any{ + "context": map[string]any{ + "transaction_id": "", + }, + }, + expectedCode: http.StatusInternalServerError, + }, + { + name: "Nil config", + config: nil, + requestBody: map[string]any{}, + expectedCode: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware, err := NewUUIDSetter(tt.config) + if tt.config == nil || len(tt.config.checkKeys) == 0 { + if err == nil { + t.Fatal("Expected an error, but got none") + } + return + } + if err != nil { + t.Fatalf("Unexpected error while creating middleware: %v", err) + } + + // Prepare request + bodyBytes, _ := json.Marshal(tt.requestBody) + req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(bodyBytes)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + // Define a dummy handler + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + io.Copy(w, r.Body) + }) + + // Apply middleware + middleware(dummyHandler).ServeHTTP(rec, req) + + // Check status code + if rec.Code != tt.expectedCode { + t.Errorf("Expected status code %d, but got %d", tt.expectedCode, rec.Code) + } + + // If success, check updated keys + if rec.Code == http.StatusOK { + var responseBody map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &responseBody); err != nil { + t.Fatal("Failed to unmarshal response body:", err) + } + + // Validate updated keys + contextData, ok := responseBody[contextKey].(map[string]any) + if !ok { + t.Fatalf("Expected context to be a map, got %T", responseBody[contextKey]) + } + + for _, key := range tt.expectedKeys { + value, exists := contextData[key] + if !exists || isEmpty(value) { + t.Errorf("Expected key %s to be set, but it's missing or empty", key) + } + } + } + }) + } +}