Files
onix/pkg/plugin/implementation/opapolicychecker/enforcer_test.go

1177 lines
32 KiB
Go

package opapolicychecker
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/open-policy-agent/opa/v1/bundle"
"github.com/beckn-one/beckn-onix/pkg/model"
)
// Helper: create a StepContext with the given action path and JSON body.
func makeStepCtx(action string, body string) *model.StepContext {
req, _ := http.NewRequest("POST", "/bpp/caller/"+action, nil)
return &model.StepContext{
Context: context.Background(),
Request: req,
Body: []byte(body),
}
}
// Helper: write a .rego file to a temp dir and return the dir path.
func writePolicyDir(t *testing.T, filename, content string) string {
t.Helper()
dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, filename), []byte(content), 0644)
if err != nil {
t.Fatalf("failed to write policy file: %v", err)
}
return dir
}
// --- Config Tests ---
func TestParseConfig_RequiresPolicySource(t *testing.T) {
_, err := ParseConfig(map[string]string{})
if err == nil {
t.Fatal("expected error when no policy source given")
}
}
func TestParseConfig_Defaults(t *testing.T) {
cfg, err := ParseConfig(map[string]string{
"type": "dir",
"location": "/tmp",
"query": "data.policy.violations",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Actions) != 0 {
t.Errorf("expected empty default actions (all enabled), got %v", cfg.Actions)
}
if !cfg.Enabled {
t.Error("expected enabled=true by default")
}
}
func TestParseConfig_RequiresQuery(t *testing.T) {
_, err := ParseConfig(map[string]string{
"type": "dir",
"location": "/tmp",
})
if err == nil {
t.Fatal("expected error when no query given")
}
}
func TestParseConfig_RuntimeConfigForwarding(t *testing.T) {
cfg, err := ParseConfig(map[string]string{
"type": "dir",
"location": "/tmp",
"query": "data.policy.violations",
"minDeliveryLeadHours": "6",
"customParam": "value",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.RuntimeConfig["minDeliveryLeadHours"] != "6" {
t.Errorf("expected minDeliveryLeadHours=6, got %q", cfg.RuntimeConfig["minDeliveryLeadHours"])
}
if cfg.RuntimeConfig["customParam"] != "value" {
t.Errorf("expected customParam=value, got %q", cfg.RuntimeConfig["customParam"])
}
}
func TestParseConfig_CustomActions(t *testing.T) {
cfg, err := ParseConfig(map[string]string{
"type": "dir",
"location": "/tmp",
"query": "data.policy.violations",
"actions": "confirm, select, init",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Actions) != 3 {
t.Fatalf("expected 3 actions, got %d: %v", len(cfg.Actions), cfg.Actions)
}
expected := []string{"confirm", "select", "init"}
for i, want := range expected {
if cfg.Actions[i] != want {
t.Errorf("action[%d] = %q, want %q", i, cfg.Actions[i], want)
}
}
}
func TestParseConfig_PolicyPaths(t *testing.T) {
cfg, err := ParseConfig(map[string]string{
"type": "url",
"location": "https://example.com/a.rego, https://example.com/b.rego",
"query": "data.policy.violations",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.PolicyPaths) != 2 {
t.Fatalf("expected 2 paths, got %d: %v", len(cfg.PolicyPaths), cfg.PolicyPaths)
}
if cfg.PolicyPaths[0] != "https://example.com/a.rego" {
t.Errorf("path[0] = %q", cfg.PolicyPaths[0])
}
}
// --- Evaluator Tests (with inline policies) ---
func TestEvaluator_NoViolations(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains msg if {
input.value < 0
msg := "value is negative"
}
`
dir := writePolicyDir(t, "test.rego", policy)
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
violations, err := eval.Evaluate(context.Background(), []byte(`{"value": 10}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations, got %d: %v", len(violations), violations)
}
}
func TestEvaluator_WithViolation(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains msg if {
input.value < 0
msg := "value is negative"
}
`
dir := writePolicyDir(t, "test.rego", policy)
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
violations, err := eval.Evaluate(context.Background(), []byte(`{"value": -5}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 {
t.Fatalf("expected 1 violation, got %d: %v", len(violations), violations)
}
if violations[0] != "value is negative" {
t.Errorf("unexpected violation: %q", violations[0])
}
}
func TestEvaluator_RuntimeConfig(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains msg if {
input.value > to_number(data.config.maxValue)
msg := "value exceeds maximum"
}
`
dir := writePolicyDir(t, "test.rego", policy)
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", map[string]string{"maxValue": "100"}, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
// Under limit
violations, err := eval.Evaluate(context.Background(), []byte(`{"value": 50}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations for value=50, got %v", violations)
}
// Over limit
violations, err = eval.Evaluate(context.Background(), []byte(`{"value": 150}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 {
t.Errorf("expected 1 violation for value=150, got %v", violations)
}
}
func TestEvaluator_SkipsTestFiles(t *testing.T) {
dir := t.TempDir()
policy := `
package policy
import rego.v1
violations contains "always" if { true }
`
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(policy), 0644)
// Test file would cause compilation issues if loaded (different package)
testFile := `
package policy_test
import rego.v1
import data.policy
test_something if { count(policy.violations) > 0 }
`
os.WriteFile(filepath.Join(dir, "policy_test.rego"), []byte(testFile), 0644)
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator should skip _test.rego files, but failed: %v", err)
}
violations, err := eval.Evaluate(context.Background(), []byte(`{}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 {
t.Errorf("expected 1 violation, got %d", len(violations))
}
}
func TestEvaluator_InvalidJSON(t *testing.T) {
policy := `
package policy
import rego.v1
violations := set()
`
dir := writePolicyDir(t, "test.rego", policy)
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
_, err = eval.Evaluate(context.Background(), []byte(`not json`))
if err == nil {
t.Error("expected error for invalid JSON")
}
}
// --- Evaluator URL Fetch Tests ---
func TestEvaluator_FetchFromURL(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains msg if {
input.value < 0
msg := "value is negative"
}
`
// Serve the policy via a local HTTP server
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte(policy))
}))
defer srv.Close()
eval, err := NewEvaluator([]string{srv.URL + "/test_policy.rego"}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator with URL failed: %v", err)
}
// Compliant
violations, err := eval.Evaluate(context.Background(), []byte(`{"value": 10}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations, got %v", violations)
}
// Non-compliant
violations, err = eval.Evaluate(context.Background(), []byte(`{"value": -1}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 {
t.Errorf("expected 1 violation, got %v", violations)
}
}
func TestEvaluator_FetchURL_NotFound(t *testing.T) {
srv := httptest.NewServer(http.NotFoundHandler())
defer srv.Close()
_, err := NewEvaluator([]string{srv.URL + "/missing.rego"}, "data.policy.violations", nil, false, 0)
if err == nil {
t.Fatal("expected error for 404 URL")
}
}
func TestEvaluator_FetchURL_InvalidScheme(t *testing.T) {
_, err := NewEvaluator([]string{"ftp://example.com/policy.rego"}, "data.policy.violations", nil, false, 0)
if err == nil {
t.Fatal("expected error for ftp:// scheme")
}
}
func TestEvaluator_MixedLocalAndURL(t *testing.T) {
// Local policy
localPolicy := `
package policy
import rego.v1
violations contains "local_violation" if { input.local_bad }
`
dir := writePolicyDir(t, "local.rego", localPolicy)
// Remote policy (different rule, same package)
remotePolicy := `
package policy
import rego.v1
violations contains "remote_violation" if { input.remote_bad }
`
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(remotePolicy))
}))
defer srv.Close()
eval, err := NewEvaluator([]string{dir, srv.URL + "/remote.rego"}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
// Trigger both violations
violations, err := eval.Evaluate(context.Background(), []byte(`{"local_bad": true, "remote_bad": true}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 2 {
t.Errorf("expected 2 violations (local+remote), got %d: %v", len(violations), violations)
}
}
// --- Evaluator with local file path in policySources ---
func TestEvaluator_LocalFilePath(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains "from_file" if { input.bad }
`
dir := t.TempDir()
policyPath := filepath.Join(dir, "local_policy.rego")
os.WriteFile(policyPath, []byte(policy), 0644)
eval, err := NewEvaluator([]string{policyPath}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator with local path failed: %v", err)
}
violations, err := eval.Evaluate(context.Background(), []byte(`{"bad": true}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 || violations[0] != "from_file" {
t.Errorf("expected [from_file], got %v", violations)
}
}
// --- Rego Modularity Tests ---
// These tests prove that rego files can reference each other, supporting
// modular policy design to avoid rule bloat.
// TestEvaluator_CrossFileModularity verifies that multiple .rego files
// in the SAME package automatically share rules and data.
func TestEvaluator_CrossFileModularity(t *testing.T) {
dir := t.TempDir()
// File 1: defines a helper rule
helpers := `
package policy
import rego.v1
is_high_value if { input.message.order.value > 10000 }
`
os.WriteFile(filepath.Join(dir, "helpers.rego"), []byte(helpers), 0644)
// File 2: uses the helper from file 1 (same package, auto-merged)
rules := `
package policy
import rego.v1
violations contains "order too large" if { is_high_value }
`
os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644)
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
// High value order — should trigger
violations, err := eval.Evaluate(context.Background(), []byte(`{"message":{"order":{"value":15000}}}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 || violations[0] != "order too large" {
t.Errorf("expected [order too large], got %v", violations)
}
// Low value order — should not trigger
violations, err = eval.Evaluate(context.Background(), []byte(`{"message":{"order":{"value":500}}}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations, got %v", violations)
}
}
// TestEvaluator_CrossPackageImport verifies that rego files in DIFFERENT
// packages can import each other using `import data.<package>`.
func TestEvaluator_CrossPackageImport(t *testing.T) {
dir := t.TempDir()
// File 1: utility package with reusable helpers
utils := `
package utils
import rego.v1
is_confirm if { input.context.action == "confirm" }
is_high_value if { input.message.order.value > 10000 }
`
os.WriteFile(filepath.Join(dir, "utils.rego"), []byte(utils), 0644)
// File 2: policy package imports from utils package
rules := `
package policy
import rego.v1
import data.utils
violations contains "high value confirm blocked" if {
utils.is_confirm
utils.is_high_value
}
`
os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644)
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
// confirm + high value — should fire
violations, err := eval.Evaluate(context.Background(), []byte(`{
"context": {"action": "confirm"},
"message": {"order": {"value": 50000}}
}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 {
t.Errorf("expected 1 violation, got %v", violations)
}
// search action — should NOT fire (action filter in rego)
violations, err = eval.Evaluate(context.Background(), []byte(`{
"context": {"action": "search"},
"message": {"order": {"value": 50000}}
}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations for search action, got %v", violations)
}
}
// TestEvaluator_MultiFileOrganization demonstrates a realistic modular layout
// where policies are organized by concern (compliance, safety, rate-limiting)
// across separate files that all work together.
func TestEvaluator_MultiFileOrganization(t *testing.T) {
dir := t.TempDir()
// Shared helpers
helpers := `
package helpers
import rego.v1
action_is(a) if { input.context.action == a }
value_exceeds(limit) if { input.message.order.value > limit }
`
os.WriteFile(filepath.Join(dir, "helpers.rego"), []byte(helpers), 0644)
// compliance.rego — compliance rules
compliance := `
package policy
import rego.v1
import data.helpers
violations contains "compliance: missing provider" if {
helpers.action_is("confirm")
not input.message.order.provider
}
`
os.WriteFile(filepath.Join(dir, "compliance.rego"), []byte(compliance), 0644)
// safety.rego — safety rules
safety := `
package policy
import rego.v1
import data.helpers
violations contains "safety: order value too high" if {
helpers.action_is("confirm")
helpers.value_exceeds(100000)
}
`
os.WriteFile(filepath.Join(dir, "safety.rego"), []byte(safety), 0644)
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
// Input that triggers BOTH violations
violations, err := eval.Evaluate(context.Background(), []byte(`{
"context": {"action": "confirm"},
"message": {"order": {"value": 999999}}
}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 2 {
t.Errorf("expected 2 violations (compliance+safety), got %d: %v", len(violations), violations)
}
}
// --- Enforcer Integration Tests ---
func TestEnforcer_Compliant(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains "blocked" if { input.context.action == "confirm"; input.block }
`
dir := writePolicyDir(t, "test.rego", policy)
enforcer, err := New(context.Background(), map[string]string{
"type": "dir",
"location": dir,
"query": "data.policy.violations",
"actions": "confirm",
})
if err != nil {
t.Fatalf("New failed: %v", err)
}
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}, "block": false}`)
err = enforcer.CheckPolicy(ctx)
if err != nil {
t.Errorf("expected nil error for compliant message, got: %v", err)
}
}
func TestEnforcer_NonCompliant(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains "blocked" if { input.context.action == "confirm" }
`
dir := writePolicyDir(t, "test.rego", policy)
enforcer, err := New(context.Background(), map[string]string{
"type": "dir",
"location": dir,
"query": "data.policy.violations",
"actions": "confirm",
})
if err != nil {
t.Fatalf("New failed: %v", err)
}
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}}`)
err = enforcer.CheckPolicy(ctx)
if err == nil {
t.Fatal("expected error for non-compliant message, got nil")
}
// Should be a BadReqErr
if _, ok := err.(*model.BadReqErr); !ok {
t.Errorf("expected *model.BadReqErr, got %T: %v", err, err)
}
}
func TestEnforcer_SkipsNonMatchingAction(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains "blocked" if { true }
`
dir := writePolicyDir(t, "test.rego", policy)
enforcer, err := New(context.Background(), map[string]string{
"type": "dir",
"location": dir,
"query": "data.policy.violations",
"actions": "confirm",
})
if err != nil {
t.Fatalf("New failed: %v", err)
}
// Non-compliant body, but action is "search" — not in configured actions
ctx := makeStepCtx("search", `{"context": {"action": "search"}}`)
err = enforcer.CheckPolicy(ctx)
if err != nil {
t.Errorf("expected nil for non-matching action, got: %v", err)
}
}
func TestEnforcer_DisabledPlugin(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains "blocked" if { true }
`
dir := writePolicyDir(t, "test.rego", policy)
enforcer, err := New(context.Background(), map[string]string{
"type": "dir",
"location": dir,
"query": "data.policy.violations",
"enabled": "false",
})
if err != nil {
t.Fatalf("New failed: %v", err)
}
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}}`)
err = enforcer.CheckPolicy(ctx)
if err != nil {
t.Errorf("expected nil for disabled plugin, got: %v", err)
}
}
// --- Enforcer with URL-sourced policy ---
func TestEnforcer_PolicyFromURL(t *testing.T) {
policy := `
package policy
import rego.v1
violations contains "blocked" if { input.context.action == "confirm" }
`
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(policy))
}))
defer srv.Close()
enforcer, err := New(context.Background(), map[string]string{
"type": "url",
"location": srv.URL + "/block_confirm.rego",
"query": "data.policy.violations",
"actions": "confirm",
})
if err != nil {
t.Fatalf("New failed: %v", err)
}
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}}`)
err = enforcer.CheckPolicy(ctx)
if err == nil {
t.Fatal("expected error from URL-sourced policy, got nil")
}
if _, ok := err.(*model.BadReqErr); !ok {
t.Errorf("expected *model.BadReqErr, got %T", err)
}
}
// --- extractAction Tests ---
func TestExtractAction_FromURL(t *testing.T) {
action := extractAction("/bpp/caller/confirm", nil)
if action != "confirm" {
t.Errorf("expected 'confirm', got %q", action)
}
}
func TestExtractAction_FromBody(t *testing.T) {
body := []byte(`{"context": {"action": "select"}}`)
action := extractAction("/x", body)
if action != "select" {
t.Errorf("expected 'select', got %q", action)
}
}
// --- Config Tests: Bundle Type ---
func TestParseConfig_BundleType(t *testing.T) {
cfg, err := ParseConfig(map[string]string{
"type": "bundle",
"location": "https://example.com/bundle.tar.gz",
"query": "data.retail.validation.result",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !cfg.IsBundle {
t.Error("expected IsBundle=true for type=bundle")
}
if len(cfg.PolicyPaths) != 1 || cfg.PolicyPaths[0] != "https://example.com/bundle.tar.gz" {
t.Errorf("expected 1 policy path, got %v", cfg.PolicyPaths)
}
if cfg.Query != "data.retail.validation.result" {
t.Errorf("expected query 'data.retail.validation.result', got %q", cfg.Query)
}
}
// --- Structured Result Format Tests ---
func TestEvaluator_StructuredResult_Valid(t *testing.T) {
// Policy returns {"valid": true, "violations": []} — no violations expected
policy := `
package retail.policy
import rego.v1
default result := {
"valid": true,
"violations": []
}
`
dir := writePolicyDir(t, "policy.rego", policy)
eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
violations, err := eval.Evaluate(context.Background(), []byte(`{"message": {"order": {"items": []}}}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations for valid result, got %v", violations)
}
}
func TestEvaluator_StructuredResult_WithViolations(t *testing.T) {
// Policy returns {"valid": false, "violations": ["msg1", "msg2"]} when items have count <= 0
policy := `
package retail.policy
import rego.v1
default result := {
"valid": true,
"violations": []
}
result := {
"valid": count(violations) == 0,
"violations": violations
}
violations contains msg if {
some item in input.message.order.items
item.quantity.count <= 0
msg := sprintf("item %s: quantity must be > 0", [item.id])
}
`
dir := writePolicyDir(t, "policy.rego", policy)
eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
// Non-compliant input
body := `{
"message": {
"order": {
"items": [
{"id": "item1", "quantity": {"count": 0}},
{"id": "item2", "quantity": {"count": 5}}
]
}
}
}`
violations, err := eval.Evaluate(context.Background(), []byte(body))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 {
t.Fatalf("expected 1 violation, got %d: %v", len(violations), violations)
}
if violations[0] != "item item1: quantity must be > 0" {
t.Errorf("unexpected violation: %q", violations[0])
}
// Compliant input
body = `{
"message": {
"order": {
"items": [
{"id": "item1", "quantity": {"count": 3}}
]
}
}
}`
violations, err = eval.Evaluate(context.Background(), []byte(body))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations for compliant input, got %v", violations)
}
}
func TestEvaluator_StructuredResult_FalseNoViolations(t *testing.T) {
// Edge case: valid=false but violations is empty — should report generic denial
policy := `
package policy
import rego.v1
result := {
"valid": false,
"violations": []
}
`
dir := writePolicyDir(t, "policy.rego", policy)
eval, err := NewEvaluator([]string{dir}, "data.policy.result", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
violations, err := eval.Evaluate(context.Background(), []byte(`{}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 || violations[0] != "policy denied the request" {
t.Errorf("expected ['policy denied the request'], got %v", violations)
}
}
func TestEvaluator_NonStructuredMapResult_Ignored(t *testing.T) {
policy := `
package policy
import rego.v1
result := {
"action": "confirm",
"status": "ok"
}
`
dir := writePolicyDir(t, "policy.rego", policy)
eval, err := NewEvaluator([]string{dir}, "data.policy.result", nil, false, 0)
if err != nil {
t.Fatalf("NewEvaluator failed: %v", err)
}
violations, err := eval.Evaluate(context.Background(), []byte(`{}`))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Fatalf("expected non-structured map result to be ignored, got %v", violations)
}
}
// --- Bundle Tests ---
// buildTestBundle creates an OPA bundle .tar.gz in memory from the given modules.
func buildTestBundle(t *testing.T, modules map[string]string) []byte {
t.Helper()
b := bundle.Bundle{
Modules: make([]bundle.ModuleFile, 0, len(modules)),
Data: make(map[string]interface{}),
}
for path, content := range modules {
b.Modules = append(b.Modules, bundle.ModuleFile{
URL: path,
Path: path,
Raw: []byte(content),
Parsed: nil,
})
}
var buf bytes.Buffer
if err := bundle.Write(&buf, b); err != nil {
t.Fatalf("failed to write test bundle: %v", err)
}
return buf.Bytes()
}
func TestEvaluator_BundleFromURL(t *testing.T) {
policy := `
package retail.validation
import rego.v1
default result := {
"valid": true,
"violations": []
}
result := {
"valid": count(violations) == 0,
"violations": violations
}
violations contains msg if {
some item in input.message.order.items
item.quantity.count <= 0
msg := sprintf("item %s: quantity must be > 0", [item.id])
}
`
bundleData := buildTestBundle(t, map[string]string{
"retail/validation.rego": policy,
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/gzip")
w.Write(bundleData)
}))
defer srv.Close()
eval, err := NewEvaluator([]string{srv.URL + "/bundle.tar.gz"}, "data.retail.validation.result", nil, true, 0)
if err != nil {
t.Fatalf("NewEvaluator with bundle failed: %v", err)
}
// Non-compliant
body := `{"message":{"order":{"items":[{"id":"x","quantity":{"count":0}}]}}}`
violations, err := eval.Evaluate(context.Background(), []byte(body))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 1 {
t.Fatalf("expected 1 violation, got %d: %v", len(violations), violations)
}
// Compliant
body = `{"message":{"order":{"items":[{"id":"x","quantity":{"count":5}}]}}}`
violations, err = eval.Evaluate(context.Background(), []byte(body))
if err != nil {
t.Fatalf("Evaluate failed: %v", err)
}
if len(violations) != 0 {
t.Errorf("expected 0 violations, got %v", violations)
}
}
func TestEnforcer_BundlePolicy(t *testing.T) {
policy := `
package retail.policy
import rego.v1
default result := {
"valid": true,
"violations": []
}
result := {
"valid": count(violations) == 0,
"violations": violations
}
violations contains "blocked" if {
input.context.action == "confirm"
not input.message.order.provider
}
`
bundleData := buildTestBundle(t, map[string]string{
"retail/policy.rego": policy,
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/gzip")
w.Write(bundleData)
}))
defer srv.Close()
enforcer, err := New(context.Background(), map[string]string{
"type": "bundle",
"location": srv.URL + "/policy-bundle.tar.gz",
"query": "data.retail.policy.result",
"actions": "confirm",
})
if err != nil {
t.Fatalf("New failed: %v", err)
}
// Non-compliant: confirm without provider
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}, "message": {"order": {}}}`)
err = enforcer.CheckPolicy(ctx)
if err == nil {
t.Fatal("expected error for non-compliant message, got nil")
}
if _, ok := err.(*model.BadReqErr); !ok {
t.Errorf("expected *model.BadReqErr, got %T: %v", err, err)
}
// Compliant: confirm with provider
ctx = makeStepCtx("confirm", `{"context": {"action": "confirm"}, "message": {"order": {"provider": {"id": "p1"}}}}`)
err = enforcer.CheckPolicy(ctx)
if err != nil {
t.Errorf("expected nil error for compliant message, got: %v", err)
}
}
func TestParseConfig_RefreshInterval(t *testing.T) {
cfg, err := ParseConfig(map[string]string{
"type": "dir",
"location": "/tmp",
"query": "data.policy.violations",
"refreshIntervalSeconds": "300",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.RefreshInterval != 300*time.Second {
t.Errorf("expected 300s refresh interval, got %v", cfg.RefreshInterval)
}
}
func TestParseConfig_RefreshInterval_Zero(t *testing.T) {
cfg, err := ParseConfig(map[string]string{
"type": "dir",
"location": "/tmp",
"query": "data.policy.violations",
// no refreshIntervalSeconds → disabled
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.RefreshInterval != 0 {
t.Errorf("expected refresh disabled (0), got %v", cfg.RefreshInterval)
}
}
func TestParseConfig_RefreshInterval_Invalid(t *testing.T) {
_, err := ParseConfig(map[string]string{
"type": "dir",
"location": "/tmp",
"query": "data.policy.violations",
"refreshIntervalSeconds": "not-a-number",
})
if err == nil {
t.Fatal("expected error for invalid refreshIntervalSeconds")
}
}
// TestEnforcer_HotReload verifies that the hot-reload goroutine picks up changes
// to a local policy file within the configured refresh interval.
func TestEnforcer_HotReload(t *testing.T) {
dir := t.TempDir()
policyPath := filepath.Join(dir, "policy.rego")
// Initial policy: always blocks confirm
blockPolicy := `package policy
import rego.v1
default result := {"valid": false, "violations": ["blocked by initial policy"]}
result := {"valid": false, "violations": ["blocked by initial policy"]}
`
if err := os.WriteFile(policyPath, []byte(blockPolicy), 0644); err != nil {
t.Fatalf("failed to write initial policy: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
enforcer, err := New(ctx, map[string]string{
"type": "dir",
"location": dir,
"query": "data.policy.result",
"refreshIntervalSeconds": "1", // 1s refresh for test speed
})
if err != nil {
t.Fatalf("New failed: %v", err)
}
// Confirm is blocked with initial policy
stepCtx := makeStepCtx("confirm", `{"context":{"action":"confirm"}}`)
if err := enforcer.CheckPolicy(stepCtx); err == nil {
t.Fatal("expected block from initial policy, got nil")
}
// Swap policy on disk to allow everything
allowPolicy := `package policy
import rego.v1
default result := {"valid": true, "violations": []}
`
if err := os.WriteFile(policyPath, []byte(allowPolicy), 0644); err != nil {
t.Fatalf("failed to write updated policy: %v", err)
}
// Wait up to 5s for the reload to fire and swap the evaluator
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
if err := enforcer.CheckPolicy(stepCtx); err == nil {
// Reload took effect
return
}
time.Sleep(200 * time.Millisecond)
}
t.Fatal("hot-reload did not take effect within 5 seconds")
}
func TestParseConfig_FetchTimeout(t *testing.T) {
cfg, err := ParseConfig(map[string]string{
"type": "url",
"location": "https://example.com/policy.rego",
"query": "data.policy.violations",
"fetchTimeoutSeconds": "7",
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.FetchTimeout != 7*time.Second {
t.Fatalf("expected fetch timeout 7s, got %s", cfg.FetchTimeout)
}
}
func TestEvaluator_FetchURL_Timeout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(50 * time.Millisecond)
w.Write([]byte(`package policy
import rego.v1
violations := []`))
}))
defer srv.Close()
_, err := NewEvaluator([]string{srv.URL + "/slow.rego"}, "data.policy.violations", nil, false, 10*time.Millisecond)
if err == nil {
t.Fatal("expected timeout error for slow policy URL")
}
}
func TestExtractAction_NonStandardURLFallsBackToBody(t *testing.T) {
body := []byte(`{"context": {"action": "confirm"}}`)
action := extractAction("/bpp/caller/confirm/extra", body)
if action != "confirm" {
t.Fatalf("expected body fallback action 'confirm', got %q", action)
}
}
func TestEnforcer_DisabledSkipsEvaluatorInitialization(t *testing.T) {
enforcer, err := New(context.Background(), map[string]string{
"type": "url",
"location": "https://127.0.0.1:1/unreachable.rego",
"query": "data.policy.violations",
"enabled": "false",
})
if err != nil {
t.Fatalf("expected disabled enforcer to skip evaluator initialization, got %v", err)
}
if enforcer.getEvaluator() != nil {
t.Fatal("expected disabled enforcer to leave evaluator uninitialized")
}
}