From 13a7a18e1753699db46f97c492cd424467d9cd0c Mon Sep 17 00:00:00 2001 From: Ritesh Date: Tue, 24 Mar 2026 17:59:48 +0530 Subject: [PATCH] Fix: address policy checker review feedback --- cmd/adapter/main_test.go | 57 ++++++++++++ core/module/module_test.go | 59 ++++++++++++- .../implementation/opapolicychecker/README.md | 9 +- .../opapolicychecker/benchmark_test.go | 12 +-- .../opapolicychecker/cmd/plugin_test.go | 41 +++++++++ .../opapolicychecker/enforcer.go | 70 +++++++++++---- .../opapolicychecker/enforcer_test.go | 86 +++++++++++++++---- .../opapolicychecker/evaluator.go | 36 ++++---- 8 files changed, 313 insertions(+), 57 deletions(-) create mode 100644 pkg/plugin/implementation/opapolicychecker/cmd/plugin_test.go diff --git a/cmd/adapter/main_test.go b/cmd/adapter/main_test.go index cd18015..5c003a5 100644 --- a/cmd/adapter/main_test.go +++ b/cmd/adapter/main_test.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "net/http" + "net/http/httptest" "os" "path/filepath" "strings" @@ -13,6 +14,7 @@ import ( "github.com/beckn-one/beckn-onix/core/module" "github.com/beckn-one/beckn-onix/core/module/handler" + "github.com/beckn-one/beckn-onix/pkg/model" "github.com/beckn-one/beckn-onix/pkg/plugin" "github.com/beckn-one/beckn-onix/pkg/plugin/definition" "github.com/stretchr/testify/mock" @@ -21,6 +23,15 @@ import ( // MockPluginManager implements handler.PluginManager for testing. type MockPluginManager struct { mock.Mock + policyCheckerFunc func(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) +} + +type stubPolicyChecker struct { + err error +} + +func (s stubPolicyChecker) CheckPolicy(*model.StepContext) error { + return s.err } // Middleware returns a middleware function based on the provided configuration. @@ -85,6 +96,9 @@ func (m *MockPluginManager) SchemaValidator(ctx context.Context, cfg *plugin.Con // PolicyChecker returns a mock implementation of the PolicyChecker interface. func (m *MockPluginManager) PolicyChecker(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) { + if m.policyCheckerFunc != nil { + return m.policyCheckerFunc(ctx, cfg) + } return nil, nil } @@ -335,6 +349,49 @@ func TestNewServerSuccess(t *testing.T) { } } +func TestNewServerRejectsPolicyViolation(t *testing.T) { + mockMgr := &MockPluginManager{ + policyCheckerFunc: func(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) { + return stubPolicyChecker{err: model.NewBadReqErr(errors.New("blocked by policy"))}, nil + }, + } + + cfg := &Config{ + Modules: []module.Config{ + { + Name: "policy-module", + Path: "/policy", + Handler: handler.Config{ + Type: handler.HandlerTypeStd, + Plugins: handler.PluginCfg{ + PolicyChecker: &plugin.Config{ID: "mock-policy"}, + }, + Steps: []string{"checkPolicy"}, + }, + }, + }, + } + + h, err := newServer(context.Background(), mockMgr, cfg) + if err != nil { + t.Fatalf("expected no error creating server, got %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/policy", strings.NewReader(`{"context":{"action":"confirm"}}`)) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for policy violation, got %d", rec.Code) + } + if !strings.Contains(rec.Body.String(), "NACK") { + t.Fatalf("expected NACK response, got %s", rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "blocked by policy") { + t.Fatalf("expected policy error in response, got %s", rec.Body.String()) + } +} + // TestNewServerFailure tests failure scenarios when creating a server. func TestNewServerFailure(t *testing.T) { tests := []struct { diff --git a/core/module/module_test.go b/core/module/module_test.go index 3f26c4c..c050e70 100644 --- a/core/module/module_test.go +++ b/core/module/module_test.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "net/http/httptest" + "strings" "testing" "github.com/beckn-one/beckn-onix/core/module/handler" @@ -16,7 +17,8 @@ import ( // mockPluginManager is a mock implementation of the PluginManager interface // with support for dynamically setting behavior. type mockPluginManager struct { - middlewareFunc func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) + middlewareFunc func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) + policyCheckerFunc func(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) } // Middleware returns a mock middleware function based on the provided configuration. @@ -81,10 +83,65 @@ func (m *mockPluginManager) SchemaValidator(ctx context.Context, cfg *plugin.Con // PolicyChecker returns a mock policy checker implementation. func (m *mockPluginManager) PolicyChecker(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) { + if m.policyCheckerFunc != nil { + return m.policyCheckerFunc(ctx, cfg) + } return nil, nil } +type mockPolicyChecker struct { + err error +} + +func (m mockPolicyChecker) CheckPolicy(*model.StepContext) error { + return m.err +} + // TestRegisterSuccess tests scenarios where the handler registration should succeed. +func TestRegisterRejectsPolicyViolation(t *testing.T) { + mCfgs := []Config{ + { + Name: "test-module", + Path: "/test", + Handler: handler.Config{ + Type: handler.HandlerTypeStd, + Plugins: handler.PluginCfg{ + PolicyChecker: &plugin.Config{ID: "mock-policy"}, + }, + Steps: []string{"checkPolicy"}, + }, + }, + } + + mockManager := &mockPluginManager{ + middlewareFunc: func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) { + return func(next http.Handler) http.Handler { return next }, nil + }, + policyCheckerFunc: func(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) { + return mockPolicyChecker{err: model.NewBadReqErr(errors.New("blocked by policy"))}, nil + }, + } + + mux := http.NewServeMux() + if err := Register(context.Background(), mCfgs, mux, mockManager); err != nil { + t.Fatalf("unexpected register error: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader(`{"context":{"action":"confirm"}}`)) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for policy violation, got %d", rec.Code) + } + if !strings.Contains(rec.Body.String(), "NACK") { + t.Fatalf("expected NACK response, got %s", rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "blocked by policy") { + t.Fatalf("expected policy error in response, got %s", rec.Body.String()) + } +} + func TestRegisterSuccess(t *testing.T) { mCfgs := []Config{ { diff --git a/pkg/plugin/implementation/opapolicychecker/README.md b/pkg/plugin/implementation/opapolicychecker/README.md index 159fdf4..0e18c9e 100644 --- a/pkg/plugin/implementation/opapolicychecker/README.md +++ b/pkg/plugin/implementation/opapolicychecker/README.md @@ -10,6 +10,8 @@ Validates incoming Beckn messages against network-defined business rules using [ - Fail-closed on empty/undefined query results — misconfigured policies are treated as violations - Runtime config forwarding: adapter config values are accessible in Rego as `data.config.` - Action-based enforcement: apply policies only to specific beckn actions (e.g., `confirm`, `search`) +- Configurable fetch timeout for remote policy and bundle sources +- Warns at startup when policy enforcement is explicitly disabled ## Configuration @@ -36,6 +38,7 @@ steps: | `actions` | string | No | *(all)* | Comma-separated beckn actions to enforce | | `enabled` | string | No | `"true"` | Enable or disable the plugin | | `debugLogging` | string | No | `"false"` | Enable verbose OPA evaluation logging | +| `fetchTimeoutSeconds` | string | No | `"30"` | Timeout in seconds for fetching remote `.rego` files or bundles | | `refreshIntervalSeconds` | string | No | - | Reload policies every N seconds (0 or omit = disabled) | | *any other key* | string | No | - | Forwarded to Rego as `data.config.` | @@ -47,7 +50,7 @@ When `refreshIntervalSeconds` is set, a background goroutine periodically re-fet - **Atomic swap**: the old evaluator stays fully active until the new one is compiled — no gap in enforcement - **Non-fatal errors**: if the reload fails (e.g., file temporarily unreachable or parse error), the error is logged and the previous policy stays active -- **Goroutine lifecycle**: the reload loop is tied to the adapter context and stops cleanly on shutdown +- **Goroutine lifecycle**: the reload loop stops when the adapter context is cancelled or when plugin `Close()` is invoked during shutdown ```yaml config: @@ -67,7 +70,7 @@ config: ### Request Evaluation (Runtime) -1. **Check Action Match**: If `actions` is configured, skip evaluation for non-matching actions +1. **Check Action Match**: If `actions` is configured, skip evaluation for non-matching actions. The plugin assumes standard adapter routes look like `/{participant}/{direction}/{action}` such as `/bpp/caller/confirm`; non-standard paths fall back to `context.action` from the JSON body. 2. **Evaluate OPA Query**: Run the prepared query with the full beckn message as `input` 3. **Handle Result**: - If the query returns no result (undefined) → **violation** (fail-closed) @@ -109,6 +112,7 @@ checkPolicy: type: url location: https://policies.example.com/compliance.rego query: "data.policy.result" + fetchTimeoutSeconds: "10" ``` ### Local Directory (multiple `.rego` files) @@ -192,3 +196,4 @@ Configure them side-by-side in your adapter steps as needed. - **No bundle signature verification**: When using `type: bundle`, bundle signature verification is skipped. This is planned for a future enhancement. - **Network-level scoping**: Policies apply to all messages handled by the adapter instance. Per-network policy mapping (by `networkId`) is tracked for follow-up. +- **Non-standard route shapes**: URL-based action extraction assumes the standard Beckn adapter route shape `/{participant}/{direction}/{action}` and falls back to `context.action` for other path layouts. diff --git a/pkg/plugin/implementation/opapolicychecker/benchmark_test.go b/pkg/plugin/implementation/opapolicychecker/benchmark_test.go index f070905..1f10a08 100644 --- a/pkg/plugin/implementation/opapolicychecker/benchmark_test.go +++ b/pkg/plugin/implementation/opapolicychecker/benchmark_test.go @@ -101,7 +101,7 @@ func BenchmarkEvaluate_MostlyInactive(b *testing.B) { dir := b.TempDir() os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { b.Fatalf("NewEvaluator failed: %v", err) } @@ -135,7 +135,7 @@ func BenchmarkEvaluate_AllActive(b *testing.B) { dir := b.TempDir() os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateActiveRules(n)), 0644) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { b.Fatalf("NewEvaluator failed: %v", err) } @@ -172,7 +172,7 @@ func BenchmarkCompilation(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + _, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { b.Fatalf("NewEvaluator failed: %v", err) } @@ -206,7 +206,7 @@ func TestBenchmarkReport(t *testing.T) { os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644) start := time.Now() - _, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + _, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) elapsed := time.Since(start) if err != nil { t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err) @@ -227,7 +227,7 @@ func TestBenchmarkReport(t *testing.T) { dir := t.TempDir() os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err) } @@ -263,7 +263,7 @@ func TestBenchmarkReport(t *testing.T) { dir := t.TempDir() os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateActiveRules(n)), 0644) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err) } diff --git a/pkg/plugin/implementation/opapolicychecker/cmd/plugin_test.go b/pkg/plugin/implementation/opapolicychecker/cmd/plugin_test.go new file mode 100644 index 0000000..8ed77c7 --- /dev/null +++ b/pkg/plugin/implementation/opapolicychecker/cmd/plugin_test.go @@ -0,0 +1,41 @@ +package main + +import ( + "context" + "path/filepath" + "testing" +) + +func TestProviderNewSuccess(t *testing.T) { + provider := provider{} + config := map[string]string{ + "type": "file", + "location": filepath.Join("..", "testdata", "example.rego"), + "query": "data.policy.result", + } + + checker, closer, err := provider.New(context.Background(), config) + if err != nil { + t.Fatalf("New() unexpected error: %v", err) + } + if checker == nil { + t.Fatal("New() returned nil checker") + } + if closer == nil { + t.Fatal("New() returned nil closer") + } + + closer() +} + +func TestProviderNewFailure(t *testing.T) { + provider := provider{} + + _, _, err := provider.New(context.Background(), map[string]string{ + "type": "file", + "query": "data.policy.result", + }) + if err == nil { + t.Fatal("expected error when required config is missing") + } +} diff --git a/pkg/plugin/implementation/opapolicychecker/enforcer.go b/pkg/plugin/implementation/opapolicychecker/enforcer.go index 0dd631c..94fb0bb 100644 --- a/pkg/plugin/implementation/opapolicychecker/enforcer.go +++ b/pkg/plugin/implementation/opapolicychecker/enforcer.go @@ -22,6 +22,7 @@ type Config struct { Actions []string Enabled bool DebugLogging bool + FetchTimeout time.Duration IsBundle bool RefreshInterval time.Duration // 0 = disabled RuntimeConfig map[string]string @@ -34,12 +35,14 @@ var knownKeys = map[string]bool{ "actions": true, "enabled": true, "debugLogging": true, + "fetchTimeoutSeconds": true, "refreshIntervalSeconds": true, } func DefaultConfig() *Config { return &Config{ Enabled: true, + FetchTimeout: defaultPolicyFetchTimeout, RuntimeConfig: make(map[string]string), } } @@ -105,6 +108,14 @@ func ParseConfig(cfg map[string]string) (*Config, error) { config.DebugLogging = debug == "true" || debug == "1" } + if fts, ok := cfg["fetchTimeoutSeconds"]; ok && fts != "" { + secs, err := strconv.Atoi(fts) + if err != nil || secs <= 0 { + return nil, fmt.Errorf("'fetchTimeoutSeconds' must be a positive integer, got %q", fts) + } + config.FetchTimeout = time.Duration(secs) * time.Second + } + if ris, ok := cfg["refreshIntervalSeconds"]; ok && ris != "" { secs, err := strconv.Atoi(ris) if err != nil || secs < 0 { @@ -136,9 +147,11 @@ func (c *Config) IsActionEnabled(action string) bool { // PolicyEnforcer evaluates beckn messages against OPA policies and NACKs non-compliant messages. type PolicyEnforcer struct { - config *Config - evaluator *Evaluator - evaluatorMu sync.RWMutex + config *Config + evaluator *Evaluator + evaluatorMu sync.RWMutex + closeOnce sync.Once + done chan struct{} } // getEvaluator safely returns the current evaluator under a read lock. @@ -162,18 +175,24 @@ func New(ctx context.Context, cfg map[string]string) (*PolicyEnforcer, error) { return nil, fmt.Errorf("opapolicychecker: config error: %w", err) } - evaluator, err := NewEvaluator(config.PolicyPaths, config.Query, config.RuntimeConfig, config.IsBundle) + enforcer := &PolicyEnforcer{ + config: config, + done: make(chan struct{}), + } + + if !config.Enabled { + log.Warnf(ctx, "OPAPolicyChecker is disabled via config; policy enforcement will be skipped") + return enforcer, nil + } + + evaluator, err := NewEvaluator(config.PolicyPaths, config.Query, config.RuntimeConfig, config.IsBundle, config.FetchTimeout) if err != nil { return nil, fmt.Errorf("opapolicychecker: failed to initialize OPA evaluator: %w", err) } + enforcer.evaluator = evaluator - log.Infof(ctx, "OPAPolicyChecker initialized (actions=%v, query=%s, policies=%v, isBundle=%v, debugLogging=%v, refreshInterval=%s)", - config.Actions, config.Query, evaluator.ModuleNames(), config.IsBundle, config.DebugLogging, config.RefreshInterval) - - enforcer := &PolicyEnforcer{ - config: config, - evaluator: evaluator, - } + log.Infof(ctx, "OPAPolicyChecker initialized (actions=%v, query=%s, policies=%v, isBundle=%v, debugLogging=%v, fetchTimeout=%s, refreshInterval=%s)", + config.Actions, config.Query, evaluator.ModuleNames(), config.IsBundle, config.DebugLogging, config.FetchTimeout, config.RefreshInterval) if config.RefreshInterval > 0 { go enforcer.refreshLoop(ctx) @@ -193,6 +212,9 @@ func (e *PolicyEnforcer) refreshLoop(ctx context.Context) { case <-ctx.Done(): log.Infof(ctx, "OPAPolicyChecker: refresh loop stopped") return + case <-e.done: + log.Infof(ctx, "OPAPolicyChecker: refresh loop stopped by Close()") + return case <-ticker.C: e.reloadPolicies(ctx) } @@ -208,6 +230,7 @@ func (e *PolicyEnforcer) reloadPolicies(ctx context.Context) { e.config.Query, e.config.RuntimeConfig, e.config.IsBundle, + e.config.FetchTimeout, ) if err != nil { log.Errorf(ctx, err, "OPAPolicyChecker: policy reload failed (keeping previous policies): %v", err) @@ -237,6 +260,9 @@ func (e *PolicyEnforcer) CheckPolicy(ctx *model.StepContext) error { } ev := e.getEvaluator() + if ev == nil { + return model.NewBadReqErr(fmt.Errorf("policy evaluator is not initialized")) + } if e.config.DebugLogging { log.Debugf(ctx, "OPAPolicyChecker: evaluating policies for action %q (modules=%v)", action, ev.ModuleNames()) @@ -260,12 +286,17 @@ func (e *PolicyEnforcer) CheckPolicy(ctx *model.StepContext) error { return model.NewBadReqErr(fmt.Errorf("%s", msg)) } -func (e *PolicyEnforcer) Close() {} +func (e *PolicyEnforcer) Close() { + e.closeOnce.Do(func() { + close(e.done) + }) +} func extractAction(urlPath string, body []byte) string { - parts := strings.Split(strings.Trim(urlPath, "/"), "/") - if len(parts) >= 3 { - return parts[len(parts)-1] + // /bpp/caller/confirm/extra as action "extra". + parts := strings.FieldsFunc(strings.Trim(urlPath, "/"), func(r rune) bool { return r == '/' }) + if len(parts) == 3 && isBecknDirection(parts[1]) && parts[2] != "" { + return parts[2] } var payload struct { @@ -279,3 +310,12 @@ func extractAction(urlPath string, body []byte) string { return "" } + +func isBecknDirection(part string) bool { + switch part { + case "caller", "receiver", "reciever": + return true + default: + return false + } +} diff --git a/pkg/plugin/implementation/opapolicychecker/enforcer_test.go b/pkg/plugin/implementation/opapolicychecker/enforcer_test.go index 2aba5a2..7d0e8a9 100644 --- a/pkg/plugin/implementation/opapolicychecker/enforcer_test.go +++ b/pkg/plugin/implementation/opapolicychecker/enforcer_test.go @@ -141,7 +141,7 @@ violations contains msg if { } ` dir := writePolicyDir(t, "test.rego", policy) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -165,7 +165,7 @@ violations contains msg if { } ` dir := writePolicyDir(t, "test.rego", policy) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -192,7 +192,7 @@ violations contains msg if { } ` dir := writePolicyDir(t, "test.rego", policy) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", map[string]string{"maxValue": "100"}, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", map[string]string{"maxValue": "100"}, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -235,7 +235,7 @@ test_something if { count(policy.violations) > 0 } ` os.WriteFile(filepath.Join(dir, "policy_test.rego"), []byte(testFile), 0644) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator should skip _test.rego files, but failed: %v", err) } @@ -256,7 +256,7 @@ import rego.v1 violations := set() ` dir := writePolicyDir(t, "test.rego", policy) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -285,7 +285,7 @@ violations contains msg if { })) defer srv.Close() - eval, err := NewEvaluator([]string{srv.URL + "/test_policy.rego"}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{srv.URL + "/test_policy.rego"}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator with URL failed: %v", err) } @@ -313,14 +313,14 @@ func TestEvaluator_FetchURL_NotFound(t *testing.T) { srv := httptest.NewServer(http.NotFoundHandler()) defer srv.Close() - _, err := NewEvaluator([]string{srv.URL + "/missing.rego"}, "data.policy.violations", nil, false) + _, err := NewEvaluator([]string{srv.URL + "/missing.rego"}, "data.policy.violations", nil, false, 0) if err == nil { t.Fatal("expected error for 404 URL") } } func TestEvaluator_FetchURL_InvalidScheme(t *testing.T) { - _, err := NewEvaluator([]string{"ftp://example.com/policy.rego"}, "data.policy.violations", nil, false) + _, err := NewEvaluator([]string{"ftp://example.com/policy.rego"}, "data.policy.violations", nil, false, 0) if err == nil { t.Fatal("expected error for ftp:// scheme") } @@ -346,7 +346,7 @@ violations contains "remote_violation" if { input.remote_bad } })) defer srv.Close() - eval, err := NewEvaluator([]string{dir, srv.URL + "/remote.rego"}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir, srv.URL + "/remote.rego"}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -373,7 +373,7 @@ violations contains "from_file" if { input.bad } policyPath := filepath.Join(dir, "local_policy.rego") os.WriteFile(policyPath, []byte(policy), 0644) - eval, err := NewEvaluator([]string{policyPath}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{policyPath}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator with local path failed: %v", err) } @@ -412,7 +412,7 @@ violations contains "order too large" if { is_high_value } ` os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -462,7 +462,7 @@ violations contains "high value confirm blocked" if { ` os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -531,7 +531,7 @@ violations contains "safety: order value too high" if { ` os.WriteFile(filepath.Join(dir, "safety.rego"), []byte(safety), 0644) - eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -744,7 +744,7 @@ default result := { } ` dir := writePolicyDir(t, "policy.rego", policy) - eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -782,7 +782,7 @@ violations contains msg if { } ` dir := writePolicyDir(t, "policy.rego", policy) - eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -841,7 +841,7 @@ result := { } ` dir := writePolicyDir(t, "policy.rego", policy) - eval, err := NewEvaluator([]string{dir}, "data.policy.result", nil, false) + eval, err := NewEvaluator([]string{dir}, "data.policy.result", nil, false, 0) if err != nil { t.Fatalf("NewEvaluator failed: %v", err) } @@ -912,7 +912,7 @@ violations contains msg if { })) defer srv.Close() - eval, err := NewEvaluator([]string{srv.URL + "/bundle.tar.gz"}, "data.retail.validation.result", nil, true) + eval, err := NewEvaluator([]string{srv.URL + "/bundle.tar.gz"}, "data.retail.validation.result", nil, true, 0) if err != nil { t.Fatalf("NewEvaluator with bundle failed: %v", err) } @@ -1096,3 +1096,55 @@ default result := {"valid": true, "violations": []} t.Fatal("hot-reload did not take effect within 5 seconds") } +func TestParseConfig_FetchTimeout(t *testing.T) { + cfg, err := ParseConfig(map[string]string{ + "type": "url", + "location": "https://example.com/policy.rego", + "query": "data.policy.violations", + "fetchTimeoutSeconds": "7", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.FetchTimeout != 7*time.Second { + t.Fatalf("expected fetch timeout 7s, got %s", cfg.FetchTimeout) + } +} + +func TestEvaluator_FetchURL_Timeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(50 * time.Millisecond) + w.Write([]byte(`package policy +import rego.v1 +violations := []`)) + })) + defer srv.Close() + + _, err := NewEvaluator([]string{srv.URL + "/slow.rego"}, "data.policy.violations", nil, false, 10*time.Millisecond) + if err == nil { + t.Fatal("expected timeout error for slow policy URL") + } +} + +func TestExtractAction_NonStandardURLFallsBackToBody(t *testing.T) { + body := []byte(`{"context": {"action": "confirm"}}`) + action := extractAction("/bpp/caller/confirm/extra", body) + if action != "confirm" { + t.Fatalf("expected body fallback action 'confirm', got %q", action) + } +} + +func TestEnforcer_DisabledSkipsEvaluatorInitialization(t *testing.T) { + enforcer, err := New(context.Background(), map[string]string{ + "type": "url", + "location": "https://127.0.0.1:1/unreachable.rego", + "query": "data.policy.violations", + "enabled": "false", + }) + if err != nil { + t.Fatalf("expected disabled enforcer to skip evaluator initialization, got %v", err) + } + if enforcer.getEvaluator() != nil { + t.Fatal("expected disabled enforcer to leave evaluator uninitialized") + } +} diff --git a/pkg/plugin/implementation/opapolicychecker/evaluator.go b/pkg/plugin/implementation/opapolicychecker/evaluator.go index aff3584..dc57f46 100644 --- a/pkg/plugin/implementation/opapolicychecker/evaluator.go +++ b/pkg/plugin/implementation/opapolicychecker/evaluator.go @@ -35,8 +35,9 @@ func (e *Evaluator) ModuleNames() []string { return e.moduleNames } -// policyFetchTimeout is the HTTP timeout for fetching remote .rego files. -const policyFetchTimeout = 30 * time.Second +// defaultPolicyFetchTimeout bounds remote policy and bundle fetches during startup +// and refresh. This can be overridden via config.fetchTimeoutSeconds. +const defaultPolicyFetchTimeout = 30 * time.Second // maxPolicySize is the maximum size of a single .rego file fetched from a URL (1 MB). const maxPolicySize = 1 << 20 @@ -47,21 +48,24 @@ const maxBundleSize = 10 << 20 // NewEvaluator creates an Evaluator by loading .rego files from local paths // and/or URLs, then compiling them. runtimeConfig is passed to Rego as data.config. // When isBundle is true, the first policyPath is treated as a URL to an OPA bundle (.tar.gz). -func NewEvaluator(policyPaths []string, query string, runtimeConfig map[string]string, isBundle bool) (*Evaluator, error) { - if isBundle { - return newBundleEvaluator(policyPaths, query, runtimeConfig) +func NewEvaluator(policyPaths []string, query string, runtimeConfig map[string]string, isBundle bool, fetchTimeout time.Duration) (*Evaluator, error) { + if fetchTimeout <= 0 { + fetchTimeout = defaultPolicyFetchTimeout } - return newRegoEvaluator(policyPaths, query, runtimeConfig) + if isBundle { + return newBundleEvaluator(policyPaths, query, runtimeConfig, fetchTimeout) + } + return newRegoEvaluator(policyPaths, query, runtimeConfig, fetchTimeout) } // newRegoEvaluator loads raw .rego files from local paths and/or URLs. -func newRegoEvaluator(policyPaths []string, query string, runtimeConfig map[string]string) (*Evaluator, error) { +func newRegoEvaluator(policyPaths []string, query string, runtimeConfig map[string]string, fetchTimeout time.Duration) (*Evaluator, error) { modules := make(map[string]string) // Load from policyPaths (resolved locations based on config Type) for _, source := range policyPaths { if isURL(source) { - name, content, err := fetchPolicy(source) + name, content, err := fetchPolicy(source, fetchTimeout) if err != nil { return nil, fmt.Errorf("failed to fetch policy from %s: %w", source, err) } @@ -101,13 +105,13 @@ func newRegoEvaluator(policyPaths []string, query string, runtimeConfig map[stri } // newBundleEvaluator loads an OPA bundle (.tar.gz) from a URL and compiles it. -func newBundleEvaluator(policyPaths []string, query string, runtimeConfig map[string]string) (*Evaluator, error) { +func newBundleEvaluator(policyPaths []string, query string, runtimeConfig map[string]string, fetchTimeout time.Duration) (*Evaluator, error) { if len(policyPaths) == 0 { return nil, fmt.Errorf("bundle source URL is required") } bundleURL := policyPaths[0] - modules, bundleData, err := loadBundle(bundleURL) + modules, bundleData, err := loadBundle(bundleURL, fetchTimeout) if err != nil { return nil, fmt.Errorf("failed to load bundle from %s: %w", bundleURL, err) } @@ -121,8 +125,8 @@ func newBundleEvaluator(policyPaths []string, query string, runtimeConfig map[st // loadBundle downloads a .tar.gz OPA bundle from a URL, parses it using OPA's // bundle reader, and returns the modules and data from the bundle. -func loadBundle(bundleURL string) (map[string]string, map[string]interface{}, error) { - data, err := fetchBundleArchive(bundleURL) +func loadBundle(bundleURL string, fetchTimeout time.Duration) (map[string]string, map[string]interface{}, error) { + data, err := fetchBundleArchive(bundleURL, fetchTimeout) if err != nil { return nil, nil, err } @@ -131,7 +135,7 @@ func loadBundle(bundleURL string) (map[string]string, map[string]interface{}, er } // fetchBundleArchive downloads a bundle .tar.gz from a URL. -func fetchBundleArchive(rawURL string) ([]byte, error) { +func fetchBundleArchive(rawURL string, fetchTimeout time.Duration) ([]byte, error) { parsed, err := url.Parse(rawURL) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) @@ -141,7 +145,7 @@ func fetchBundleArchive(rawURL string) ([]byte, error) { return nil, fmt.Errorf("unsupported URL scheme %q (only http and https are supported)", parsed.Scheme) } - client := &http.Client{Timeout: policyFetchTimeout} + client := &http.Client{Timeout: fetchTimeout} resp, err := client.Get(rawURL) if err != nil { return nil, fmt.Errorf("HTTP request failed: %w", err) @@ -229,7 +233,7 @@ func isURL(source string) bool { } // fetchPolicy downloads a .rego file from a URL and returns (filename, content, error). -func fetchPolicy(rawURL string) (string, string, error) { +func fetchPolicy(rawURL string, fetchTimeout time.Duration) (string, string, error) { parsed, err := url.Parse(rawURL) if err != nil { return "", "", fmt.Errorf("invalid URL: %w", err) @@ -239,7 +243,7 @@ func fetchPolicy(rawURL string) (string, string, error) { return "", "", fmt.Errorf("unsupported URL scheme %q (only http and https are supported)", parsed.Scheme) } - client := &http.Client{Timeout: policyFetchTimeout} + client := &http.Client{Timeout: fetchTimeout} resp, err := client.Get(rawURL) if err != nil { return "", "", fmt.Errorf("HTTP request failed: %w", err)