updated code as per the review comments

This commit is contained in:
MohitKatare-protean
2025-04-01 12:31:43 +05:30
parent 2fe39a2c0a
commit 450f13cf34
4 changed files with 45 additions and 92 deletions

View File

@@ -246,8 +246,6 @@ func (h *stdHandler) initSteps(ctx context.Context, mgr PluginManager, cfg *Conf
s, err = newValidateSchemaStep(h.schemaValidator) s, err = newValidateSchemaStep(h.schemaValidator)
case "addRoute": case "addRoute":
s, err = newAddRouteStep(h.router) s, err = newAddRouteStep(h.router)
case "broadcast":
s = &broadcastStep{}
default: default:
if customStep, exists := steps[step]; exists { if customStep, exists := steps[step]; exists {
s = customStep s = customStep

View File

@@ -167,12 +167,3 @@ func (s *addRouteStep) Run(ctx *model.StepContext) error {
} }
return nil return nil
} }
// broadcastStep is a stub implementation of a step that handles broadcasting messages.
type broadcastStep struct{}
// Run executes the broadcast step.
func (b *broadcastStep) Run(ctx *model.StepContext) error {
// TODO: Implement broadcast logic if needed
return nil
}

View File

@@ -69,14 +69,7 @@ func (m *mockPluginManager) SchemaValidator(ctx context.Context, cfg *plugin.Con
// TestRegisterSuccess tests scenarios where the handler registration should succeed. // TestRegisterSuccess tests scenarios where the handler registration should succeed.
func TestRegisterSuccess(t *testing.T) { func TestRegisterSuccess(t *testing.T) {
tests := []struct { mCfgs := []Config{
name string
mCfgs []Config
mockManager *mockPluginManager
}{
{
name: "successful registration",
mCfgs: []Config{
{ {
Name: "test-module", Name: "test-module",
Path: "/test", Path: "/test",
@@ -87,8 +80,9 @@ func TestRegisterSuccess(t *testing.T) {
}, },
}, },
}, },
}, }
mockManager: &mockPluginManager{
mockManager := &mockPluginManager{
middlewareFunc: func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) { middlewareFunc: func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -96,19 +90,13 @@ func TestRegisterSuccess(t *testing.T) {
}) })
}, nil }, nil
}, },
},
},
} }
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mux := http.NewServeMux() mux := http.NewServeMux()
err := Register(context.Background(), tt.mCfgs, mux, tt.mockManager) err := Register(context.Background(), mCfgs, mux, mockManager)
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
})
}
} }
// TestRegisterFailure tests scenarios where the handler registration should fail. // TestRegisterFailure tests scenarios where the handler registration should fail.

View File

@@ -71,30 +71,6 @@ func New(ctx context.Context, config *Config) (*Router, func() error, error) {
return router, nil, nil return router, nil, nil
} }
// parseTargetURL parses a URL string into a url.URL object with strict validation
func parseTargetURL(urlStr string) (*url.URL, error) {
if urlStr == "" {
return nil, nil
}
parsed, err := url.Parse(urlStr)
if err != nil {
return nil, fmt.Errorf("invalid URL '%s': %w", urlStr, err)
}
// Enforce scheme requirement
if parsed.Scheme == "" {
return nil, fmt.Errorf("URL '%s' must include a scheme (http/https)", urlStr)
}
// Optionally validate scheme is http or https
if parsed.Scheme != "https" {
return nil, fmt.Errorf("URL '%s' must use https scheme", urlStr)
}
return parsed, nil
}
// LoadRules reads and parses routing rules from the YAML configuration file. // LoadRules reads and parses routing rules from the YAML configuration file.
func (r *Router) loadRules(configPath string) error { func (r *Router) loadRules(configPath string) error {
if configPath == "" { if configPath == "" {
@@ -135,7 +111,7 @@ func (r *Router) loadRules(configPath string) error {
PublisherID: rule.Target.PublisherID, PublisherID: rule.Target.PublisherID,
} }
case targetTypeURL: case targetTypeURL:
parsedURL, err := parseTargetURL(rule.Target.URL) parsedURL, err := url.Parse(rule.Target.URL)
if err != nil { if err != nil {
return fmt.Errorf("invalid URL in rule: %w", err) return fmt.Errorf("invalid URL in rule: %w", err)
} }
@@ -146,7 +122,7 @@ func (r *Router) loadRules(configPath string) error {
case targetTypeBPP, targetTypeBAP: case targetTypeBPP, targetTypeBAP:
var parsedURL *url.URL var parsedURL *url.URL
if rule.Target.URL != "" { if rule.Target.URL != "" {
parsedURL, err = parseTargetURL(rule.Target.URL) parsedURL, err = url.Parse(rule.Target.URL)
if err != nil { if err != nil {
return fmt.Errorf("invalid URL in rule: %w", err) return fmt.Errorf("invalid URL in rule: %w", err)
} }
@@ -177,7 +153,7 @@ func validateRules(rules []routingRule) error {
if rule.Target.URL == "" { if rule.Target.URL == "" {
return fmt.Errorf("invalid rule: url is required for targetType 'url'") return fmt.Errorf("invalid rule: url is required for targetType 'url'")
} }
if _, err := parseTargetURL(rule.Target.URL); err != nil { if _, err := url.Parse(rule.Target.URL); err != nil {
return fmt.Errorf("invalid URL - %s: %w", rule.Target.URL, err) return fmt.Errorf("invalid URL - %s: %w", rule.Target.URL, err)
} }
case targetTypePublisher: case targetTypePublisher:
@@ -186,7 +162,7 @@ func validateRules(rules []routingRule) error {
} }
case targetTypeBPP, targetTypeBAP: case targetTypeBPP, targetTypeBAP:
if rule.Target.URL != "" { if rule.Target.URL != "" {
if _, err := parseTargetURL(rule.Target.URL); err != nil { if _, err := url.Parse(rule.Target.URL); err != nil {
return fmt.Errorf("invalid URL - %s defined in routing config for target type %s: %w", rule.Target.URL, rule.TargetType, err) return fmt.Errorf("invalid URL - %s defined in routing config for target type %s: %w", rule.Target.URL, rule.TargetType, err)
} }
} }
@@ -243,32 +219,32 @@ func (r *Router) Route(ctx context.Context, url *url.URL, body []byte) (*model.R
} }
// handleProtocolMapping handles both BPP and BAP routing with proper URL construction // handleProtocolMapping handles both BPP and BAP routing with proper URL construction
func handleProtocolMapping(route *model.Route, requestURI, endpoint string) (*model.Route, error) { func handleProtocolMapping(route *model.Route, npURI, endpoint string) (*model.Route, error) {
uri := strings.TrimSpace(requestURI) target := strings.TrimSpace(npURI)
var targetURL *url.URL if len(target) == 0 {
if len(uri) != 0 {
parsedURL, err := parseTargetURL(uri)
if err != nil {
return nil, fmt.Errorf("invalid %s URI - %s in request body for %s: %w", strings.ToUpper(route.TargetType), uri, endpoint, err)
}
targetURL = parsedURL
}
// If no request URI, fall back to configured URL with endpoint appended
if targetURL == nil {
if route.URL == nil { if route.URL == nil {
return nil, fmt.Errorf("could not determine destination for endpoint '%s': neither request contained a %s URI nor was a default URL configured in routing rules", endpoint, strings.ToUpper(route.TargetType)) return nil, fmt.Errorf("could not determine destination for endpoint '%s': neither request contained a %s URI nor was a default URL configured in routing rules", endpoint, strings.ToUpper(route.TargetType))
} }
return &model.Route{
targetURL = &url.URL{ TargetType: targetTypeURL,
URL: &url.URL{
Scheme: route.URL.Scheme, Scheme: route.URL.Scheme,
Host: route.URL.Host, Host: route.URL.Host,
Path: path.Join(route.URL.Path, endpoint), Path: path.Join(route.URL.Path, endpoint),
},
}, nil
} }
targetURL, err := url.Parse(target)
if err != nil {
return nil, fmt.Errorf("invalid %s URI - %s in request body for %s: %w", strings.ToUpper(route.TargetType), target, endpoint, err)
} }
return &model.Route{ return &model.Route{
TargetType: targetTypeURL, TargetType: targetTypeURL,
URL: targetURL, URL: &url.URL{
Scheme: targetURL.Scheme,
Host: targetURL.Host,
Path: path.Join(targetURL.Path, endpoint),
},
}, nil }, nil
} }