1177 lines
32 KiB
Go
1177 lines
32 KiB
Go
package opapolicychecker
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/open-policy-agent/opa/v1/bundle"
|
|
|
|
"github.com/beckn-one/beckn-onix/pkg/model"
|
|
)
|
|
|
|
// Helper: create a StepContext with the given action path and JSON body.
|
|
func makeStepCtx(action string, body string) *model.StepContext {
|
|
req, _ := http.NewRequest("POST", "/bpp/caller/"+action, nil)
|
|
return &model.StepContext{
|
|
Context: context.Background(),
|
|
Request: req,
|
|
Body: []byte(body),
|
|
}
|
|
}
|
|
|
|
// Helper: write a .rego file to a temp dir and return the dir path.
|
|
func writePolicyDir(t *testing.T, filename, content string) string {
|
|
t.Helper()
|
|
dir := t.TempDir()
|
|
err := os.WriteFile(filepath.Join(dir, filename), []byte(content), 0644)
|
|
if err != nil {
|
|
t.Fatalf("failed to write policy file: %v", err)
|
|
}
|
|
return dir
|
|
}
|
|
|
|
// --- Config Tests ---
|
|
|
|
func TestParseConfig_RequiresPolicySource(t *testing.T) {
|
|
_, err := ParseConfig(map[string]string{})
|
|
if err == nil {
|
|
t.Fatal("expected error when no policy source given")
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_Defaults(t *testing.T) {
|
|
cfg, err := ParseConfig(map[string]string{
|
|
"type": "dir",
|
|
"location": "/tmp",
|
|
"query": "data.policy.violations",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(cfg.Actions) != 0 {
|
|
t.Errorf("expected empty default actions (all enabled), got %v", cfg.Actions)
|
|
}
|
|
if !cfg.Enabled {
|
|
t.Error("expected enabled=true by default")
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_RequiresQuery(t *testing.T) {
|
|
_, err := ParseConfig(map[string]string{
|
|
"type": "dir",
|
|
"location": "/tmp",
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error when no query given")
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_RuntimeConfigForwarding(t *testing.T) {
|
|
cfg, err := ParseConfig(map[string]string{
|
|
"type": "dir",
|
|
"location": "/tmp",
|
|
"query": "data.policy.violations",
|
|
"minDeliveryLeadHours": "6",
|
|
"customParam": "value",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if cfg.RuntimeConfig["minDeliveryLeadHours"] != "6" {
|
|
t.Errorf("expected minDeliveryLeadHours=6, got %q", cfg.RuntimeConfig["minDeliveryLeadHours"])
|
|
}
|
|
if cfg.RuntimeConfig["customParam"] != "value" {
|
|
t.Errorf("expected customParam=value, got %q", cfg.RuntimeConfig["customParam"])
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_CustomActions(t *testing.T) {
|
|
cfg, err := ParseConfig(map[string]string{
|
|
"type": "dir",
|
|
"location": "/tmp",
|
|
"query": "data.policy.violations",
|
|
"actions": "confirm, select, init",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(cfg.Actions) != 3 {
|
|
t.Fatalf("expected 3 actions, got %d: %v", len(cfg.Actions), cfg.Actions)
|
|
}
|
|
expected := []string{"confirm", "select", "init"}
|
|
for i, want := range expected {
|
|
if cfg.Actions[i] != want {
|
|
t.Errorf("action[%d] = %q, want %q", i, cfg.Actions[i], want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_PolicyPaths(t *testing.T) {
|
|
cfg, err := ParseConfig(map[string]string{
|
|
"type": "url",
|
|
"location": "https://example.com/a.rego, https://example.com/b.rego",
|
|
"query": "data.policy.violations",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(cfg.PolicyPaths) != 2 {
|
|
t.Fatalf("expected 2 paths, got %d: %v", len(cfg.PolicyPaths), cfg.PolicyPaths)
|
|
}
|
|
if cfg.PolicyPaths[0] != "https://example.com/a.rego" {
|
|
t.Errorf("path[0] = %q", cfg.PolicyPaths[0])
|
|
}
|
|
}
|
|
|
|
// --- Evaluator Tests (with inline policies) ---
|
|
|
|
func TestEvaluator_NoViolations(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains msg if {
|
|
input.value < 0
|
|
msg := "value is negative"
|
|
}
|
|
`
|
|
dir := writePolicyDir(t, "test.rego", policy)
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{"value": 10}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Errorf("expected 0 violations, got %d: %v", len(violations), violations)
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_WithViolation(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains msg if {
|
|
input.value < 0
|
|
msg := "value is negative"
|
|
}
|
|
`
|
|
dir := writePolicyDir(t, "test.rego", policy)
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{"value": -5}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 {
|
|
t.Fatalf("expected 1 violation, got %d: %v", len(violations), violations)
|
|
}
|
|
if violations[0] != "value is negative" {
|
|
t.Errorf("unexpected violation: %q", violations[0])
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_RuntimeConfig(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains msg if {
|
|
input.value > to_number(data.config.maxValue)
|
|
msg := "value exceeds maximum"
|
|
}
|
|
`
|
|
dir := writePolicyDir(t, "test.rego", policy)
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", map[string]string{"maxValue": "100"}, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
// Under limit
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{"value": 50}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Errorf("expected 0 violations for value=50, got %v", violations)
|
|
}
|
|
|
|
// Over limit
|
|
violations, err = eval.Evaluate(context.Background(), []byte(`{"value": 150}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 {
|
|
t.Errorf("expected 1 violation for value=150, got %v", violations)
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_SkipsTestFiles(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "always" if { true }
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(policy), 0644)
|
|
|
|
// Test file would cause compilation issues if loaded (different package)
|
|
testFile := `
|
|
package policy_test
|
|
import rego.v1
|
|
import data.policy
|
|
test_something if { count(policy.violations) > 0 }
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "policy_test.rego"), []byte(testFile), 0644)
|
|
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator should skip _test.rego files, but failed: %v", err)
|
|
}
|
|
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 {
|
|
t.Errorf("expected 1 violation, got %d", len(violations))
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_InvalidJSON(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations := set()
|
|
`
|
|
dir := writePolicyDir(t, "test.rego", policy)
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
_, err = eval.Evaluate(context.Background(), []byte(`not json`))
|
|
if err == nil {
|
|
t.Error("expected error for invalid JSON")
|
|
}
|
|
}
|
|
|
|
// --- Evaluator URL Fetch Tests ---
|
|
|
|
func TestEvaluator_FetchFromURL(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains msg if {
|
|
input.value < 0
|
|
msg := "value is negative"
|
|
}
|
|
`
|
|
// Serve the policy via a local HTTP server
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.Write([]byte(policy))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
eval, err := NewEvaluator([]string{srv.URL + "/test_policy.rego"}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator with URL failed: %v", err)
|
|
}
|
|
|
|
// Compliant
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{"value": 10}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Errorf("expected 0 violations, got %v", violations)
|
|
}
|
|
|
|
// Non-compliant
|
|
violations, err = eval.Evaluate(context.Background(), []byte(`{"value": -1}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 {
|
|
t.Errorf("expected 1 violation, got %v", violations)
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_FetchURL_NotFound(t *testing.T) {
|
|
srv := httptest.NewServer(http.NotFoundHandler())
|
|
defer srv.Close()
|
|
|
|
_, err := NewEvaluator([]string{srv.URL + "/missing.rego"}, "data.policy.violations", nil, false, 0)
|
|
if err == nil {
|
|
t.Fatal("expected error for 404 URL")
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_FetchURL_InvalidScheme(t *testing.T) {
|
|
_, err := NewEvaluator([]string{"ftp://example.com/policy.rego"}, "data.policy.violations", nil, false, 0)
|
|
if err == nil {
|
|
t.Fatal("expected error for ftp:// scheme")
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_MixedLocalAndURL(t *testing.T) {
|
|
// Local policy
|
|
localPolicy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "local_violation" if { input.local_bad }
|
|
`
|
|
dir := writePolicyDir(t, "local.rego", localPolicy)
|
|
|
|
// Remote policy (different rule, same package)
|
|
remotePolicy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "remote_violation" if { input.remote_bad }
|
|
`
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte(remotePolicy))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
eval, err := NewEvaluator([]string{dir, srv.URL + "/remote.rego"}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
// Trigger both violations
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{"local_bad": true, "remote_bad": true}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 2 {
|
|
t.Errorf("expected 2 violations (local+remote), got %d: %v", len(violations), violations)
|
|
}
|
|
}
|
|
|
|
// --- Evaluator with local file path in policySources ---
|
|
|
|
func TestEvaluator_LocalFilePath(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "from_file" if { input.bad }
|
|
`
|
|
dir := t.TempDir()
|
|
policyPath := filepath.Join(dir, "local_policy.rego")
|
|
os.WriteFile(policyPath, []byte(policy), 0644)
|
|
|
|
eval, err := NewEvaluator([]string{policyPath}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator with local path failed: %v", err)
|
|
}
|
|
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{"bad": true}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 || violations[0] != "from_file" {
|
|
t.Errorf("expected [from_file], got %v", violations)
|
|
}
|
|
}
|
|
|
|
// --- Rego Modularity Tests ---
|
|
// These tests prove that rego files can reference each other, supporting
|
|
// modular policy design to avoid rule bloat.
|
|
|
|
// TestEvaluator_CrossFileModularity verifies that multiple .rego files
|
|
// in the SAME package automatically share rules and data.
|
|
func TestEvaluator_CrossFileModularity(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// File 1: defines a helper rule
|
|
helpers := `
|
|
package policy
|
|
import rego.v1
|
|
is_high_value if { input.message.order.value > 10000 }
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "helpers.rego"), []byte(helpers), 0644)
|
|
|
|
// File 2: uses the helper from file 1 (same package, auto-merged)
|
|
rules := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "order too large" if { is_high_value }
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644)
|
|
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
// High value order — should trigger
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{"message":{"order":{"value":15000}}}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 || violations[0] != "order too large" {
|
|
t.Errorf("expected [order too large], got %v", violations)
|
|
}
|
|
|
|
// Low value order — should not trigger
|
|
violations, err = eval.Evaluate(context.Background(), []byte(`{"message":{"order":{"value":500}}}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Errorf("expected 0 violations, got %v", violations)
|
|
}
|
|
}
|
|
|
|
// TestEvaluator_CrossPackageImport verifies that rego files in DIFFERENT
|
|
// packages can import each other using `import data.<package>`.
|
|
func TestEvaluator_CrossPackageImport(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// File 1: utility package with reusable helpers
|
|
utils := `
|
|
package utils
|
|
import rego.v1
|
|
is_confirm if { input.context.action == "confirm" }
|
|
is_high_value if { input.message.order.value > 10000 }
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "utils.rego"), []byte(utils), 0644)
|
|
|
|
// File 2: policy package imports from utils package
|
|
rules := `
|
|
package policy
|
|
import rego.v1
|
|
import data.utils
|
|
violations contains "high value confirm blocked" if {
|
|
utils.is_confirm
|
|
utils.is_high_value
|
|
}
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644)
|
|
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
// confirm + high value — should fire
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{
|
|
"context": {"action": "confirm"},
|
|
"message": {"order": {"value": 50000}}
|
|
}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 {
|
|
t.Errorf("expected 1 violation, got %v", violations)
|
|
}
|
|
|
|
// search action — should NOT fire (action filter in rego)
|
|
violations, err = eval.Evaluate(context.Background(), []byte(`{
|
|
"context": {"action": "search"},
|
|
"message": {"order": {"value": 50000}}
|
|
}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Errorf("expected 0 violations for search action, got %v", violations)
|
|
}
|
|
}
|
|
|
|
// TestEvaluator_MultiFileOrganization demonstrates a realistic modular layout
|
|
// where policies are organized by concern (compliance, safety, rate-limiting)
|
|
// across separate files that all work together.
|
|
func TestEvaluator_MultiFileOrganization(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Shared helpers
|
|
helpers := `
|
|
package helpers
|
|
import rego.v1
|
|
action_is(a) if { input.context.action == a }
|
|
value_exceeds(limit) if { input.message.order.value > limit }
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "helpers.rego"), []byte(helpers), 0644)
|
|
|
|
// compliance.rego — compliance rules
|
|
compliance := `
|
|
package policy
|
|
import rego.v1
|
|
import data.helpers
|
|
violations contains "compliance: missing provider" if {
|
|
helpers.action_is("confirm")
|
|
not input.message.order.provider
|
|
}
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "compliance.rego"), []byte(compliance), 0644)
|
|
|
|
// safety.rego — safety rules
|
|
safety := `
|
|
package policy
|
|
import rego.v1
|
|
import data.helpers
|
|
violations contains "safety: order value too high" if {
|
|
helpers.action_is("confirm")
|
|
helpers.value_exceeds(100000)
|
|
}
|
|
`
|
|
os.WriteFile(filepath.Join(dir, "safety.rego"), []byte(safety), 0644)
|
|
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
// Input that triggers BOTH violations
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{
|
|
"context": {"action": "confirm"},
|
|
"message": {"order": {"value": 999999}}
|
|
}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 2 {
|
|
t.Errorf("expected 2 violations (compliance+safety), got %d: %v", len(violations), violations)
|
|
}
|
|
}
|
|
|
|
// --- Enforcer Integration Tests ---
|
|
|
|
func TestEnforcer_Compliant(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "blocked" if { input.context.action == "confirm"; input.block }
|
|
`
|
|
dir := writePolicyDir(t, "test.rego", policy)
|
|
|
|
enforcer, err := New(context.Background(), map[string]string{
|
|
"type": "dir",
|
|
"location": dir,
|
|
"query": "data.policy.violations",
|
|
"actions": "confirm",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New failed: %v", err)
|
|
}
|
|
|
|
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}, "block": false}`)
|
|
err = enforcer.CheckPolicy(ctx)
|
|
if err != nil {
|
|
t.Errorf("expected nil error for compliant message, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEnforcer_NonCompliant(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "blocked" if { input.context.action == "confirm" }
|
|
`
|
|
dir := writePolicyDir(t, "test.rego", policy)
|
|
|
|
enforcer, err := New(context.Background(), map[string]string{
|
|
"type": "dir",
|
|
"location": dir,
|
|
"query": "data.policy.violations",
|
|
"actions": "confirm",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New failed: %v", err)
|
|
}
|
|
|
|
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}}`)
|
|
err = enforcer.CheckPolicy(ctx)
|
|
if err == nil {
|
|
t.Fatal("expected error for non-compliant message, got nil")
|
|
}
|
|
|
|
// Should be a BadReqErr
|
|
if _, ok := err.(*model.BadReqErr); !ok {
|
|
t.Errorf("expected *model.BadReqErr, got %T: %v", err, err)
|
|
}
|
|
}
|
|
|
|
func TestEnforcer_SkipsNonMatchingAction(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "blocked" if { true }
|
|
`
|
|
dir := writePolicyDir(t, "test.rego", policy)
|
|
|
|
enforcer, err := New(context.Background(), map[string]string{
|
|
"type": "dir",
|
|
"location": dir,
|
|
"query": "data.policy.violations",
|
|
"actions": "confirm",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New failed: %v", err)
|
|
}
|
|
|
|
// Non-compliant body, but action is "search" — not in configured actions
|
|
ctx := makeStepCtx("search", `{"context": {"action": "search"}}`)
|
|
err = enforcer.CheckPolicy(ctx)
|
|
if err != nil {
|
|
t.Errorf("expected nil for non-matching action, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEnforcer_DisabledPlugin(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "blocked" if { true }
|
|
`
|
|
dir := writePolicyDir(t, "test.rego", policy)
|
|
|
|
enforcer, err := New(context.Background(), map[string]string{
|
|
"type": "dir",
|
|
"location": dir,
|
|
"query": "data.policy.violations",
|
|
"enabled": "false",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New failed: %v", err)
|
|
}
|
|
|
|
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}}`)
|
|
err = enforcer.CheckPolicy(ctx)
|
|
if err != nil {
|
|
t.Errorf("expected nil for disabled plugin, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// --- Enforcer with URL-sourced policy ---
|
|
|
|
func TestEnforcer_PolicyFromURL(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
import rego.v1
|
|
violations contains "blocked" if { input.context.action == "confirm" }
|
|
`
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte(policy))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
enforcer, err := New(context.Background(), map[string]string{
|
|
"type": "url",
|
|
"location": srv.URL + "/block_confirm.rego",
|
|
"query": "data.policy.violations",
|
|
"actions": "confirm",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New failed: %v", err)
|
|
}
|
|
|
|
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}}`)
|
|
err = enforcer.CheckPolicy(ctx)
|
|
if err == nil {
|
|
t.Fatal("expected error from URL-sourced policy, got nil")
|
|
}
|
|
if _, ok := err.(*model.BadReqErr); !ok {
|
|
t.Errorf("expected *model.BadReqErr, got %T", err)
|
|
}
|
|
}
|
|
|
|
// --- extractAction Tests ---
|
|
|
|
func TestExtractAction_FromURL(t *testing.T) {
|
|
action := extractAction("/bpp/caller/confirm", nil)
|
|
if action != "confirm" {
|
|
t.Errorf("expected 'confirm', got %q", action)
|
|
}
|
|
}
|
|
|
|
func TestExtractAction_FromBody(t *testing.T) {
|
|
body := []byte(`{"context": {"action": "select"}}`)
|
|
action := extractAction("/x", body)
|
|
if action != "select" {
|
|
t.Errorf("expected 'select', got %q", action)
|
|
}
|
|
}
|
|
|
|
// --- Config Tests: Bundle Type ---
|
|
|
|
func TestParseConfig_BundleType(t *testing.T) {
|
|
cfg, err := ParseConfig(map[string]string{
|
|
"type": "bundle",
|
|
"location": "https://example.com/bundle.tar.gz",
|
|
"query": "data.retail.validation.result",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !cfg.IsBundle {
|
|
t.Error("expected IsBundle=true for type=bundle")
|
|
}
|
|
if len(cfg.PolicyPaths) != 1 || cfg.PolicyPaths[0] != "https://example.com/bundle.tar.gz" {
|
|
t.Errorf("expected 1 policy path, got %v", cfg.PolicyPaths)
|
|
}
|
|
if cfg.Query != "data.retail.validation.result" {
|
|
t.Errorf("expected query 'data.retail.validation.result', got %q", cfg.Query)
|
|
}
|
|
}
|
|
|
|
// --- Structured Result Format Tests ---
|
|
|
|
func TestEvaluator_StructuredResult_Valid(t *testing.T) {
|
|
// Policy returns {"valid": true, "violations": []} — no violations expected
|
|
policy := `
|
|
package retail.policy
|
|
|
|
import rego.v1
|
|
|
|
default result := {
|
|
"valid": true,
|
|
"violations": []
|
|
}
|
|
`
|
|
dir := writePolicyDir(t, "policy.rego", policy)
|
|
eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{"message": {"order": {"items": []}}}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Errorf("expected 0 violations for valid result, got %v", violations)
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_StructuredResult_WithViolations(t *testing.T) {
|
|
// Policy returns {"valid": false, "violations": ["msg1", "msg2"]} when items have count <= 0
|
|
policy := `
|
|
package retail.policy
|
|
|
|
import rego.v1
|
|
|
|
default result := {
|
|
"valid": true,
|
|
"violations": []
|
|
}
|
|
|
|
result := {
|
|
"valid": count(violations) == 0,
|
|
"violations": violations
|
|
}
|
|
|
|
violations contains msg if {
|
|
some item in input.message.order.items
|
|
item.quantity.count <= 0
|
|
msg := sprintf("item %s: quantity must be > 0", [item.id])
|
|
}
|
|
`
|
|
dir := writePolicyDir(t, "policy.rego", policy)
|
|
eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
// Non-compliant input
|
|
body := `{
|
|
"message": {
|
|
"order": {
|
|
"items": [
|
|
{"id": "item1", "quantity": {"count": 0}},
|
|
{"id": "item2", "quantity": {"count": 5}}
|
|
]
|
|
}
|
|
}
|
|
}`
|
|
violations, err := eval.Evaluate(context.Background(), []byte(body))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 {
|
|
t.Fatalf("expected 1 violation, got %d: %v", len(violations), violations)
|
|
}
|
|
if violations[0] != "item item1: quantity must be > 0" {
|
|
t.Errorf("unexpected violation: %q", violations[0])
|
|
}
|
|
|
|
// Compliant input
|
|
body = `{
|
|
"message": {
|
|
"order": {
|
|
"items": [
|
|
{"id": "item1", "quantity": {"count": 3}}
|
|
]
|
|
}
|
|
}
|
|
}`
|
|
violations, err = eval.Evaluate(context.Background(), []byte(body))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Errorf("expected 0 violations for compliant input, got %v", violations)
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_StructuredResult_FalseNoViolations(t *testing.T) {
|
|
// Edge case: valid=false but violations is empty — should report generic denial
|
|
policy := `
|
|
package policy
|
|
|
|
import rego.v1
|
|
|
|
result := {
|
|
"valid": false,
|
|
"violations": []
|
|
}
|
|
`
|
|
dir := writePolicyDir(t, "policy.rego", policy)
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.result", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 || violations[0] != "policy denied the request" {
|
|
t.Errorf("expected ['policy denied the request'], got %v", violations)
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_NonStructuredMapResult_Ignored(t *testing.T) {
|
|
policy := `
|
|
package policy
|
|
|
|
import rego.v1
|
|
|
|
result := {
|
|
"action": "confirm",
|
|
"status": "ok"
|
|
}
|
|
`
|
|
dir := writePolicyDir(t, "policy.rego", policy)
|
|
eval, err := NewEvaluator([]string{dir}, "data.policy.result", nil, false, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator failed: %v", err)
|
|
}
|
|
|
|
violations, err := eval.Evaluate(context.Background(), []byte(`{}`))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Fatalf("expected non-structured map result to be ignored, got %v", violations)
|
|
}
|
|
}
|
|
|
|
// --- Bundle Tests ---
|
|
|
|
// buildTestBundle creates an OPA bundle .tar.gz in memory from the given modules.
|
|
func buildTestBundle(t *testing.T, modules map[string]string) []byte {
|
|
t.Helper()
|
|
b := bundle.Bundle{
|
|
Modules: make([]bundle.ModuleFile, 0, len(modules)),
|
|
Data: make(map[string]interface{}),
|
|
}
|
|
for path, content := range modules {
|
|
b.Modules = append(b.Modules, bundle.ModuleFile{
|
|
URL: path,
|
|
Path: path,
|
|
Raw: []byte(content),
|
|
Parsed: nil,
|
|
})
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := bundle.Write(&buf, b); err != nil {
|
|
t.Fatalf("failed to write test bundle: %v", err)
|
|
}
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func TestEvaluator_BundleFromURL(t *testing.T) {
|
|
policy := `
|
|
package retail.validation
|
|
|
|
import rego.v1
|
|
|
|
default result := {
|
|
"valid": true,
|
|
"violations": []
|
|
}
|
|
|
|
result := {
|
|
"valid": count(violations) == 0,
|
|
"violations": violations
|
|
}
|
|
|
|
violations contains msg if {
|
|
some item in input.message.order.items
|
|
item.quantity.count <= 0
|
|
msg := sprintf("item %s: quantity must be > 0", [item.id])
|
|
}
|
|
`
|
|
bundleData := buildTestBundle(t, map[string]string{
|
|
"retail/validation.rego": policy,
|
|
})
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/gzip")
|
|
w.Write(bundleData)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
eval, err := NewEvaluator([]string{srv.URL + "/bundle.tar.gz"}, "data.retail.validation.result", nil, true, 0)
|
|
if err != nil {
|
|
t.Fatalf("NewEvaluator with bundle failed: %v", err)
|
|
}
|
|
|
|
// Non-compliant
|
|
body := `{"message":{"order":{"items":[{"id":"x","quantity":{"count":0}}]}}}`
|
|
violations, err := eval.Evaluate(context.Background(), []byte(body))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 1 {
|
|
t.Fatalf("expected 1 violation, got %d: %v", len(violations), violations)
|
|
}
|
|
|
|
// Compliant
|
|
body = `{"message":{"order":{"items":[{"id":"x","quantity":{"count":5}}]}}}`
|
|
violations, err = eval.Evaluate(context.Background(), []byte(body))
|
|
if err != nil {
|
|
t.Fatalf("Evaluate failed: %v", err)
|
|
}
|
|
if len(violations) != 0 {
|
|
t.Errorf("expected 0 violations, got %v", violations)
|
|
}
|
|
}
|
|
|
|
func TestEnforcer_BundlePolicy(t *testing.T) {
|
|
policy := `
|
|
package retail.policy
|
|
|
|
import rego.v1
|
|
|
|
default result := {
|
|
"valid": true,
|
|
"violations": []
|
|
}
|
|
|
|
result := {
|
|
"valid": count(violations) == 0,
|
|
"violations": violations
|
|
}
|
|
|
|
violations contains "blocked" if {
|
|
input.context.action == "confirm"
|
|
not input.message.order.provider
|
|
}
|
|
`
|
|
bundleData := buildTestBundle(t, map[string]string{
|
|
"retail/policy.rego": policy,
|
|
})
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/gzip")
|
|
w.Write(bundleData)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
enforcer, err := New(context.Background(), map[string]string{
|
|
"type": "bundle",
|
|
"location": srv.URL + "/policy-bundle.tar.gz",
|
|
"query": "data.retail.policy.result",
|
|
"actions": "confirm",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New failed: %v", err)
|
|
}
|
|
|
|
// Non-compliant: confirm without provider
|
|
ctx := makeStepCtx("confirm", `{"context": {"action": "confirm"}, "message": {"order": {}}}`)
|
|
err = enforcer.CheckPolicy(ctx)
|
|
if err == nil {
|
|
t.Fatal("expected error for non-compliant message, got nil")
|
|
}
|
|
if _, ok := err.(*model.BadReqErr); !ok {
|
|
t.Errorf("expected *model.BadReqErr, got %T: %v", err, err)
|
|
}
|
|
|
|
// Compliant: confirm with provider
|
|
ctx = makeStepCtx("confirm", `{"context": {"action": "confirm"}, "message": {"order": {"provider": {"id": "p1"}}}}`)
|
|
err = enforcer.CheckPolicy(ctx)
|
|
if err != nil {
|
|
t.Errorf("expected nil error for compliant message, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_RefreshInterval(t *testing.T) {
|
|
cfg, err := ParseConfig(map[string]string{
|
|
"type": "dir",
|
|
"location": "/tmp",
|
|
"query": "data.policy.violations",
|
|
"refreshIntervalSeconds": "300",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if cfg.RefreshInterval != 300*time.Second {
|
|
t.Errorf("expected 300s refresh interval, got %v", cfg.RefreshInterval)
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_RefreshInterval_Zero(t *testing.T) {
|
|
cfg, err := ParseConfig(map[string]string{
|
|
"type": "dir",
|
|
"location": "/tmp",
|
|
"query": "data.policy.violations",
|
|
// no refreshIntervalSeconds → disabled
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if cfg.RefreshInterval != 0 {
|
|
t.Errorf("expected refresh disabled (0), got %v", cfg.RefreshInterval)
|
|
}
|
|
}
|
|
|
|
func TestParseConfig_RefreshInterval_Invalid(t *testing.T) {
|
|
_, err := ParseConfig(map[string]string{
|
|
"type": "dir",
|
|
"location": "/tmp",
|
|
"query": "data.policy.violations",
|
|
"refreshIntervalSeconds": "not-a-number",
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid refreshIntervalSeconds")
|
|
}
|
|
}
|
|
|
|
// TestEnforcer_HotReload verifies that the hot-reload goroutine picks up changes
|
|
// to a local policy file within the configured refresh interval.
|
|
func TestEnforcer_HotReload(t *testing.T) {
|
|
dir := t.TempDir()
|
|
policyPath := filepath.Join(dir, "policy.rego")
|
|
|
|
// Initial policy: always blocks confirm
|
|
blockPolicy := `package policy
|
|
import rego.v1
|
|
default result := {"valid": false, "violations": ["blocked by initial policy"]}
|
|
result := {"valid": false, "violations": ["blocked by initial policy"]}
|
|
`
|
|
if err := os.WriteFile(policyPath, []byte(blockPolicy), 0644); err != nil {
|
|
t.Fatalf("failed to write initial policy: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
enforcer, err := New(ctx, map[string]string{
|
|
"type": "dir",
|
|
"location": dir,
|
|
"query": "data.policy.result",
|
|
"refreshIntervalSeconds": "1", // 1s refresh for test speed
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New failed: %v", err)
|
|
}
|
|
|
|
// Confirm is blocked with initial policy
|
|
stepCtx := makeStepCtx("confirm", `{"context":{"action":"confirm"}}`)
|
|
if err := enforcer.CheckPolicy(stepCtx); err == nil {
|
|
t.Fatal("expected block from initial policy, got nil")
|
|
}
|
|
|
|
// Swap policy on disk to allow everything
|
|
allowPolicy := `package policy
|
|
import rego.v1
|
|
default result := {"valid": true, "violations": []}
|
|
`
|
|
if err := os.WriteFile(policyPath, []byte(allowPolicy), 0644); err != nil {
|
|
t.Fatalf("failed to write updated policy: %v", err)
|
|
}
|
|
|
|
// Wait up to 5s for the reload to fire and swap the evaluator
|
|
deadline := time.Now().Add(5 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
if err := enforcer.CheckPolicy(stepCtx); err == nil {
|
|
// Reload took effect
|
|
return
|
|
}
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
|
|
t.Fatal("hot-reload did not take effect within 5 seconds")
|
|
}
|
|
|
|
func TestParseConfig_FetchTimeout(t *testing.T) {
|
|
cfg, err := ParseConfig(map[string]string{
|
|
"type": "url",
|
|
"location": "https://example.com/policy.rego",
|
|
"query": "data.policy.violations",
|
|
"fetchTimeoutSeconds": "7",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if cfg.FetchTimeout != 7*time.Second {
|
|
t.Fatalf("expected fetch timeout 7s, got %s", cfg.FetchTimeout)
|
|
}
|
|
}
|
|
|
|
func TestEvaluator_FetchURL_Timeout(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(50 * time.Millisecond)
|
|
w.Write([]byte(`package policy
|
|
import rego.v1
|
|
violations := []`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
_, err := NewEvaluator([]string{srv.URL + "/slow.rego"}, "data.policy.violations", nil, false, 10*time.Millisecond)
|
|
if err == nil {
|
|
t.Fatal("expected timeout error for slow policy URL")
|
|
}
|
|
}
|
|
|
|
func TestExtractAction_NonStandardURLFallsBackToBody(t *testing.T) {
|
|
body := []byte(`{"context": {"action": "confirm"}}`)
|
|
action := extractAction("/bpp/caller/confirm/extra", body)
|
|
if action != "confirm" {
|
|
t.Fatalf("expected body fallback action 'confirm', got %q", action)
|
|
}
|
|
}
|
|
|
|
func TestEnforcer_DisabledSkipsEvaluatorInitialization(t *testing.T) {
|
|
enforcer, err := New(context.Background(), map[string]string{
|
|
"type": "url",
|
|
"location": "https://127.0.0.1:1/unreachable.rego",
|
|
"query": "data.policy.violations",
|
|
"enabled": "false",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("expected disabled enforcer to skip evaluator initialization, got %v", err)
|
|
}
|
|
if enforcer.getEvaluator() != nil {
|
|
t.Fatal("expected disabled enforcer to leave evaluator uninitialized")
|
|
}
|
|
}
|