Files
onix/pkg/plugin/implementation/schemav2validator/schemav2validator_test.go
2025-11-14 13:06:31 +05:30

301 lines
7.4 KiB
Go

package schemav2validator
import (
"context"
"net/http"
"net/http/httptest"
"os"
"testing"
)
const testSpec = `openapi: 3.1.0
info:
title: Test API
version: 1.0.0
paths:
/search:
post:
requestBody:
required: true
content:
application/json:
schema:
type: object
required: [context, message]
properties:
context:
type: object
required: [action]
properties:
action:
const: search
domain:
type: string
message:
type: object
/select:
post:
requestBody:
content:
application/json:
schema:
type: object
required: [context, message]
properties:
context:
allOf:
- type: object
properties:
action:
enum: [select]
message:
type: object
required: [order]
properties:
order:
type: object
`
func TestNew(t *testing.T) {
tests := []struct {
name string
config *Config
wantErr bool
}{
{"nil config", nil, true},
{"empty type", &Config{Type: "", Location: "http://example.com"}, true},
{"empty location", &Config{Type: "url", Location: ""}, true},
{"invalid type", &Config{Type: "invalid", Location: "http://example.com"}, true},
{"invalid URL", &Config{Type: "url", Location: "http://invalid-domain-12345.com/spec.yaml"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, err := New(context.Background(), tt.config)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestValidate_ActionExtraction(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testSpec))
}))
defer server.Close()
validator, _, err := New(context.Background(), &Config{Type: "url", Location: server.URL, CacheTTL: 3600})
if err != nil {
t.Fatalf("Failed to create validator: %v", err)
}
tests := []struct {
name string
payload string
wantErr bool
errMsg string
}{
{
name: "valid search action",
payload: `{"context":{"action":"search","domain":"retail"},"message":{}}`,
wantErr: false,
},
{
name: "valid select action with allOf",
payload: `{"context":{"action":"select"},"message":{"order":{}}}`,
wantErr: false,
},
{
name: "missing action",
payload: `{"context":{},"message":{}}`,
wantErr: true,
errMsg: "missing field Action",
},
{
name: "unsupported action",
payload: `{"context":{"action":"unknown"},"message":{}}`,
wantErr: true,
errMsg: "unsupported action: unknown",
},
{
name: "action as number",
payload: `{"context":{"action":123},"message":{}}`,
wantErr: true,
errMsg: "failed to parse JSON payload",
},
{
name: "invalid JSON",
payload: `{invalid json}`,
wantErr: true,
errMsg: "failed to parse JSON payload",
},
{
name: "missing required field",
payload: `{"context":{"action":"search"}}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.Validate(context.Background(), nil, []byte(tt.payload))
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr && tt.errMsg != "" && err != nil {
if !contains(err.Error(), tt.errMsg) {
t.Errorf("Validate() error = %v, want error containing %v", err, tt.errMsg)
}
}
})
}
}
func TestValidate_NestedValidation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testSpec))
}))
defer server.Close()
validator, _, err := New(context.Background(), &Config{Type: "url", Location: server.URL, CacheTTL: 3600})
if err != nil {
t.Fatalf("Failed to create validator: %v", err)
}
tests := []struct {
name string
payload string
wantErr bool
}{
{
name: "select missing required order",
payload: `{"context":{"action":"select"},"message":{}}`,
wantErr: true,
},
{
name: "select with order",
payload: `{"context":{"action":"select"},"message":{"order":{}}}`,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.Validate(context.Background(), nil, []byte(tt.payload))
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestLoadSpec_LocalFile(t *testing.T) {
tmpFile, err := os.CreateTemp("", "test-spec-*.yaml")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.Write([]byte(testSpec)); err != nil {
t.Fatalf("Failed to write temp file: %v", err)
}
tmpFile.Close()
validator, _, err := New(context.Background(), &Config{Type: "file", Location: tmpFile.Name(), CacheTTL: 3600})
if err != nil {
t.Fatalf("Failed to load local spec: %v", err)
}
validator.specMutex.RLock()
defer validator.specMutex.RUnlock()
if validator.spec == nil || validator.spec.doc == nil {
t.Error("Spec not loaded from local file")
}
}
func TestCacheTTL_DefaultValue(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testSpec))
}))
defer server.Close()
validator, _, err := New(context.Background(), &Config{Type: "url", Location: server.URL})
if err != nil {
t.Fatalf("Failed to create validator: %v", err)
}
if validator.config.CacheTTL != 3600 {
t.Errorf("Expected default CacheTTL 3600, got %d", validator.config.CacheTTL)
}
}
func TestValidate_EdgeCases(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(testSpec))
}))
defer server.Close()
validator, _, err := New(context.Background(), &Config{Type: "url", Location: server.URL, CacheTTL: 3600})
if err != nil {
t.Fatalf("Failed to create validator: %v", err)
}
tests := []struct {
name string
payload string
wantErr bool
}{
{
name: "empty payload",
payload: `{}`,
wantErr: true,
},
{
name: "null context",
payload: `{"context":null,"message":{}}`,
wantErr: true,
},
{
name: "empty string action",
payload: `{"context":{"action":""},"message":{}}`,
wantErr: true,
},
{
name: "action with whitespace",
payload: `{"context":{"action":" search "},"message":{}}`,
wantErr: true,
},
{
name: "case sensitive action",
payload: `{"context":{"action":"Search"},"message":{}}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.Validate(context.Background(), nil, []byte(tt.payload))
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func contains(s, substr string) bool {
if len(substr) == 0 {
return true
}
if len(s) < len(substr) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}