updated code as per the review comments
This commit is contained in:
@@ -246,8 +246,6 @@ func (h *stdHandler) initSteps(ctx context.Context, mgr PluginManager, cfg *Conf
|
|||||||
s, err = newValidateSchemaStep(h.schemaValidator)
|
s, err = newValidateSchemaStep(h.schemaValidator)
|
||||||
case "addRoute":
|
case "addRoute":
|
||||||
s, err = newAddRouteStep(h.router)
|
s, err = newAddRouteStep(h.router)
|
||||||
case "broadcast":
|
|
||||||
s = &broadcastStep{}
|
|
||||||
default:
|
default:
|
||||||
if customStep, exists := steps[step]; exists {
|
if customStep, exists := steps[step]; exists {
|
||||||
s = customStep
|
s = customStep
|
||||||
|
|||||||
@@ -167,12 +167,3 @@ func (s *addRouteStep) Run(ctx *model.StepContext) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// broadcastStep is a stub implementation of a step that handles broadcasting messages.
|
|
||||||
type broadcastStep struct{}
|
|
||||||
|
|
||||||
// Run executes the broadcast step.
|
|
||||||
func (b *broadcastStep) Run(ctx *model.StepContext) error {
|
|
||||||
// TODO: Implement broadcast logic if needed
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -69,14 +69,7 @@ func (m *mockPluginManager) SchemaValidator(ctx context.Context, cfg *plugin.Con
|
|||||||
|
|
||||||
// TestRegisterSuccess tests scenarios where the handler registration should succeed.
|
// TestRegisterSuccess tests scenarios where the handler registration should succeed.
|
||||||
func TestRegisterSuccess(t *testing.T) {
|
func TestRegisterSuccess(t *testing.T) {
|
||||||
tests := []struct {
|
mCfgs := []Config{
|
||||||
name string
|
|
||||||
mCfgs []Config
|
|
||||||
mockManager *mockPluginManager
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "successful registration",
|
|
||||||
mCfgs: []Config{
|
|
||||||
{
|
{
|
||||||
Name: "test-module",
|
Name: "test-module",
|
||||||
Path: "/test",
|
Path: "/test",
|
||||||
@@ -87,8 +80,9 @@ func TestRegisterSuccess(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
mockManager: &mockPluginManager{
|
|
||||||
|
mockManager := &mockPluginManager{
|
||||||
middlewareFunc: func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) {
|
middlewareFunc: func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -96,19 +90,13 @@ func TestRegisterSuccess(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
err := Register(context.Background(), tt.mCfgs, mux, tt.mockManager)
|
err := Register(context.Background(), mCfgs, mux, mockManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRegisterFailure tests scenarios where the handler registration should fail.
|
// TestRegisterFailure tests scenarios where the handler registration should fail.
|
||||||
|
|||||||
@@ -71,30 +71,6 @@ func New(ctx context.Context, config *Config) (*Router, func() error, error) {
|
|||||||
return router, nil, nil
|
return router, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseTargetURL parses a URL string into a url.URL object with strict validation
|
|
||||||
func parseTargetURL(urlStr string) (*url.URL, error) {
|
|
||||||
if urlStr == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
parsed, err := url.Parse(urlStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid URL '%s': %w", urlStr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enforce scheme requirement
|
|
||||||
if parsed.Scheme == "" {
|
|
||||||
return nil, fmt.Errorf("URL '%s' must include a scheme (http/https)", urlStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Optionally validate scheme is http or https
|
|
||||||
if parsed.Scheme != "https" {
|
|
||||||
return nil, fmt.Errorf("URL '%s' must use https scheme", urlStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return parsed, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadRules reads and parses routing rules from the YAML configuration file.
|
// LoadRules reads and parses routing rules from the YAML configuration file.
|
||||||
func (r *Router) loadRules(configPath string) error {
|
func (r *Router) loadRules(configPath string) error {
|
||||||
if configPath == "" {
|
if configPath == "" {
|
||||||
@@ -135,7 +111,7 @@ func (r *Router) loadRules(configPath string) error {
|
|||||||
PublisherID: rule.Target.PublisherID,
|
PublisherID: rule.Target.PublisherID,
|
||||||
}
|
}
|
||||||
case targetTypeURL:
|
case targetTypeURL:
|
||||||
parsedURL, err := parseTargetURL(rule.Target.URL)
|
parsedURL, err := url.Parse(rule.Target.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid URL in rule: %w", err)
|
return fmt.Errorf("invalid URL in rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -146,7 +122,7 @@ func (r *Router) loadRules(configPath string) error {
|
|||||||
case targetTypeBPP, targetTypeBAP:
|
case targetTypeBPP, targetTypeBAP:
|
||||||
var parsedURL *url.URL
|
var parsedURL *url.URL
|
||||||
if rule.Target.URL != "" {
|
if rule.Target.URL != "" {
|
||||||
parsedURL, err = parseTargetURL(rule.Target.URL)
|
parsedURL, err = url.Parse(rule.Target.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid URL in rule: %w", err)
|
return fmt.Errorf("invalid URL in rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -177,7 +153,7 @@ func validateRules(rules []routingRule) error {
|
|||||||
if rule.Target.URL == "" {
|
if rule.Target.URL == "" {
|
||||||
return fmt.Errorf("invalid rule: url is required for targetType 'url'")
|
return fmt.Errorf("invalid rule: url is required for targetType 'url'")
|
||||||
}
|
}
|
||||||
if _, err := parseTargetURL(rule.Target.URL); err != nil {
|
if _, err := url.Parse(rule.Target.URL); err != nil {
|
||||||
return fmt.Errorf("invalid URL - %s: %w", rule.Target.URL, err)
|
return fmt.Errorf("invalid URL - %s: %w", rule.Target.URL, err)
|
||||||
}
|
}
|
||||||
case targetTypePublisher:
|
case targetTypePublisher:
|
||||||
@@ -186,7 +162,7 @@ func validateRules(rules []routingRule) error {
|
|||||||
}
|
}
|
||||||
case targetTypeBPP, targetTypeBAP:
|
case targetTypeBPP, targetTypeBAP:
|
||||||
if rule.Target.URL != "" {
|
if rule.Target.URL != "" {
|
||||||
if _, err := parseTargetURL(rule.Target.URL); err != nil {
|
if _, err := url.Parse(rule.Target.URL); err != nil {
|
||||||
return fmt.Errorf("invalid URL - %s defined in routing config for target type %s: %w", rule.Target.URL, rule.TargetType, err)
|
return fmt.Errorf("invalid URL - %s defined in routing config for target type %s: %w", rule.Target.URL, rule.TargetType, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -243,32 +219,32 @@ func (r *Router) Route(ctx context.Context, url *url.URL, body []byte) (*model.R
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleProtocolMapping handles both BPP and BAP routing with proper URL construction
|
// handleProtocolMapping handles both BPP and BAP routing with proper URL construction
|
||||||
func handleProtocolMapping(route *model.Route, requestURI, endpoint string) (*model.Route, error) {
|
func handleProtocolMapping(route *model.Route, npURI, endpoint string) (*model.Route, error) {
|
||||||
uri := strings.TrimSpace(requestURI)
|
target := strings.TrimSpace(npURI)
|
||||||
var targetURL *url.URL
|
if len(target) == 0 {
|
||||||
if len(uri) != 0 {
|
|
||||||
parsedURL, err := parseTargetURL(uri)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid %s URI - %s in request body for %s: %w", strings.ToUpper(route.TargetType), uri, endpoint, err)
|
|
||||||
}
|
|
||||||
targetURL = parsedURL
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no request URI, fall back to configured URL with endpoint appended
|
|
||||||
if targetURL == nil {
|
|
||||||
if route.URL == nil {
|
if route.URL == nil {
|
||||||
return nil, fmt.Errorf("could not determine destination for endpoint '%s': neither request contained a %s URI nor was a default URL configured in routing rules", endpoint, strings.ToUpper(route.TargetType))
|
return nil, fmt.Errorf("could not determine destination for endpoint '%s': neither request contained a %s URI nor was a default URL configured in routing rules", endpoint, strings.ToUpper(route.TargetType))
|
||||||
}
|
}
|
||||||
|
return &model.Route{
|
||||||
targetURL = &url.URL{
|
TargetType: targetTypeURL,
|
||||||
|
URL: &url.URL{
|
||||||
Scheme: route.URL.Scheme,
|
Scheme: route.URL.Scheme,
|
||||||
Host: route.URL.Host,
|
Host: route.URL.Host,
|
||||||
Path: path.Join(route.URL.Path, endpoint),
|
Path: path.Join(route.URL.Path, endpoint),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
targetURL, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid %s URI - %s in request body for %s: %w", strings.ToUpper(route.TargetType), target, endpoint, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &model.Route{
|
return &model.Route{
|
||||||
TargetType: targetTypeURL,
|
TargetType: targetTypeURL,
|
||||||
URL: targetURL,
|
URL: &url.URL{
|
||||||
|
Scheme: targetURL.Scheme,
|
||||||
|
Host: targetURL.Host,
|
||||||
|
Path: path.Join(targetURL.Path, endpoint),
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user