Fix: address policy checker review feedback
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/beckn-one/beckn-onix/core/module"
|
"github.com/beckn-one/beckn-onix/core/module"
|
||||||
"github.com/beckn-one/beckn-onix/core/module/handler"
|
"github.com/beckn-one/beckn-onix/core/module/handler"
|
||||||
|
"github.com/beckn-one/beckn-onix/pkg/model"
|
||||||
"github.com/beckn-one/beckn-onix/pkg/plugin"
|
"github.com/beckn-one/beckn-onix/pkg/plugin"
|
||||||
"github.com/beckn-one/beckn-onix/pkg/plugin/definition"
|
"github.com/beckn-one/beckn-onix/pkg/plugin/definition"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
@@ -21,6 +23,15 @@ import (
|
|||||||
// MockPluginManager implements handler.PluginManager for testing.
|
// MockPluginManager implements handler.PluginManager for testing.
|
||||||
type MockPluginManager struct {
|
type MockPluginManager struct {
|
||||||
mock.Mock
|
mock.Mock
|
||||||
|
policyCheckerFunc func(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubPolicyChecker struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s stubPolicyChecker) CheckPolicy(*model.StepContext) error {
|
||||||
|
return s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Middleware returns a middleware function based on the provided configuration.
|
// Middleware returns a middleware function based on the provided configuration.
|
||||||
@@ -85,6 +96,9 @@ func (m *MockPluginManager) SchemaValidator(ctx context.Context, cfg *plugin.Con
|
|||||||
|
|
||||||
// PolicyChecker returns a mock implementation of the PolicyChecker interface.
|
// PolicyChecker returns a mock implementation of the PolicyChecker interface.
|
||||||
func (m *MockPluginManager) PolicyChecker(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) {
|
func (m *MockPluginManager) PolicyChecker(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) {
|
||||||
|
if m.policyCheckerFunc != nil {
|
||||||
|
return m.policyCheckerFunc(ctx, cfg)
|
||||||
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,6 +349,49 @@ func TestNewServerSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewServerRejectsPolicyViolation(t *testing.T) {
|
||||||
|
mockMgr := &MockPluginManager{
|
||||||
|
policyCheckerFunc: func(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) {
|
||||||
|
return stubPolicyChecker{err: model.NewBadReqErr(errors.New("blocked by policy"))}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &Config{
|
||||||
|
Modules: []module.Config{
|
||||||
|
{
|
||||||
|
Name: "policy-module",
|
||||||
|
Path: "/policy",
|
||||||
|
Handler: handler.Config{
|
||||||
|
Type: handler.HandlerTypeStd,
|
||||||
|
Plugins: handler.PluginCfg{
|
||||||
|
PolicyChecker: &plugin.Config{ID: "mock-policy"},
|
||||||
|
},
|
||||||
|
Steps: []string{"checkPolicy"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := newServer(context.Background(), mockMgr, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error creating server, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/policy", strings.NewReader(`{"context":{"action":"confirm"}}`))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400 for policy violation, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "NACK") {
|
||||||
|
t.Fatalf("expected NACK response, got %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "blocked by policy") {
|
||||||
|
t.Fatalf("expected policy error in response, got %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestNewServerFailure tests failure scenarios when creating a server.
|
// TestNewServerFailure tests failure scenarios when creating a server.
|
||||||
func TestNewServerFailure(t *testing.T) {
|
func TestNewServerFailure(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/beckn-one/beckn-onix/core/module/handler"
|
"github.com/beckn-one/beckn-onix/core/module/handler"
|
||||||
@@ -17,6 +18,7 @@ import (
|
|||||||
// with support for dynamically setting behavior.
|
// with support for dynamically setting behavior.
|
||||||
type mockPluginManager struct {
|
type mockPluginManager struct {
|
||||||
middlewareFunc func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error)
|
middlewareFunc func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error)
|
||||||
|
policyCheckerFunc func(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Middleware returns a mock middleware function based on the provided configuration.
|
// Middleware returns a mock middleware function based on the provided configuration.
|
||||||
@@ -81,10 +83,65 @@ func (m *mockPluginManager) SchemaValidator(ctx context.Context, cfg *plugin.Con
|
|||||||
|
|
||||||
// PolicyChecker returns a mock policy checker implementation.
|
// PolicyChecker returns a mock policy checker implementation.
|
||||||
func (m *mockPluginManager) PolicyChecker(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) {
|
func (m *mockPluginManager) PolicyChecker(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) {
|
||||||
|
if m.policyCheckerFunc != nil {
|
||||||
|
return m.policyCheckerFunc(ctx, cfg)
|
||||||
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockPolicyChecker struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockPolicyChecker) CheckPolicy(*model.StepContext) error {
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
// TestRegisterSuccess tests scenarios where the handler registration should succeed.
|
// TestRegisterSuccess tests scenarios where the handler registration should succeed.
|
||||||
|
func TestRegisterRejectsPolicyViolation(t *testing.T) {
|
||||||
|
mCfgs := []Config{
|
||||||
|
{
|
||||||
|
Name: "test-module",
|
||||||
|
Path: "/test",
|
||||||
|
Handler: handler.Config{
|
||||||
|
Type: handler.HandlerTypeStd,
|
||||||
|
Plugins: handler.PluginCfg{
|
||||||
|
PolicyChecker: &plugin.Config{ID: "mock-policy"},
|
||||||
|
},
|
||||||
|
Steps: []string{"checkPolicy"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockManager := &mockPluginManager{
|
||||||
|
middlewareFunc: func(ctx context.Context, cfg *plugin.Config) (func(http.Handler) http.Handler, error) {
|
||||||
|
return func(next http.Handler) http.Handler { return next }, nil
|
||||||
|
},
|
||||||
|
policyCheckerFunc: func(ctx context.Context, cfg *plugin.Config) (definition.PolicyChecker, error) {
|
||||||
|
return mockPolicyChecker{err: model.NewBadReqErr(errors.New("blocked by policy"))}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
if err := Register(context.Background(), mCfgs, mux, mockManager); err != nil {
|
||||||
|
t.Fatalf("unexpected register error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader(`{"context":{"action":"confirm"}}`))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
mux.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400 for policy violation, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "NACK") {
|
||||||
|
t.Fatalf("expected NACK response, got %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(rec.Body.String(), "blocked by policy") {
|
||||||
|
t.Fatalf("expected policy error in response, got %s", rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRegisterSuccess(t *testing.T) {
|
func TestRegisterSuccess(t *testing.T) {
|
||||||
mCfgs := []Config{
|
mCfgs := []Config{
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ Validates incoming Beckn messages against network-defined business rules using [
|
|||||||
- Fail-closed on empty/undefined query results — misconfigured policies are treated as violations
|
- Fail-closed on empty/undefined query results — misconfigured policies are treated as violations
|
||||||
- Runtime config forwarding: adapter config values are accessible in Rego as `data.config.<key>`
|
- Runtime config forwarding: adapter config values are accessible in Rego as `data.config.<key>`
|
||||||
- Action-based enforcement: apply policies only to specific beckn actions (e.g., `confirm`, `search`)
|
- Action-based enforcement: apply policies only to specific beckn actions (e.g., `confirm`, `search`)
|
||||||
|
- Configurable fetch timeout for remote policy and bundle sources
|
||||||
|
- Warns at startup when policy enforcement is explicitly disabled
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
@@ -36,6 +38,7 @@ steps:
|
|||||||
| `actions` | string | No | *(all)* | Comma-separated beckn actions to enforce |
|
| `actions` | string | No | *(all)* | Comma-separated beckn actions to enforce |
|
||||||
| `enabled` | string | No | `"true"` | Enable or disable the plugin |
|
| `enabled` | string | No | `"true"` | Enable or disable the plugin |
|
||||||
| `debugLogging` | string | No | `"false"` | Enable verbose OPA evaluation logging |
|
| `debugLogging` | string | No | `"false"` | Enable verbose OPA evaluation logging |
|
||||||
|
| `fetchTimeoutSeconds` | string | No | `"30"` | Timeout in seconds for fetching remote `.rego` files or bundles |
|
||||||
| `refreshIntervalSeconds` | string | No | - | Reload policies every N seconds (0 or omit = disabled) |
|
| `refreshIntervalSeconds` | string | No | - | Reload policies every N seconds (0 or omit = disabled) |
|
||||||
| *any other key* | string | No | - | Forwarded to Rego as `data.config.<key>` |
|
| *any other key* | string | No | - | Forwarded to Rego as `data.config.<key>` |
|
||||||
|
|
||||||
@@ -47,7 +50,7 @@ When `refreshIntervalSeconds` is set, a background goroutine periodically re-fet
|
|||||||
|
|
||||||
- **Atomic swap**: the old evaluator stays fully active until the new one is compiled — no gap in enforcement
|
- **Atomic swap**: the old evaluator stays fully active until the new one is compiled — no gap in enforcement
|
||||||
- **Non-fatal errors**: if the reload fails (e.g., file temporarily unreachable or parse error), the error is logged and the previous policy stays active
|
- **Non-fatal errors**: if the reload fails (e.g., file temporarily unreachable or parse error), the error is logged and the previous policy stays active
|
||||||
- **Goroutine lifecycle**: the reload loop is tied to the adapter context and stops cleanly on shutdown
|
- **Goroutine lifecycle**: the reload loop stops when the adapter context is cancelled or when plugin `Close()` is invoked during shutdown
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
config:
|
config:
|
||||||
@@ -67,7 +70,7 @@ config:
|
|||||||
|
|
||||||
### Request Evaluation (Runtime)
|
### Request Evaluation (Runtime)
|
||||||
|
|
||||||
1. **Check Action Match**: If `actions` is configured, skip evaluation for non-matching actions
|
1. **Check Action Match**: If `actions` is configured, skip evaluation for non-matching actions. The plugin assumes standard adapter routes look like `/{participant}/{direction}/{action}` such as `/bpp/caller/confirm`; non-standard paths fall back to `context.action` from the JSON body.
|
||||||
2. **Evaluate OPA Query**: Run the prepared query with the full beckn message as `input`
|
2. **Evaluate OPA Query**: Run the prepared query with the full beckn message as `input`
|
||||||
3. **Handle Result**:
|
3. **Handle Result**:
|
||||||
- If the query returns no result (undefined) → **violation** (fail-closed)
|
- If the query returns no result (undefined) → **violation** (fail-closed)
|
||||||
@@ -109,6 +112,7 @@ checkPolicy:
|
|||||||
type: url
|
type: url
|
||||||
location: https://policies.example.com/compliance.rego
|
location: https://policies.example.com/compliance.rego
|
||||||
query: "data.policy.result"
|
query: "data.policy.result"
|
||||||
|
fetchTimeoutSeconds: "10"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Local Directory (multiple `.rego` files)
|
### Local Directory (multiple `.rego` files)
|
||||||
@@ -192,3 +196,4 @@ Configure them side-by-side in your adapter steps as needed.
|
|||||||
|
|
||||||
- **No bundle signature verification**: When using `type: bundle`, bundle signature verification is skipped. This is planned for a future enhancement.
|
- **No bundle signature verification**: When using `type: bundle`, bundle signature verification is skipped. This is planned for a future enhancement.
|
||||||
- **Network-level scoping**: Policies apply to all messages handled by the adapter instance. Per-network policy mapping (by `networkId`) is tracked for follow-up.
|
- **Network-level scoping**: Policies apply to all messages handled by the adapter instance. Per-network policy mapping (by `networkId`) is tracked for follow-up.
|
||||||
|
- **Non-standard route shapes**: URL-based action extraction assumes the standard Beckn adapter route shape `/{participant}/{direction}/{action}` and falls back to `context.action` for other path layouts.
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func BenchmarkEvaluate_MostlyInactive(b *testing.B) {
|
|||||||
dir := b.TempDir()
|
dir := b.TempDir()
|
||||||
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644)
|
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("NewEvaluator failed: %v", err)
|
b.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -135,7 +135,7 @@ func BenchmarkEvaluate_AllActive(b *testing.B) {
|
|||||||
dir := b.TempDir()
|
dir := b.TempDir()
|
||||||
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateActiveRules(n)), 0644)
|
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateActiveRules(n)), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("NewEvaluator failed: %v", err)
|
b.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -172,7 +172,7 @@ func BenchmarkCompilation(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
_, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("NewEvaluator failed: %v", err)
|
b.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -206,7 +206,7 @@ func TestBenchmarkReport(t *testing.T) {
|
|||||||
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644)
|
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644)
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
_, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
_, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
elapsed := time.Since(start)
|
elapsed := time.Since(start)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err)
|
t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err)
|
||||||
@@ -227,7 +227,7 @@ func TestBenchmarkReport(t *testing.T) {
|
|||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644)
|
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateDummyRules(n)), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err)
|
t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err)
|
||||||
}
|
}
|
||||||
@@ -263,7 +263,7 @@ func TestBenchmarkReport(t *testing.T) {
|
|||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateActiveRules(n)), 0644)
|
os.WriteFile(filepath.Join(dir, "policy.rego"), []byte(generateActiveRules(n)), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err)
|
t.Fatalf("NewEvaluator(%d rules) failed: %v", n, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,41 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProviderNewSuccess(t *testing.T) {
|
||||||
|
provider := provider{}
|
||||||
|
config := map[string]string{
|
||||||
|
"type": "file",
|
||||||
|
"location": filepath.Join("..", "testdata", "example.rego"),
|
||||||
|
"query": "data.policy.result",
|
||||||
|
}
|
||||||
|
|
||||||
|
checker, closer, err := provider.New(context.Background(), config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if checker == nil {
|
||||||
|
t.Fatal("New() returned nil checker")
|
||||||
|
}
|
||||||
|
if closer == nil {
|
||||||
|
t.Fatal("New() returned nil closer")
|
||||||
|
}
|
||||||
|
|
||||||
|
closer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProviderNewFailure(t *testing.T) {
|
||||||
|
provider := provider{}
|
||||||
|
|
||||||
|
_, _, err := provider.New(context.Background(), map[string]string{
|
||||||
|
"type": "file",
|
||||||
|
"query": "data.policy.result",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when required config is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ type Config struct {
|
|||||||
Actions []string
|
Actions []string
|
||||||
Enabled bool
|
Enabled bool
|
||||||
DebugLogging bool
|
DebugLogging bool
|
||||||
|
FetchTimeout time.Duration
|
||||||
IsBundle bool
|
IsBundle bool
|
||||||
RefreshInterval time.Duration // 0 = disabled
|
RefreshInterval time.Duration // 0 = disabled
|
||||||
RuntimeConfig map[string]string
|
RuntimeConfig map[string]string
|
||||||
@@ -34,12 +35,14 @@ var knownKeys = map[string]bool{
|
|||||||
"actions": true,
|
"actions": true,
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"debugLogging": true,
|
"debugLogging": true,
|
||||||
|
"fetchTimeoutSeconds": true,
|
||||||
"refreshIntervalSeconds": true,
|
"refreshIntervalSeconds": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultConfig() *Config {
|
func DefaultConfig() *Config {
|
||||||
return &Config{
|
return &Config{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
FetchTimeout: defaultPolicyFetchTimeout,
|
||||||
RuntimeConfig: make(map[string]string),
|
RuntimeConfig: make(map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -105,6 +108,14 @@ func ParseConfig(cfg map[string]string) (*Config, error) {
|
|||||||
config.DebugLogging = debug == "true" || debug == "1"
|
config.DebugLogging = debug == "true" || debug == "1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if fts, ok := cfg["fetchTimeoutSeconds"]; ok && fts != "" {
|
||||||
|
secs, err := strconv.Atoi(fts)
|
||||||
|
if err != nil || secs <= 0 {
|
||||||
|
return nil, fmt.Errorf("'fetchTimeoutSeconds' must be a positive integer, got %q", fts)
|
||||||
|
}
|
||||||
|
config.FetchTimeout = time.Duration(secs) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
if ris, ok := cfg["refreshIntervalSeconds"]; ok && ris != "" {
|
if ris, ok := cfg["refreshIntervalSeconds"]; ok && ris != "" {
|
||||||
secs, err := strconv.Atoi(ris)
|
secs, err := strconv.Atoi(ris)
|
||||||
if err != nil || secs < 0 {
|
if err != nil || secs < 0 {
|
||||||
@@ -139,6 +150,8 @@ type PolicyEnforcer struct {
|
|||||||
config *Config
|
config *Config
|
||||||
evaluator *Evaluator
|
evaluator *Evaluator
|
||||||
evaluatorMu sync.RWMutex
|
evaluatorMu sync.RWMutex
|
||||||
|
closeOnce sync.Once
|
||||||
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getEvaluator safely returns the current evaluator under a read lock.
|
// getEvaluator safely returns the current evaluator under a read lock.
|
||||||
@@ -162,18 +175,24 @@ func New(ctx context.Context, cfg map[string]string) (*PolicyEnforcer, error) {
|
|||||||
return nil, fmt.Errorf("opapolicychecker: config error: %w", err)
|
return nil, fmt.Errorf("opapolicychecker: config error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
evaluator, err := NewEvaluator(config.PolicyPaths, config.Query, config.RuntimeConfig, config.IsBundle)
|
enforcer := &PolicyEnforcer{
|
||||||
|
config: config,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !config.Enabled {
|
||||||
|
log.Warnf(ctx, "OPAPolicyChecker is disabled via config; policy enforcement will be skipped")
|
||||||
|
return enforcer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
evaluator, err := NewEvaluator(config.PolicyPaths, config.Query, config.RuntimeConfig, config.IsBundle, config.FetchTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("opapolicychecker: failed to initialize OPA evaluator: %w", err)
|
return nil, fmt.Errorf("opapolicychecker: failed to initialize OPA evaluator: %w", err)
|
||||||
}
|
}
|
||||||
|
enforcer.evaluator = evaluator
|
||||||
|
|
||||||
log.Infof(ctx, "OPAPolicyChecker initialized (actions=%v, query=%s, policies=%v, isBundle=%v, debugLogging=%v, refreshInterval=%s)",
|
log.Infof(ctx, "OPAPolicyChecker initialized (actions=%v, query=%s, policies=%v, isBundle=%v, debugLogging=%v, fetchTimeout=%s, refreshInterval=%s)",
|
||||||
config.Actions, config.Query, evaluator.ModuleNames(), config.IsBundle, config.DebugLogging, config.RefreshInterval)
|
config.Actions, config.Query, evaluator.ModuleNames(), config.IsBundle, config.DebugLogging, config.FetchTimeout, config.RefreshInterval)
|
||||||
|
|
||||||
enforcer := &PolicyEnforcer{
|
|
||||||
config: config,
|
|
||||||
evaluator: evaluator,
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.RefreshInterval > 0 {
|
if config.RefreshInterval > 0 {
|
||||||
go enforcer.refreshLoop(ctx)
|
go enforcer.refreshLoop(ctx)
|
||||||
@@ -193,6 +212,9 @@ func (e *PolicyEnforcer) refreshLoop(ctx context.Context) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Infof(ctx, "OPAPolicyChecker: refresh loop stopped")
|
log.Infof(ctx, "OPAPolicyChecker: refresh loop stopped")
|
||||||
return
|
return
|
||||||
|
case <-e.done:
|
||||||
|
log.Infof(ctx, "OPAPolicyChecker: refresh loop stopped by Close()")
|
||||||
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
e.reloadPolicies(ctx)
|
e.reloadPolicies(ctx)
|
||||||
}
|
}
|
||||||
@@ -208,6 +230,7 @@ func (e *PolicyEnforcer) reloadPolicies(ctx context.Context) {
|
|||||||
e.config.Query,
|
e.config.Query,
|
||||||
e.config.RuntimeConfig,
|
e.config.RuntimeConfig,
|
||||||
e.config.IsBundle,
|
e.config.IsBundle,
|
||||||
|
e.config.FetchTimeout,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(ctx, err, "OPAPolicyChecker: policy reload failed (keeping previous policies): %v", err)
|
log.Errorf(ctx, err, "OPAPolicyChecker: policy reload failed (keeping previous policies): %v", err)
|
||||||
@@ -237,6 +260,9 @@ func (e *PolicyEnforcer) CheckPolicy(ctx *model.StepContext) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ev := e.getEvaluator()
|
ev := e.getEvaluator()
|
||||||
|
if ev == nil {
|
||||||
|
return model.NewBadReqErr(fmt.Errorf("policy evaluator is not initialized"))
|
||||||
|
}
|
||||||
|
|
||||||
if e.config.DebugLogging {
|
if e.config.DebugLogging {
|
||||||
log.Debugf(ctx, "OPAPolicyChecker: evaluating policies for action %q (modules=%v)", action, ev.ModuleNames())
|
log.Debugf(ctx, "OPAPolicyChecker: evaluating policies for action %q (modules=%v)", action, ev.ModuleNames())
|
||||||
@@ -260,12 +286,17 @@ func (e *PolicyEnforcer) CheckPolicy(ctx *model.StepContext) error {
|
|||||||
return model.NewBadReqErr(fmt.Errorf("%s", msg))
|
return model.NewBadReqErr(fmt.Errorf("%s", msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *PolicyEnforcer) Close() {}
|
func (e *PolicyEnforcer) Close() {
|
||||||
|
e.closeOnce.Do(func() {
|
||||||
|
close(e.done)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func extractAction(urlPath string, body []byte) string {
|
func extractAction(urlPath string, body []byte) string {
|
||||||
parts := strings.Split(strings.Trim(urlPath, "/"), "/")
|
// /bpp/caller/confirm/extra as action "extra".
|
||||||
if len(parts) >= 3 {
|
parts := strings.FieldsFunc(strings.Trim(urlPath, "/"), func(r rune) bool { return r == '/' })
|
||||||
return parts[len(parts)-1]
|
if len(parts) == 3 && isBecknDirection(parts[1]) && parts[2] != "" {
|
||||||
|
return parts[2]
|
||||||
}
|
}
|
||||||
|
|
||||||
var payload struct {
|
var payload struct {
|
||||||
@@ -279,3 +310,12 @@ func extractAction(urlPath string, body []byte) string {
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isBecknDirection(part string) bool {
|
||||||
|
switch part {
|
||||||
|
case "caller", "receiver", "reciever":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ violations contains msg if {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
dir := writePolicyDir(t, "test.rego", policy)
|
dir := writePolicyDir(t, "test.rego", policy)
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -165,7 +165,7 @@ violations contains msg if {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
dir := writePolicyDir(t, "test.rego", policy)
|
dir := writePolicyDir(t, "test.rego", policy)
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -192,7 +192,7 @@ violations contains msg if {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
dir := writePolicyDir(t, "test.rego", policy)
|
dir := writePolicyDir(t, "test.rego", policy)
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", map[string]string{"maxValue": "100"}, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", map[string]string{"maxValue": "100"}, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -235,7 +235,7 @@ test_something if { count(policy.violations) > 0 }
|
|||||||
`
|
`
|
||||||
os.WriteFile(filepath.Join(dir, "policy_test.rego"), []byte(testFile), 0644)
|
os.WriteFile(filepath.Join(dir, "policy_test.rego"), []byte(testFile), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator should skip _test.rego files, but failed: %v", err)
|
t.Fatalf("NewEvaluator should skip _test.rego files, but failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -256,7 +256,7 @@ import rego.v1
|
|||||||
violations := set()
|
violations := set()
|
||||||
`
|
`
|
||||||
dir := writePolicyDir(t, "test.rego", policy)
|
dir := writePolicyDir(t, "test.rego", policy)
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -285,7 +285,7 @@ violations contains msg if {
|
|||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{srv.URL + "/test_policy.rego"}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{srv.URL + "/test_policy.rego"}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator with URL failed: %v", err)
|
t.Fatalf("NewEvaluator with URL failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -313,14 +313,14 @@ func TestEvaluator_FetchURL_NotFound(t *testing.T) {
|
|||||||
srv := httptest.NewServer(http.NotFoundHandler())
|
srv := httptest.NewServer(http.NotFoundHandler())
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
_, err := NewEvaluator([]string{srv.URL + "/missing.rego"}, "data.policy.violations", nil, false)
|
_, err := NewEvaluator([]string{srv.URL + "/missing.rego"}, "data.policy.violations", nil, false, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for 404 URL")
|
t.Fatal("expected error for 404 URL")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEvaluator_FetchURL_InvalidScheme(t *testing.T) {
|
func TestEvaluator_FetchURL_InvalidScheme(t *testing.T) {
|
||||||
_, err := NewEvaluator([]string{"ftp://example.com/policy.rego"}, "data.policy.violations", nil, false)
|
_, err := NewEvaluator([]string{"ftp://example.com/policy.rego"}, "data.policy.violations", nil, false, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for ftp:// scheme")
|
t.Fatal("expected error for ftp:// scheme")
|
||||||
}
|
}
|
||||||
@@ -346,7 +346,7 @@ violations contains "remote_violation" if { input.remote_bad }
|
|||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir, srv.URL + "/remote.rego"}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir, srv.URL + "/remote.rego"}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -373,7 +373,7 @@ violations contains "from_file" if { input.bad }
|
|||||||
policyPath := filepath.Join(dir, "local_policy.rego")
|
policyPath := filepath.Join(dir, "local_policy.rego")
|
||||||
os.WriteFile(policyPath, []byte(policy), 0644)
|
os.WriteFile(policyPath, []byte(policy), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{policyPath}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{policyPath}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator with local path failed: %v", err)
|
t.Fatalf("NewEvaluator with local path failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -412,7 +412,7 @@ violations contains "order too large" if { is_high_value }
|
|||||||
`
|
`
|
||||||
os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644)
|
os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -462,7 +462,7 @@ violations contains "high value confirm blocked" if {
|
|||||||
`
|
`
|
||||||
os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644)
|
os.WriteFile(filepath.Join(dir, "rules.rego"), []byte(rules), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -531,7 +531,7 @@ violations contains "safety: order value too high" if {
|
|||||||
`
|
`
|
||||||
os.WriteFile(filepath.Join(dir, "safety.rego"), []byte(safety), 0644)
|
os.WriteFile(filepath.Join(dir, "safety.rego"), []byte(safety), 0644)
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.violations", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -744,7 +744,7 @@ default result := {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
dir := writePolicyDir(t, "policy.rego", policy)
|
dir := writePolicyDir(t, "policy.rego", policy)
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -782,7 +782,7 @@ violations contains msg if {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
dir := writePolicyDir(t, "policy.rego", policy)
|
dir := writePolicyDir(t, "policy.rego", policy)
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.retail.policy.result", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -841,7 +841,7 @@ result := {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
dir := writePolicyDir(t, "policy.rego", policy)
|
dir := writePolicyDir(t, "policy.rego", policy)
|
||||||
eval, err := NewEvaluator([]string{dir}, "data.policy.result", nil, false)
|
eval, err := NewEvaluator([]string{dir}, "data.policy.result", nil, false, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator failed: %v", err)
|
t.Fatalf("NewEvaluator failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -912,7 +912,7 @@ violations contains msg if {
|
|||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
|
||||||
eval, err := NewEvaluator([]string{srv.URL + "/bundle.tar.gz"}, "data.retail.validation.result", nil, true)
|
eval, err := NewEvaluator([]string{srv.URL + "/bundle.tar.gz"}, "data.retail.validation.result", nil, true, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewEvaluator with bundle failed: %v", err)
|
t.Fatalf("NewEvaluator with bundle failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1096,3 +1096,55 @@ default result := {"valid": true, "violations": []}
|
|||||||
t.Fatal("hot-reload did not take effect within 5 seconds")
|
t.Fatal("hot-reload did not take effect within 5 seconds")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseConfig_FetchTimeout(t *testing.T) {
|
||||||
|
cfg, err := ParseConfig(map[string]string{
|
||||||
|
"type": "url",
|
||||||
|
"location": "https://example.com/policy.rego",
|
||||||
|
"query": "data.policy.violations",
|
||||||
|
"fetchTimeoutSeconds": "7",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.FetchTimeout != 7*time.Second {
|
||||||
|
t.Fatalf("expected fetch timeout 7s, got %s", cfg.FetchTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEvaluator_FetchURL_Timeout(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
w.Write([]byte(`package policy
|
||||||
|
import rego.v1
|
||||||
|
violations := []`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
_, err := NewEvaluator([]string{srv.URL + "/slow.rego"}, "data.policy.violations", nil, false, 10*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected timeout error for slow policy URL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractAction_NonStandardURLFallsBackToBody(t *testing.T) {
|
||||||
|
body := []byte(`{"context": {"action": "confirm"}}`)
|
||||||
|
action := extractAction("/bpp/caller/confirm/extra", body)
|
||||||
|
if action != "confirm" {
|
||||||
|
t.Fatalf("expected body fallback action 'confirm', got %q", action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforcer_DisabledSkipsEvaluatorInitialization(t *testing.T) {
|
||||||
|
enforcer, err := New(context.Background(), map[string]string{
|
||||||
|
"type": "url",
|
||||||
|
"location": "https://127.0.0.1:1/unreachable.rego",
|
||||||
|
"query": "data.policy.violations",
|
||||||
|
"enabled": "false",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected disabled enforcer to skip evaluator initialization, got %v", err)
|
||||||
|
}
|
||||||
|
if enforcer.getEvaluator() != nil {
|
||||||
|
t.Fatal("expected disabled enforcer to leave evaluator uninitialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,8 +35,9 @@ func (e *Evaluator) ModuleNames() []string {
|
|||||||
return e.moduleNames
|
return e.moduleNames
|
||||||
}
|
}
|
||||||
|
|
||||||
// policyFetchTimeout is the HTTP timeout for fetching remote .rego files.
|
// defaultPolicyFetchTimeout bounds remote policy and bundle fetches during startup
|
||||||
const policyFetchTimeout = 30 * time.Second
|
// and refresh. This can be overridden via config.fetchTimeoutSeconds.
|
||||||
|
const defaultPolicyFetchTimeout = 30 * time.Second
|
||||||
|
|
||||||
// maxPolicySize is the maximum size of a single .rego file fetched from a URL (1 MB).
|
// maxPolicySize is the maximum size of a single .rego file fetched from a URL (1 MB).
|
||||||
const maxPolicySize = 1 << 20
|
const maxPolicySize = 1 << 20
|
||||||
@@ -47,21 +48,24 @@ const maxBundleSize = 10 << 20
|
|||||||
// NewEvaluator creates an Evaluator by loading .rego files from local paths
|
// NewEvaluator creates an Evaluator by loading .rego files from local paths
|
||||||
// and/or URLs, then compiling them. runtimeConfig is passed to Rego as data.config.
|
// and/or URLs, then compiling them. runtimeConfig is passed to Rego as data.config.
|
||||||
// When isBundle is true, the first policyPath is treated as a URL to an OPA bundle (.tar.gz).
|
// When isBundle is true, the first policyPath is treated as a URL to an OPA bundle (.tar.gz).
|
||||||
func NewEvaluator(policyPaths []string, query string, runtimeConfig map[string]string, isBundle bool) (*Evaluator, error) {
|
func NewEvaluator(policyPaths []string, query string, runtimeConfig map[string]string, isBundle bool, fetchTimeout time.Duration) (*Evaluator, error) {
|
||||||
if isBundle {
|
if fetchTimeout <= 0 {
|
||||||
return newBundleEvaluator(policyPaths, query, runtimeConfig)
|
fetchTimeout = defaultPolicyFetchTimeout
|
||||||
}
|
}
|
||||||
return newRegoEvaluator(policyPaths, query, runtimeConfig)
|
if isBundle {
|
||||||
|
return newBundleEvaluator(policyPaths, query, runtimeConfig, fetchTimeout)
|
||||||
|
}
|
||||||
|
return newRegoEvaluator(policyPaths, query, runtimeConfig, fetchTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newRegoEvaluator loads raw .rego files from local paths and/or URLs.
|
// newRegoEvaluator loads raw .rego files from local paths and/or URLs.
|
||||||
func newRegoEvaluator(policyPaths []string, query string, runtimeConfig map[string]string) (*Evaluator, error) {
|
func newRegoEvaluator(policyPaths []string, query string, runtimeConfig map[string]string, fetchTimeout time.Duration) (*Evaluator, error) {
|
||||||
modules := make(map[string]string)
|
modules := make(map[string]string)
|
||||||
|
|
||||||
// Load from policyPaths (resolved locations based on config Type)
|
// Load from policyPaths (resolved locations based on config Type)
|
||||||
for _, source := range policyPaths {
|
for _, source := range policyPaths {
|
||||||
if isURL(source) {
|
if isURL(source) {
|
||||||
name, content, err := fetchPolicy(source)
|
name, content, err := fetchPolicy(source, fetchTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to fetch policy from %s: %w", source, err)
|
return nil, fmt.Errorf("failed to fetch policy from %s: %w", source, err)
|
||||||
}
|
}
|
||||||
@@ -101,13 +105,13 @@ func newRegoEvaluator(policyPaths []string, query string, runtimeConfig map[stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newBundleEvaluator loads an OPA bundle (.tar.gz) from a URL and compiles it.
|
// newBundleEvaluator loads an OPA bundle (.tar.gz) from a URL and compiles it.
|
||||||
func newBundleEvaluator(policyPaths []string, query string, runtimeConfig map[string]string) (*Evaluator, error) {
|
func newBundleEvaluator(policyPaths []string, query string, runtimeConfig map[string]string, fetchTimeout time.Duration) (*Evaluator, error) {
|
||||||
if len(policyPaths) == 0 {
|
if len(policyPaths) == 0 {
|
||||||
return nil, fmt.Errorf("bundle source URL is required")
|
return nil, fmt.Errorf("bundle source URL is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
bundleURL := policyPaths[0]
|
bundleURL := policyPaths[0]
|
||||||
modules, bundleData, err := loadBundle(bundleURL)
|
modules, bundleData, err := loadBundle(bundleURL, fetchTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load bundle from %s: %w", bundleURL, err)
|
return nil, fmt.Errorf("failed to load bundle from %s: %w", bundleURL, err)
|
||||||
}
|
}
|
||||||
@@ -121,8 +125,8 @@ func newBundleEvaluator(policyPaths []string, query string, runtimeConfig map[st
|
|||||||
|
|
||||||
// loadBundle downloads a .tar.gz OPA bundle from a URL, parses it using OPA's
|
// loadBundle downloads a .tar.gz OPA bundle from a URL, parses it using OPA's
|
||||||
// bundle reader, and returns the modules and data from the bundle.
|
// bundle reader, and returns the modules and data from the bundle.
|
||||||
func loadBundle(bundleURL string) (map[string]string, map[string]interface{}, error) {
|
func loadBundle(bundleURL string, fetchTimeout time.Duration) (map[string]string, map[string]interface{}, error) {
|
||||||
data, err := fetchBundleArchive(bundleURL)
|
data, err := fetchBundleArchive(bundleURL, fetchTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -131,7 +135,7 @@ func loadBundle(bundleURL string) (map[string]string, map[string]interface{}, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// fetchBundleArchive downloads a bundle .tar.gz from a URL.
|
// fetchBundleArchive downloads a bundle .tar.gz from a URL.
|
||||||
func fetchBundleArchive(rawURL string) ([]byte, error) {
|
func fetchBundleArchive(rawURL string, fetchTimeout time.Duration) ([]byte, error) {
|
||||||
parsed, err := url.Parse(rawURL)
|
parsed, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||||
@@ -141,7 +145,7 @@ func fetchBundleArchive(rawURL string) ([]byte, error) {
|
|||||||
return nil, fmt.Errorf("unsupported URL scheme %q (only http and https are supported)", parsed.Scheme)
|
return nil, fmt.Errorf("unsupported URL scheme %q (only http and https are supported)", parsed.Scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{Timeout: policyFetchTimeout}
|
client := &http.Client{Timeout: fetchTimeout}
|
||||||
resp, err := client.Get(rawURL)
|
resp, err := client.Get(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("HTTP request failed: %w", err)
|
return nil, fmt.Errorf("HTTP request failed: %w", err)
|
||||||
@@ -229,7 +233,7 @@ func isURL(source string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// fetchPolicy downloads a .rego file from a URL and returns (filename, content, error).
|
// fetchPolicy downloads a .rego file from a URL and returns (filename, content, error).
|
||||||
func fetchPolicy(rawURL string) (string, string, error) {
|
func fetchPolicy(rawURL string, fetchTimeout time.Duration) (string, string, error) {
|
||||||
parsed, err := url.Parse(rawURL)
|
parsed, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("invalid URL: %w", err)
|
return "", "", fmt.Errorf("invalid URL: %w", err)
|
||||||
@@ -239,7 +243,7 @@ func fetchPolicy(rawURL string) (string, string, error) {
|
|||||||
return "", "", fmt.Errorf("unsupported URL scheme %q (only http and https are supported)", parsed.Scheme)
|
return "", "", fmt.Errorf("unsupported URL scheme %q (only http and https are supported)", parsed.Scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{Timeout: policyFetchTimeout}
|
client := &http.Client{Timeout: fetchTimeout}
|
||||||
resp, err := client.Get(rawURL)
|
resp, err := client.Get(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("HTTP request failed: %w", err)
|
return "", "", fmt.Errorf("HTTP request failed: %w", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user