Issue580-feat: add support for referenced schema validation and caching

This commit is contained in:
ameersohel45
2025-12-12 00:32:35 +05:30
parent 33cd3dc31f
commit 94943e63e6
2 changed files with 472 additions and 10 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"strconv"
"strings"
"github.com/beckn-one/beckn-onix/pkg/plugin/definition"
"github.com/beckn-one/beckn-onix/pkg/plugin/implementation/schemav2validator"
@@ -40,6 +41,36 @@ func (vp schemav2ValidatorProvider) New(ctx context.Context, config map[string]s
}
}
// NEW: Parse enableReferencedSchemas
if enableStr, ok := config["enableReferencedSchemas"]; ok {
cfg.EnableReferencedSchemas = enableStr == "true"
}
// NEW: Parse referencedSchemaConfig (if enabled)
if cfg.EnableReferencedSchemas {
if v, ok := config["referencedSchemaConfig.cacheTTL"]; ok {
if ttl, err := strconv.Atoi(v); err == nil && ttl > 0 {
cfg.ReferencedSchemaConfig.CacheTTL = ttl
}
}
if v, ok := config["referencedSchemaConfig.maxCacheSize"]; ok {
if size, err := strconv.Atoi(v); err == nil && size > 0 {
cfg.ReferencedSchemaConfig.MaxCacheSize = size
}
}
if v, ok := config["referencedSchemaConfig.downloadTimeout"]; ok {
if timeout, err := strconv.Atoi(v); err == nil && timeout > 0 {
cfg.ReferencedSchemaConfig.DownloadTimeout = timeout
}
}
if v, ok := config["referencedSchemaConfig.allowedDomains"]; ok && v != "" {
cfg.ReferencedSchemaConfig.AllowedDomains = strings.Split(v, ",")
}
if v, ok := config["referencedSchemaConfig.urlTransform"]; ok && v != "" {
cfg.ReferencedSchemaConfig.URLTransform = v
}
}
return schemav2validator.New(ctx, cfg)
}

View File

@@ -2,6 +2,8 @@ package schemav2validator
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
@@ -24,9 +26,10 @@ type payload struct {
// schemav2Validator implements the SchemaValidator interface.
type schemav2Validator struct {
config *Config
spec *cachedSpec
specMutex sync.RWMutex
config *Config
spec *cachedSpec
specMutex sync.RWMutex
schemaCache *schemaCache // NEW: cache for referenced schemas
}
// cachedSpec holds a cached OpenAPI spec.
@@ -41,6 +44,43 @@ type Config struct {
Type string // "url", "file", or "dir"
Location string // URL, file path, or directory path
CacheTTL int
// NEW: Referenced schema configuration
EnableReferencedSchemas bool
ReferencedSchemaConfig ReferencedSchemaConfig
}
// ReferencedSchemaConfig holds configuration for referenced schema validation.
type ReferencedSchemaConfig struct {
CacheTTL int // seconds, default 86400 (24h)
MaxCacheSize int // default 100
DownloadTimeout int // seconds, default 30
AllowedDomains []string // whitelist (empty = all allowed)
URLTransform string // e.g. "context.jsonld->attributes.yaml"
}
// referencedObject represents ANY object with @context in the request.
type referencedObject struct {
Path string
Context string
Type string
Data map[string]interface{}
}
// schemaCache caches loaded domain schemas with LRU eviction.
type schemaCache struct {
mu sync.RWMutex
schemas map[string]*cachedDomainSchema
maxSize int
}
// cachedDomainSchema holds a cached domain schema with metadata.
type cachedDomainSchema struct {
doc *openapi3.T
loadedAt time.Time
expiresAt time.Time
lastAccessed time.Time
accessCount int64
}
// New creates a new Schemav2Validator instance.
@@ -66,6 +106,16 @@ func New(ctx context.Context, config *Config) (*schemav2Validator, func() error,
config: config,
}
// NEW: Initialize referenced schema cache if enabled
if config.EnableReferencedSchemas {
maxSize := 100
if config.ReferencedSchemaConfig.MaxCacheSize > 0 {
maxSize = config.ReferencedSchemaConfig.MaxCacheSize
}
v.schemaCache = newSchemaCache(maxSize)
log.Infof(ctx, "Initialized referenced schema cache with max size: %d", maxSize)
}
if err := v.initialise(ctx); err != nil {
return nil, nil, fmt.Errorf("failed to initialise schemav2Validator: %v", err)
}
@@ -119,6 +169,19 @@ func (v *schemav2Validator) Validate(ctx context.Context, reqURL *url.URL, data
return v.formatValidationError(err)
}
log.Debugf(ctx, "LEVEL 1 validation passed for action: %s", action)
// NEW: LEVEL 2 - Referenced schema validation (if enabled)
if v.config.EnableReferencedSchemas && v.schemaCache != nil {
log.Debugf(ctx, "Starting LEVEL 2 validation for action: %s", action)
if err := v.validateReferencedSchemas(ctx, jsonData); err != nil {
// Level 2 failure - return error (same behavior as Level 1)
log.Debugf(ctx, "LEVEL 2 validation failed for action %s: %v", action, err)
return v.formatValidationError(err)
}
log.Debugf(ctx, "LEVEL 2 validation passed for action: %s", action)
}
return nil
}
@@ -181,15 +244,42 @@ func (v *schemav2Validator) loadSpec(ctx context.Context) error {
// refreshLoop periodically reloads expired specs based on TTL.
func (v *schemav2Validator) refreshLoop(ctx context.Context) {
ticker := time.NewTicker(time.Duration(v.config.CacheTTL) * time.Second)
defer ticker.Stop()
coreTicker := time.NewTicker(time.Duration(v.config.CacheTTL) * time.Second)
defer coreTicker.Stop()
// NEW: Ticker for referenced schema cleanup
var refTicker *time.Ticker
if v.config.EnableReferencedSchemas {
ttl := v.config.ReferencedSchemaConfig.CacheTTL
if ttl <= 0 {
ttl = 86400 // Default 24 hours
}
refTicker = time.NewTicker(time.Duration(ttl) * time.Second)
defer refTicker.Stop()
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
v.reloadExpiredSpec(ctx)
if refTicker != nil {
select {
case <-ctx.Done():
return
case <-coreTicker.C:
v.reloadExpiredSpec(ctx)
case <-refTicker.C:
if v.schemaCache != nil {
count := v.schemaCache.cleanupExpired()
if count > 0 {
log.Debugf(ctx, "Cleaned up %d expired referenced schemas", count)
}
}
}
} else {
select {
case <-ctx.Done():
return
case <-coreTicker.C:
v.reloadExpiredSpec(ctx)
}
}
}
}
@@ -309,6 +399,68 @@ func (v *schemav2Validator) buildActionIndex(ctx context.Context, doc *openapi3.
return actionSchemas
}
// validateReferencedSchemas validates all objects with @context against their schemas.
func (v *schemav2Validator) validateReferencedSchemas(ctx context.Context, body interface{}) error {
// Extract "message" object - only scan inside message, not root
bodyMap, ok := body.(map[string]interface{})
if !ok {
return fmt.Errorf("body is not a valid JSON object")
}
message, hasMessage := bodyMap["message"]
if !hasMessage {
return fmt.Errorf("missing 'message' field in request body")
}
// Find all objects with @context starting from message
objects := findReferencedObjects(message, "message")
if len(objects) == 0 {
log.Debugf(ctx, "No objects with @context found in message, skipping LEVEL 2 validation")
return nil
}
log.Debugf(ctx, "Found %d objects with @context for LEVEL 2 validation", len(objects))
// Get config with defaults
urlTransform := "context.jsonld->attributes.yaml"
ttl := 86400 * time.Second // 24 hours default
timeout := 30 * time.Second
var allowedDomains []string
refConfig := v.config.ReferencedSchemaConfig
if refConfig.URLTransform != "" {
urlTransform = refConfig.URLTransform
}
if refConfig.CacheTTL > 0 {
ttl = time.Duration(refConfig.CacheTTL) * time.Second
}
if refConfig.DownloadTimeout > 0 {
timeout = time.Duration(refConfig.DownloadTimeout) * time.Second
}
allowedDomains = refConfig.AllowedDomains
log.Debugf(ctx, "LEVEL 2 config: urlTransform=%s, ttl=%v, timeout=%v, allowedDomains=%v",
urlTransform, ttl, timeout, allowedDomains)
// Validate each object and collect errors
var errors []string
for _, obj := range objects {
log.Debugf(ctx, "Validating object at path: %s, @context: %s, @type: %s",
obj.Path, obj.Context, obj.Type)
if err := v.schemaCache.validateReferencedObject(ctx, obj, urlTransform, ttl, timeout, allowedDomains); err != nil {
errors = append(errors, err.Error())
}
}
if len(errors) > 0 {
return fmt.Errorf("validation errors:\n - %s", strings.Join(errors, "\n - "))
}
return nil
}
// extractActionFromSchema extracts the action value from a schema.
func (v *schemav2Validator) extractActionFromSchema(schema *openapi3.Schema) string {
// Check direct properties
@@ -360,3 +512,282 @@ func (v *schemav2Validator) getActionValue(contextSchema *openapi3.Schema) strin
return ""
}
// newSchemaCache creates a new schema cache.
func newSchemaCache(maxSize int) *schemaCache {
return &schemaCache{
schemas: make(map[string]*cachedDomainSchema),
maxSize: maxSize,
}
}
// hashURL creates a SHA256 hash of the URL for use as cache key.
func hashURL(urlStr string) string {
hash := sha256.Sum256([]byte(urlStr))
return hex.EncodeToString(hash[:])
}
// isValidSchemaPath validates if the schema path is safe to load.
func isValidSchemaPath(schemaPath string) bool {
u, err := url.Parse(schemaPath)
if err != nil {
// Could be a simple file path
return schemaPath != ""
}
// Support: http://, https://, file://, or no scheme (local path)
return u.Scheme == "http" || u.Scheme == "https" ||
u.Scheme == "file" || u.Scheme == ""
}
// get retrieves a cached schema and updates access tracking.
func (c *schemaCache) get(urlHash string) (*openapi3.T, bool) {
c.mu.Lock()
defer c.mu.Unlock()
cached, exists := c.schemas[urlHash]
if !exists || time.Now().After(cached.expiresAt) {
return nil, false
}
// Update access tracking
cached.lastAccessed = time.Now()
cached.accessCount++
return cached.doc, true
}
// set stores a schema in the cache with TTL and LRU eviction.
func (c *schemaCache) set(urlHash string, doc *openapi3.T, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
// LRU eviction if cache is full
if len(c.schemas) >= c.maxSize {
var oldest string
var oldestTime time.Time
for k, v := range c.schemas {
if oldest == "" || v.lastAccessed.Before(oldestTime) {
oldest, oldestTime = k, v.lastAccessed
}
}
if oldest != "" {
delete(c.schemas, oldest)
}
}
c.schemas[urlHash] = &cachedDomainSchema{
doc: doc,
loadedAt: time.Now(),
expiresAt: time.Now().Add(ttl),
lastAccessed: time.Now(),
accessCount: 1,
}
}
// cleanupExpired removes expired schemas from cache.
func (c *schemaCache) cleanupExpired() int {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
expired := make([]string, 0)
for urlHash, cached := range c.schemas {
if now.After(cached.expiresAt) {
expired = append(expired, urlHash)
}
}
for _, urlHash := range expired {
delete(c.schemas, urlHash)
}
return len(expired)
}
// loadSchemaFromPath loads a schema from URL or local file with timeout and caching.
func (c *schemaCache) loadSchemaFromPath(ctx context.Context, schemaPath string, ttl, timeout time.Duration) (*openapi3.T, error) {
urlHash := hashURL(schemaPath)
// Check cache first
if doc, found := c.get(urlHash); found {
log.Debugf(ctx, "Schema cache hit for: %s", schemaPath)
return doc, nil
}
log.Debugf(ctx, "Schema cache miss, loading from: %s", schemaPath)
// Validate path format
if !isValidSchemaPath(schemaPath) {
return nil, fmt.Errorf("invalid schema path: %s", schemaPath)
}
loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
var doc *openapi3.T
var err error
u, parseErr := url.Parse(schemaPath)
if parseErr == nil && (u.Scheme == "http" || u.Scheme == "https") {
// Load from URL with timeout
loadCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
loader.Context = loadCtx
doc, err = loader.LoadFromURI(u)
} else {
// Load from local file (file:// or path)
filePath := schemaPath
if u != nil && u.Scheme == "file" {
filePath = u.Path
}
doc, err = loader.LoadFromFile(filePath)
}
if err != nil {
log.Errorf(ctx, err, "Failed to load schema from: %s", schemaPath)
return nil, fmt.Errorf("failed to load schema from %s: %w", schemaPath, err)
}
// Validate loaded schema (non-blocking, just log warnings)
if err := doc.Validate(ctx); err != nil {
log.Debugf(ctx, "Schema validation warnings for %s: %v", schemaPath, err)
}
c.set(urlHash, doc, ttl)
log.Debugf(ctx, "Loaded and cached schema from: %s", schemaPath)
return doc, nil
}
// findReferencedObjects recursively finds ALL objects with @context in the data.
func findReferencedObjects(data interface{}, path string) []referencedObject {
var results []referencedObject
switch v := data.(type) {
case map[string]interface{}:
// Check for @context and @type
if contextVal, hasContext := v["@context"].(string); hasContext {
if typeVal, hasType := v["@type"].(string); hasType {
results = append(results, referencedObject{
Path: path,
Context: contextVal,
Type: typeVal,
Data: v,
})
}
}
// Recurse into nested objects
for key, val := range v {
newPath := key
if path != "" {
newPath = path + "." + key
}
results = append(results, findReferencedObjects(val, newPath)...)
}
case []interface{}:
// Recurse into arrays
for i, item := range v {
newPath := fmt.Sprintf("%s[%d]", path, i)
results = append(results, findReferencedObjects(item, newPath)...)
}
}
return results
}
// transformContextToSchemaURL transforms @context URL to schema URL.
func transformContextToSchemaURL(contextURL, transform string) string {
parts := strings.Split(transform, "->")
if len(parts) != 2 {
// Default transformation
return strings.Replace(contextURL, "context.jsonld", "attributes.yaml", 1)
}
return strings.Replace(contextURL, strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]), 1)
}
// findSchemaByType finds a schema in the document by @type value.
func findSchemaByType(doc *openapi3.T, typeName string) (*openapi3.SchemaRef, error) {
if doc.Components == nil || doc.Components.Schemas == nil {
return nil, fmt.Errorf("no schemas found in document")
}
// Try direct match by schema name
if schema, exists := doc.Components.Schemas[typeName]; exists {
return schema, nil
}
// Fallback: Try x-jsonld.@type match
for _, schema := range doc.Components.Schemas {
if schema.Value == nil {
continue
}
if xJsonld, ok := schema.Value.Extensions["x-jsonld"].(map[string]interface{}); ok {
if atType, ok := xJsonld["@type"].(string); ok && atType == typeName {
return schema, nil
}
}
}
return nil, fmt.Errorf("no schema found for @type: %s", typeName)
}
// isAllowedDomain checks if the URL domain is in the whitelist.
func isAllowedDomain(schemaURL string, allowedDomains []string) bool {
if len(allowedDomains) == 0 {
return true // No whitelist = all allowed
}
for _, domain := range allowedDomains {
if strings.Contains(schemaURL, domain) {
return true
}
}
return false
}
// validateReferencedObject validates a single object with @context.
func (c *schemaCache) validateReferencedObject(
ctx context.Context,
obj referencedObject,
urlTransform string,
ttl, timeout time.Duration,
allowedDomains []string,
) error {
// Domain whitelist check
if !isAllowedDomain(obj.Context, allowedDomains) {
log.Warnf(ctx, "Domain not in whitelist: %s", obj.Context)
return fmt.Errorf("domain not allowed: %s", obj.Context)
}
// Transform @context to schema path (URL or file)
schemaPath := transformContextToSchemaURL(obj.Context, urlTransform)
log.Debugf(ctx, "Transformed %s -> %s", obj.Context, schemaPath)
// Load schema with timeout (supports URL or local file)
doc, err := c.loadSchemaFromPath(ctx, schemaPath, ttl, timeout)
if err != nil {
return fmt.Errorf("at %s: %w", obj.Path, err)
}
// Find schema by @type
schema, err := findSchemaByType(doc, obj.Type)
if err != nil {
log.Errorf(ctx, err, "Schema not found for @type: %s at path: %s", obj.Type, obj.Path)
return fmt.Errorf("at %s: %w", obj.Path, err)
}
// Validate object against schema (same options as Level 1)
opts := []openapi3.SchemaValidationOption{
openapi3.VisitAsRequest(),
openapi3.EnableFormatValidation(),
}
if err := schema.Value.VisitJSON(obj.Data, opts...); err != nil {
log.Debugf(ctx, "Validation failed for @type: %s at path: %s: %v", obj.Type, obj.Path, err)
return fmt.Errorf("at %s: %w", obj.Path, err)
}
log.Debugf(ctx, "Validation passed for @type: %s at path: %s", obj.Type, obj.Path)
return nil
}