Files
onix/pkg/plugin/implementation/opapolicychecker/enforcer.go
2026-03-24 17:59:48 +05:30

322 lines
8.6 KiB
Go

package opapolicychecker
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/beckn-one/beckn-onix/pkg/log"
"github.com/beckn-one/beckn-onix/pkg/model"
)
// Config holds the configuration for the OPA Policy Checker plugin.
type Config struct {
Type string
Location string
PolicyPaths []string
Query string
Actions []string
Enabled bool
DebugLogging bool
FetchTimeout time.Duration
IsBundle bool
RefreshInterval time.Duration // 0 = disabled
RuntimeConfig map[string]string
}
var knownKeys = map[string]bool{
"type": true,
"location": true,
"query": true,
"actions": true,
"enabled": true,
"debugLogging": true,
"fetchTimeoutSeconds": true,
"refreshIntervalSeconds": true,
}
func DefaultConfig() *Config {
return &Config{
Enabled: true,
FetchTimeout: defaultPolicyFetchTimeout,
RuntimeConfig: make(map[string]string),
}
}
// ParseConfig parses the plugin configuration map into a Config struct.
// Uses type + location pattern (matches schemav2validator).
func ParseConfig(cfg map[string]string) (*Config, error) {
config := DefaultConfig()
typ, hasType := cfg["type"]
if !hasType || typ == "" {
return nil, fmt.Errorf("'type' is required (url, file, dir, or bundle)")
}
config.Type = typ
location, hasLoc := cfg["location"]
if !hasLoc || location == "" {
return nil, fmt.Errorf("'location' is required")
}
config.Location = location
switch typ {
case "url":
for _, u := range strings.Split(location, ",") {
u = strings.TrimSpace(u)
if u != "" {
config.PolicyPaths = append(config.PolicyPaths, u)
}
}
case "file":
config.PolicyPaths = append(config.PolicyPaths, location)
case "dir":
config.PolicyPaths = append(config.PolicyPaths, location)
case "bundle":
config.IsBundle = true
config.PolicyPaths = append(config.PolicyPaths, location)
default:
return nil, fmt.Errorf("unsupported type %q (expected: url, file, dir, or bundle)", typ)
}
query, hasQuery := cfg["query"]
if !hasQuery || query == "" {
return nil, fmt.Errorf("'query' is required (e.g., data.policy.violations)")
}
config.Query = query
if actions, ok := cfg["actions"]; ok && actions != "" {
actionList := strings.Split(actions, ",")
config.Actions = make([]string, 0, len(actionList))
for _, action := range actionList {
action = strings.TrimSpace(action)
if action != "" {
config.Actions = append(config.Actions, action)
}
}
}
if enabled, ok := cfg["enabled"]; ok {
config.Enabled = enabled == "true" || enabled == "1"
}
if debug, ok := cfg["debugLogging"]; ok {
config.DebugLogging = debug == "true" || debug == "1"
}
if fts, ok := cfg["fetchTimeoutSeconds"]; ok && fts != "" {
secs, err := strconv.Atoi(fts)
if err != nil || secs <= 0 {
return nil, fmt.Errorf("'fetchTimeoutSeconds' must be a positive integer, got %q", fts)
}
config.FetchTimeout = time.Duration(secs) * time.Second
}
if ris, ok := cfg["refreshIntervalSeconds"]; ok && ris != "" {
secs, err := strconv.Atoi(ris)
if err != nil || secs < 0 {
return nil, fmt.Errorf("'refreshIntervalSeconds' must be a non-negative integer, got %q", ris)
}
config.RefreshInterval = time.Duration(secs) * time.Second
}
for k, v := range cfg {
if !knownKeys[k] {
config.RuntimeConfig[k] = v
}
}
return config, nil
}
func (c *Config) IsActionEnabled(action string) bool {
if len(c.Actions) == 0 {
return true
}
for _, a := range c.Actions {
if a == action {
return true
}
}
return false
}
// PolicyEnforcer evaluates beckn messages against OPA policies and NACKs non-compliant messages.
type PolicyEnforcer struct {
config *Config
evaluator *Evaluator
evaluatorMu sync.RWMutex
closeOnce sync.Once
done chan struct{}
}
// getEvaluator safely returns the current evaluator under a read lock.
func (e *PolicyEnforcer) getEvaluator() *Evaluator {
e.evaluatorMu.RLock()
ev := e.evaluator
e.evaluatorMu.RUnlock()
return ev
}
// setEvaluator safely swaps the evaluator under a write lock.
func (e *PolicyEnforcer) setEvaluator(ev *Evaluator) {
e.evaluatorMu.Lock()
e.evaluator = ev
e.evaluatorMu.Unlock()
}
func New(ctx context.Context, cfg map[string]string) (*PolicyEnforcer, error) {
config, err := ParseConfig(cfg)
if err != nil {
return nil, fmt.Errorf("opapolicychecker: config error: %w", err)
}
enforcer := &PolicyEnforcer{
config: config,
done: make(chan struct{}),
}
if !config.Enabled {
log.Warnf(ctx, "OPAPolicyChecker is disabled via config; policy enforcement will be skipped")
return enforcer, nil
}
evaluator, err := NewEvaluator(config.PolicyPaths, config.Query, config.RuntimeConfig, config.IsBundle, config.FetchTimeout)
if err != nil {
return nil, fmt.Errorf("opapolicychecker: failed to initialize OPA evaluator: %w", err)
}
enforcer.evaluator = evaluator
log.Infof(ctx, "OPAPolicyChecker initialized (actions=%v, query=%s, policies=%v, isBundle=%v, debugLogging=%v, fetchTimeout=%s, refreshInterval=%s)",
config.Actions, config.Query, evaluator.ModuleNames(), config.IsBundle, config.DebugLogging, config.FetchTimeout, config.RefreshInterval)
if config.RefreshInterval > 0 {
go enforcer.refreshLoop(ctx)
}
return enforcer, nil
}
// refreshLoop periodically reloads and recompiles OPA policies.
// Follows the schemav2validator pattern: driven by context cancellation.
func (e *PolicyEnforcer) refreshLoop(ctx context.Context) {
ticker := time.NewTicker(e.config.RefreshInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof(ctx, "OPAPolicyChecker: refresh loop stopped")
return
case <-e.done:
log.Infof(ctx, "OPAPolicyChecker: refresh loop stopped by Close()")
return
case <-ticker.C:
e.reloadPolicies(ctx)
}
}
}
// reloadPolicies reloads and recompiles all policies, atomically swapping the evaluator.
// Reload failures are non-fatal; the old evaluator stays active.
func (e *PolicyEnforcer) reloadPolicies(ctx context.Context) {
start := time.Now()
newEvaluator, err := NewEvaluator(
e.config.PolicyPaths,
e.config.Query,
e.config.RuntimeConfig,
e.config.IsBundle,
e.config.FetchTimeout,
)
if err != nil {
log.Errorf(ctx, err, "OPAPolicyChecker: policy reload failed (keeping previous policies): %v", err)
return
}
e.setEvaluator(newEvaluator)
log.Infof(ctx, "OPAPolicyChecker: policies reloaded in %s (modules=%v)", time.Since(start), newEvaluator.ModuleNames())
}
// CheckPolicy evaluates the message body against loaded OPA policies.
// Returns a BadReqErr (causing NACK) if violations are found.
// Returns an error on evaluation failure (fail closed).
func (e *PolicyEnforcer) CheckPolicy(ctx *model.StepContext) error {
if !e.config.Enabled {
log.Debug(ctx, "OPAPolicyChecker: plugin disabled, skipping")
return nil
}
action := extractAction(ctx.Request.URL.Path, ctx.Body)
if !e.config.IsActionEnabled(action) {
if e.config.DebugLogging {
log.Debugf(ctx, "OPAPolicyChecker: action %q not in configured actions %v, skipping", action, e.config.Actions)
}
return nil
}
ev := e.getEvaluator()
if ev == nil {
return model.NewBadReqErr(fmt.Errorf("policy evaluator is not initialized"))
}
if e.config.DebugLogging {
log.Debugf(ctx, "OPAPolicyChecker: evaluating policies for action %q (modules=%v)", action, ev.ModuleNames())
}
violations, err := ev.Evaluate(ctx, ctx.Body)
if err != nil {
log.Errorf(ctx, err, "OPAPolicyChecker: policy evaluation failed: %v", err)
return model.NewBadReqErr(fmt.Errorf("policy evaluation error: %w", err))
}
if len(violations) == 0 {
if e.config.DebugLogging {
log.Debugf(ctx, "OPAPolicyChecker: message compliant for action %q", action)
}
return nil
}
msg := fmt.Sprintf("policy violation(s): %s", strings.Join(violations, "; "))
log.Warnf(ctx, "OPAPolicyChecker: %s", msg)
return model.NewBadReqErr(fmt.Errorf("%s", msg))
}
func (e *PolicyEnforcer) Close() {
e.closeOnce.Do(func() {
close(e.done)
})
}
func extractAction(urlPath string, body []byte) string {
// /bpp/caller/confirm/extra as action "extra".
parts := strings.FieldsFunc(strings.Trim(urlPath, "/"), func(r rune) bool { return r == '/' })
if len(parts) == 3 && isBecknDirection(parts[1]) && parts[2] != "" {
return parts[2]
}
var payload struct {
Context struct {
Action string `json:"action"`
} `json:"context"`
}
if err := json.Unmarshal(body, &payload); err == nil && payload.Context.Action != "" {
return payload.Context.Action
}
return ""
}
func isBecknDirection(part string) bool {
switch part {
case "caller", "receiver", "reciever":
return true
default:
return false
}
}