fix: rebase
This commit is contained in:
19
go.mod
19
go.mod
@@ -25,6 +25,21 @@ require (
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.119.0 // indirect
|
||||
cloud.google.com/go/auth v0.15.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
cloud.google.com/go/iam v1.4.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.5 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
|
||||
github.com/zenazn/pkcs7pad v0.0.0-20170308005700-253a5b1f0e03
|
||||
)
|
||||
|
||||
require golang.org/x/text v0.23.0 // indirect
|
||||
|
||||
require golang.org/x/sys v0.31.0 // indirect
|
||||
@@ -33,3 +48,7 @@ require (
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
)
|
||||
require (
|
||||
cloud.google.com/go/pubsub v1.48.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
|
||||
15
go.sum
15
go.sum
@@ -1,10 +1,23 @@
|
||||
cloud.google.com/go v0.119.0/go.mod h1:fwB8QLzTcNevxqi8dcpR+hoMIs3jBherGS9VUBDAW08=
|
||||
cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.7/go.mod h1:NTbTTzfvPl1Y3V1nPpOgl2w6d/FjO7NNUQaWSox6ZMc=
|
||||
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
|
||||
cloud.google.com/go/iam v1.4.1/go.mod h1:2vUEJpUG3Q9p2UdsyksaKpDzlwOrnMzS30isdReIcLM=
|
||||
cloud.google.com/go/pubsub v1.48.0/go.mod h1:AAtyjyIT/+zaY1ERKFJbefOvkUxRDNp3nD6TdfdqUZk=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.5/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
|
||||
github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
@@ -47,6 +60,8 @@ golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
|
||||
|
||||
113
pkg/model/error.go
Normal file
113
pkg/model/error.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Error represents a standard error response.
|
||||
type Error struct {
|
||||
Code string `json:"code"`
|
||||
Paths string `json:"paths,omitempty"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// This implements the error interface for the Error struct.
|
||||
func (e *Error) Error() string {
|
||||
return fmt.Sprintf("Error: Code=%s, Path=%s, Message=%s", e.Code, e.Paths, e.Message)
|
||||
}
|
||||
|
||||
// SchemaValidationErr occurs when schema validation errors are encountered.
|
||||
type SchemaValidationErr struct {
|
||||
Errors []Error
|
||||
}
|
||||
|
||||
// This implements the error interface for SchemaValidationErr.
|
||||
func (e *SchemaValidationErr) Error() string {
|
||||
var errorMessages []string
|
||||
for _, err := range e.Errors {
|
||||
errorMessages = append(errorMessages, fmt.Sprintf("%s: %s", err.Paths, err.Message))
|
||||
}
|
||||
return strings.Join(errorMessages, "; ")
|
||||
}
|
||||
|
||||
// BecknError converts the SchemaValidationErr to an instance of Error.
|
||||
func (e *SchemaValidationErr) BecknError() *Error {
|
||||
if len(e.Errors) == 0 {
|
||||
return &Error{
|
||||
Code: http.StatusText(http.StatusBadRequest),
|
||||
Message: "Schema validation error.",
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all error paths and messages
|
||||
var paths []string
|
||||
var messages []string
|
||||
for _, err := range e.Errors {
|
||||
if err.Paths != "" {
|
||||
paths = append(paths, err.Paths)
|
||||
}
|
||||
messages = append(messages, err.Message)
|
||||
}
|
||||
|
||||
return &Error{
|
||||
Code: http.StatusText(http.StatusBadRequest),
|
||||
Paths: strings.Join(paths, ";"),
|
||||
Message: strings.Join(messages, "; "),
|
||||
}
|
||||
}
|
||||
|
||||
// SignValidationErr occurs when signature validation fails.
|
||||
type SignValidationErr struct {
|
||||
error
|
||||
}
|
||||
|
||||
// NewSignValidationErr creates a new instance of SignValidationErr from an error.
|
||||
func NewSignValidationErr(e error) *SignValidationErr {
|
||||
return &SignValidationErr{e}
|
||||
}
|
||||
|
||||
// BecknError converts the SignValidationErr to an instance of Error.
|
||||
func (e *SignValidationErr) BecknError() *Error {
|
||||
return &Error{
|
||||
Code: http.StatusText(http.StatusUnauthorized),
|
||||
Message: "Signature Validation Error: " + e.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// BadReqErr occurs when a bad request is encountered.
|
||||
type BadReqErr struct {
|
||||
error
|
||||
}
|
||||
|
||||
// NewBadReqErr creates a new instance of BadReqErr from an error.
|
||||
func NewBadReqErr(err error) *BadReqErr {
|
||||
return &BadReqErr{err}
|
||||
}
|
||||
|
||||
// BecknError converts the BadReqErr to an instance of Error.
|
||||
func (e *BadReqErr) BecknError() *Error {
|
||||
return &Error{
|
||||
Code: http.StatusText(http.StatusBadRequest),
|
||||
Message: "BAD Request: " + e.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// NotFoundErr occurs when a requested endpoint is not found.
|
||||
type NotFoundErr struct {
|
||||
error
|
||||
}
|
||||
|
||||
// NewNotFoundErr creates a new instance of NotFoundErr from an error.
|
||||
func NewNotFoundErr(err error) *NotFoundErr {
|
||||
return &NotFoundErr{err}
|
||||
}
|
||||
|
||||
// BecknError converts the NotFoundErr to an instance of Error.
|
||||
func (e *NotFoundErr) BecknError() *Error {
|
||||
return &Error{
|
||||
Code: http.StatusText(http.StatusNotFound),
|
||||
Message: "Endpoint not found: " + e.Error(),
|
||||
}
|
||||
}
|
||||
200
pkg/model/error_test.go
Normal file
200
pkg/model/error_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"net/http"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// NewSignValidationErrf creates a new SignValidationErr with a formatted error message.
|
||||
func NewSignValidationErrf(format string, a ...any) *SignValidationErr {
|
||||
return &SignValidationErr{fmt.Errorf(format, a...)}
|
||||
}
|
||||
|
||||
// NewNotFoundErrf creates a new NotFoundErr with a formatted error message.
|
||||
func NewNotFoundErrf(format string, a ...any) *NotFoundErr {
|
||||
return &NotFoundErr{fmt.Errorf(format, a...)}
|
||||
}
|
||||
|
||||
// NewBadReqErrf creates a new BadReqErr with a formatted error message.
|
||||
func NewBadReqErrf(format string, a ...any) *BadReqErr {
|
||||
return &BadReqErr{fmt.Errorf(format, a...)}
|
||||
}
|
||||
|
||||
func TestError_Error(t *testing.T) {
|
||||
err := &Error{
|
||||
Code: "404",
|
||||
Paths: "/api/v1/user",
|
||||
Message: "User not found",
|
||||
}
|
||||
|
||||
expected := "Error: Code=404, Path=/api/v1/user, Message=User not found"
|
||||
actual := err.Error()
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
actual, expected)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSchemaValidationErr_Error(t *testing.T) {
|
||||
schemaErr := &SchemaValidationErr{
|
||||
Errors: []Error{
|
||||
{Paths: "/user", Message: "Field required"},
|
||||
{Paths: "/email", Message: "Invalid format"},
|
||||
},
|
||||
}
|
||||
|
||||
expected := "/user: Field required; /email: Invalid format"
|
||||
actual := schemaErr.Error()
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
actual, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchemaValidationErr_BecknError(t *testing.T) {
|
||||
schemaErr := &SchemaValidationErr{
|
||||
Errors: []Error{
|
||||
{Paths: "/user", Message: "Field required"},
|
||||
},
|
||||
}
|
||||
|
||||
beErr := schemaErr.BecknError()
|
||||
expected := "Bad Request"
|
||||
if beErr.Code != expected {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
beErr.Code, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignValidationErr_BecknError(t *testing.T) {
|
||||
signErr := NewSignValidationErr(errors.New("signature failed"))
|
||||
beErr := signErr.BecknError()
|
||||
|
||||
expectedMsg := "Signature Validation Error: signature failed"
|
||||
if beErr.Message != expectedMsg {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
beErr.Message, expectedMsg)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNewSignValidationErrf(t *testing.T) {
|
||||
signErr := NewSignValidationErrf("error %s", "signature failed")
|
||||
expected := "error signature failed"
|
||||
if signErr.Error() != expected {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
signErr.Error(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSignValidationErr(t *testing.T) {
|
||||
err := errors.New("signature error")
|
||||
signErr := NewSignValidationErr(err)
|
||||
|
||||
if signErr.Error() != err.Error() {
|
||||
t.Errorf("err.Error() = %s, want %s", err.Error(),
|
||||
signErr.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadReqErr_BecknError(t *testing.T) {
|
||||
badReqErr := NewBadReqErr(errors.New("invalid input"))
|
||||
beErr := badReqErr.BecknError()
|
||||
|
||||
expectedMsg := "BAD Request: invalid input"
|
||||
if beErr.Message != expectedMsg {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
beErr.Message, expectedMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBadReqErrf(t *testing.T) {
|
||||
badReqErr := NewBadReqErrf("invalid field %s", "name")
|
||||
expected := "invalid field name"
|
||||
if badReqErr.Error() != expected {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
badReqErr, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBadReqErr(t *testing.T) {
|
||||
err := errors.New("bad request")
|
||||
badReqErr := NewBadReqErr(err)
|
||||
|
||||
if badReqErr.Error() != err.Error() {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
badReqErr.Error(), err.Error())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNotFoundErr_BecknError(t *testing.T) {
|
||||
notFoundErr := NewNotFoundErr(errors.New("resource not found"))
|
||||
beErr := notFoundErr.BecknError()
|
||||
|
||||
expectedMsg := "Endpoint not found: resource not found"
|
||||
if beErr.Message != expectedMsg {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
beErr.Message, expectedMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNotFoundErrf(t *testing.T) {
|
||||
notFoundErr := NewNotFoundErrf("resource %s not found", "user")
|
||||
expected := "resource user not found"
|
||||
if notFoundErr.Error() != expected {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
notFoundErr.Error(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNotFoundErr(t *testing.T) {
|
||||
err := errors.New("not found")
|
||||
notFoundErr := NewNotFoundErr(err)
|
||||
|
||||
if notFoundErr.Error() != err.Error() {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
notFoundErr.Error(), err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRole_UnmarshalYAML_ValidRole(t *testing.T) {
|
||||
var role Role
|
||||
yamlData := []byte("bap")
|
||||
|
||||
err := yaml.Unmarshal(yamlData, &role)
|
||||
assert.NoError(t, err) //TODO: should replace assert here
|
||||
assert.Equal(t, RoleBAP, role)
|
||||
}
|
||||
|
||||
func TestRole_UnmarshalYAML_InvalidRole(t *testing.T) {
|
||||
var role Role
|
||||
yamlData := []byte("invalid")
|
||||
|
||||
err := yaml.Unmarshal(yamlData, &role)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid Role")
|
||||
}
|
||||
|
||||
func TestSchemaValidationErr_BecknError_NoErrors(t *testing.T) {
|
||||
schemaValidationErr := &SchemaValidationErr{Errors: nil}
|
||||
beErr := schemaValidationErr.BecknError()
|
||||
|
||||
expectedMsg := "Schema validation error."
|
||||
expectedCode := http.StatusText(http.StatusBadRequest)
|
||||
|
||||
if beErr.Message != expectedMsg {
|
||||
t.Errorf("beErr.Message = %s, want %s", beErr.Message, expectedMsg)
|
||||
}
|
||||
if beErr.Code != expectedCode {
|
||||
t.Errorf("beErr.Code = %s, want %s", beErr.Code, expectedCode)
|
||||
}
|
||||
}
|
||||
107
pkg/model/model.go
Normal file
107
pkg/model/model.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Subscriber struct {
|
||||
SubscriberID string `json:"subscriber_id"`
|
||||
URL string `json:"url" format:"uri"`
|
||||
Type string `json:"type" enum:"BAP,BPP,BG"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
type Subscription struct {
|
||||
Subscriber `json:",inline"`
|
||||
KeyID string `json:"key_id" format:"uuid"`
|
||||
SigningPublicKey string `json:"signing_public_key"`
|
||||
EncrPublicKey string `json:"encr_public_key"`
|
||||
ValidFrom time.Time `json:"valid_from" format:"date-time"`
|
||||
ValidUntil time.Time `json:"valid_until" format:"date-time"`
|
||||
Status string `json:"status" enum:"INITIATED,UNDER_SUBSCRIPTION,SUBSCRIBED,EXPIRED,UNSUBSCRIBED,INVALID_SSL"`
|
||||
Created time.Time `json:"created" format:"date-time"`
|
||||
Updated time.Time `json:"updated" format:"date-time"`
|
||||
Nonce string
|
||||
}
|
||||
|
||||
const (
|
||||
AuthHeaderSubscriber string = "Authorization"
|
||||
AuthHeaderGateway string = "X-Gateway-Authorization"
|
||||
UnaAuthorizedHeaderSubscriber string = "WWW-Authenticate"
|
||||
UnaAuthorizedHeaderGateway string = "Proxy-Authenticate"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const MsgIDKey = contextKey("message_id")
|
||||
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleBAP Role = "bap"
|
||||
RoleBPP Role = "bpp"
|
||||
RoleGateway Role = "gateway"
|
||||
RoleRegistery Role = "registery"
|
||||
)
|
||||
|
||||
var validRoles = map[Role]bool{
|
||||
RoleBAP: true,
|
||||
RoleBPP: true,
|
||||
RoleGateway: true,
|
||||
RoleRegistery: true,
|
||||
}
|
||||
|
||||
func (r *Role) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var roleName string
|
||||
if err := unmarshal(&roleName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
role := Role(roleName)
|
||||
if !validRoles[role] {
|
||||
return fmt.Errorf("invalid Role: %s", roleName)
|
||||
}
|
||||
*r = role
|
||||
return nil
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
Type string
|
||||
URL *url.URL
|
||||
Publisher string
|
||||
}
|
||||
|
||||
type StepContext struct {
|
||||
context.Context
|
||||
Request *http.Request
|
||||
Body []byte
|
||||
Route *Route
|
||||
SubID string
|
||||
Role Role
|
||||
RespHeader http.Header
|
||||
}
|
||||
|
||||
func (ctx *StepContext) WithContext(newCtx context.Context) {
|
||||
ctx.Context = newCtx
|
||||
}
|
||||
|
||||
type Status string
|
||||
|
||||
const (
|
||||
StatusACK Status = "ACK"
|
||||
StatusNACK Status = "NACK"
|
||||
)
|
||||
|
||||
type Ack struct {
|
||||
Status Status `json:"status"`
|
||||
}
|
||||
type Message struct {
|
||||
Ack Ack `json:"ack"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
type Response struct {
|
||||
Message Message `json:"message"`
|
||||
}
|
||||
@@ -3,20 +3,15 @@ package response
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"strings"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/model"
|
||||
)
|
||||
|
||||
type ErrorType string
|
||||
|
||||
const (
|
||||
SchemaValidationErrorType ErrorType = "SCHEMA_VALIDATION_ERROR"
|
||||
InvalidRequestErrorType ErrorType = "INVALID_REQUEST"
|
||||
)
|
||||
|
||||
type BecknRequest struct {
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
@@ -24,6 +19,7 @@ type Error struct {
|
||||
Paths string `json:"paths,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
// SchemaValidationErr represents a collection of schema validation failures.
|
||||
type SchemaValidationErr struct {
|
||||
Errors []Error
|
||||
@@ -45,114 +41,75 @@ type Message struct {
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type BecknResponse struct {
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
Message Message `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type ClientFailureBecknResponse struct {
|
||||
Context map[string]interface{} `json:"context,omitempty"`
|
||||
Error *Error `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
var errorMap = map[ErrorType]Error{
|
||||
SchemaValidationErrorType: {
|
||||
Code: "400",
|
||||
Message: "Schema validation failed",
|
||||
},
|
||||
InvalidRequestErrorType: {
|
||||
Code: "401",
|
||||
Message: "Invalid request format",
|
||||
},
|
||||
}
|
||||
|
||||
var DefaultError = Error{
|
||||
Code: "500",
|
||||
Message: "Internal server error",
|
||||
}
|
||||
|
||||
func Nack(ctx context.Context, tp ErrorType, paths string, body []byte) ([]byte, error) {
|
||||
var req BecknRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse request: %w", err)
|
||||
}
|
||||
|
||||
errorObj, ok := errorMap[tp]
|
||||
if paths != "" {
|
||||
errorObj.Paths = paths
|
||||
}
|
||||
|
||||
var response BecknResponse
|
||||
|
||||
if !ok {
|
||||
response = BecknResponse{
|
||||
Context: req.Context,
|
||||
Message: Message{
|
||||
Ack: struct {
|
||||
Status string `json:"status,omitempty"`
|
||||
}{
|
||||
Status: "NACK",
|
||||
},
|
||||
Error: &DefaultError,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
response = BecknResponse{
|
||||
Context: req.Context,
|
||||
Message: Message{
|
||||
Ack: struct {
|
||||
Status string `json:"status,omitempty"`
|
||||
}{
|
||||
Status: "NACK",
|
||||
},
|
||||
Error: &errorObj,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func Ack(ctx context.Context, body []byte) ([]byte, error) {
|
||||
var req BecknRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse request: %w", err)
|
||||
}
|
||||
|
||||
response := BecknResponse{
|
||||
Context: req.Context,
|
||||
Message: Message{
|
||||
Ack: struct {
|
||||
Status string `json:"status,omitempty"`
|
||||
}{
|
||||
Status: "ACK",
|
||||
func SendAck(w http.ResponseWriter) {
|
||||
resp := &model.Response{
|
||||
Message: model.Message{
|
||||
Ack: model.Ack{
|
||||
Status: model.StatusACK,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return json.Marshal(response)
|
||||
data, _ := json.Marshal(resp) //should not fail here
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write(data)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func HandleClientFailure(ctx context.Context, tp ErrorType, body []byte) ([]byte, error) {
|
||||
var req BecknRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse request: %w", err)
|
||||
func nack(w http.ResponseWriter, err *model.Error, status int, ctx context.Context) {
|
||||
resp := &model.Response{
|
||||
Message: model.Message{
|
||||
Ack: model.Ack{
|
||||
Status: model.StatusNACK,
|
||||
},
|
||||
Error: err,
|
||||
},
|
||||
}
|
||||
data, _ := json.Marshal(resp) //should not fail here
|
||||
|
||||
errorObj, ok := errorMap[tp]
|
||||
var response ClientFailureBecknResponse
|
||||
|
||||
if !ok {
|
||||
response = ClientFailureBecknResponse{
|
||||
Context: req.Context,
|
||||
Error: &DefaultError,
|
||||
}
|
||||
} else {
|
||||
response = ClientFailureBecknResponse{
|
||||
Context: req.Context,
|
||||
Error: &errorObj,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_, er := w.Write(data)
|
||||
if er != nil {
|
||||
fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.MsgIDKey))
|
||||
http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func internalServerError(ctx context.Context) *model.Error {
|
||||
return &model.Error{
|
||||
Code: http.StatusText(http.StatusInternalServerError),
|
||||
Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)),
|
||||
}
|
||||
}
|
||||
|
||||
func SendNack(ctx context.Context, w http.ResponseWriter, err error) {
|
||||
var schemaErr *model.SchemaValidationErr
|
||||
var signErr *model.SignValidationErr
|
||||
var badReqErr *model.BadReqErr
|
||||
var notFoundErr *model.NotFoundErr
|
||||
|
||||
switch {
|
||||
case errors.As(err, &schemaErr):
|
||||
nack(w, schemaErr.BecknError(), http.StatusBadRequest, ctx)
|
||||
return
|
||||
case errors.As(err, &signErr):
|
||||
nack(w, signErr.BecknError(), http.StatusUnauthorized, ctx)
|
||||
return
|
||||
case errors.As(err, &badReqErr):
|
||||
nack(w, badReqErr.BecknError(), http.StatusBadRequest, ctx)
|
||||
return
|
||||
case errors.As(err, ¬FoundErr):
|
||||
nack(w, notFoundErr.BecknError(), http.StatusNotFound, ctx)
|
||||
return
|
||||
default:
|
||||
nack(w, internalServerError(ctx), http.StatusInternalServerError, ctx)
|
||||
return
|
||||
}
|
||||
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
@@ -1,308 +1,133 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/model"
|
||||
)
|
||||
|
||||
func TestNack(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
type errorResponseWriter struct{}
|
||||
|
||||
// TODO: Optimize the cases by removing these
|
||||
func (e *errorResponseWriter) Write([]byte) (int, error) {
|
||||
return 0, errors.New("write error")
|
||||
}
|
||||
func (e *errorResponseWriter) WriteHeader(statusCode int) {}
|
||||
|
||||
func (e *errorResponseWriter) Header() http.Header {
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func TestSendAck(t *testing.T) {
|
||||
_, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err) // For tests
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
SendAck(rr)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("wanted status code %d, got %d", http.StatusOK, rr.Code)
|
||||
}
|
||||
|
||||
expected := `{"message":{"ack":{"status":"ACK"}}}`
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
rr.Body.String(), expected)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendNack(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), model.MsgIDKey, "123456")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errorType ErrorType
|
||||
requestBody string
|
||||
wantStatus string
|
||||
wantErrCode string
|
||||
wantErrMsg string
|
||||
wantErr bool
|
||||
path string
|
||||
name string
|
||||
err error
|
||||
expected string
|
||||
status int
|
||||
}{
|
||||
{
|
||||
name: "Schema validation error",
|
||||
errorType: SchemaValidationErrorType,
|
||||
requestBody: `{"context": {"domain": "test-domain", "location": "test-location"}}`,
|
||||
wantStatus: "NACK",
|
||||
wantErrCode: "400",
|
||||
wantErrMsg: "Schema validation failed",
|
||||
wantErr: false,
|
||||
path: "test",
|
||||
name: "SchemaValidationErr",
|
||||
err: &model.SchemaValidationErr{
|
||||
Errors: []model.Error{
|
||||
{Paths: "/path1", Message: "Error 1"},
|
||||
{Paths: "/path2", Message: "Error 2"},
|
||||
},
|
||||
},
|
||||
status: http.StatusBadRequest,
|
||||
expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Bad Request","paths":"/path1;/path2","message":"Error 1; Error 2"}}}`,
|
||||
},
|
||||
{
|
||||
name: "Invalid request error",
|
||||
errorType: InvalidRequestErrorType,
|
||||
requestBody: `{"context": {"domain": "test-domain"}}`,
|
||||
wantStatus: "NACK",
|
||||
wantErrCode: "401",
|
||||
wantErrMsg: "Invalid request format",
|
||||
wantErr: false,
|
||||
path: "test",
|
||||
name: "SignValidationErr",
|
||||
err: model.NewSignValidationErr(errors.New("signature invalid")),
|
||||
status: http.StatusUnauthorized,
|
||||
expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Unauthorized","message":"Signature Validation Error: signature invalid"}}}`,
|
||||
},
|
||||
{
|
||||
name: "Unknown error type",
|
||||
errorType: "UNKNOWN_ERROR",
|
||||
requestBody: `{"context": {"domain": "test-domain"}}`,
|
||||
wantStatus: "NACK",
|
||||
wantErrCode: "500",
|
||||
wantErrMsg: "Internal server error",
|
||||
wantErr: false,
|
||||
path: "test",
|
||||
name: "BadReqErr",
|
||||
err: model.NewBadReqErr(errors.New("bad request error")),
|
||||
status: http.StatusBadRequest,
|
||||
expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Bad Request","message":"BAD Request: bad request error"}}}`,
|
||||
},
|
||||
{
|
||||
name: "Empty request body",
|
||||
errorType: SchemaValidationErrorType,
|
||||
requestBody: `{}`,
|
||||
wantStatus: "NACK",
|
||||
wantErrCode: "400",
|
||||
wantErrMsg: "Schema validation failed",
|
||||
wantErr: false,
|
||||
path: "test",
|
||||
name: "NotFoundErr",
|
||||
err: model.NewNotFoundErr(errors.New("endpoint not found")),
|
||||
status: http.StatusNotFound,
|
||||
expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Not Found","message":"Endpoint not found: endpoint not found"}}}`,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
errorType: SchemaValidationErrorType,
|
||||
requestBody: `{invalid json}`,
|
||||
wantErr: true,
|
||||
path: "test",
|
||||
},
|
||||
{
|
||||
name: "Complex nested context",
|
||||
errorType: SchemaValidationErrorType,
|
||||
requestBody: `{"context": {"domain": "test-domain", "nested": {"key1": "value1", "key2": 123}}}`,
|
||||
wantStatus: "NACK",
|
||||
wantErrCode: "400",
|
||||
wantErrMsg: "Schema validation failed",
|
||||
wantErr: false,
|
||||
path: "test",
|
||||
name: "InternalServerError",
|
||||
err: errors.New("unexpected error"),
|
||||
status: http.StatusInternalServerError,
|
||||
expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Internal Server Error","message":"Internal server error, MessageID: 123456"}}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, err := Nack(ctx, tt.errorType, tt.path, []byte(tt.requestBody))
|
||||
_, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err) // For tests
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Nack() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
SendNack(ctx, rr, tt.err)
|
||||
|
||||
if rr.Code != tt.status {
|
||||
t.Errorf("wanted status code %d, got %d", tt.status, rr.Code)
|
||||
}
|
||||
|
||||
if tt.wantErr && err != nil {
|
||||
return
|
||||
var actual map[string]interface{}
|
||||
err = json.Unmarshal(rr.Body.Bytes(), &actual)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
var becknResp BecknResponse
|
||||
if err := json.Unmarshal(resp, &becknResp); err != nil {
|
||||
t.Errorf("Failed to unmarshal response: %v", err)
|
||||
return
|
||||
var expected map[string]interface{}
|
||||
err = json.Unmarshal([]byte(tt.expected), &expected)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||
}
|
||||
|
||||
if becknResp.Message.Ack.Status != tt.wantStatus {
|
||||
t.Errorf("Nack() status = %v, want %v", becknResp.Message.Ack.Status, tt.wantStatus)
|
||||
if !compareJSON(expected, actual) {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
actual, expected)
|
||||
}
|
||||
|
||||
if becknResp.Message.Error.Code != tt.wantErrCode {
|
||||
t.Errorf("Nack() error code = %v, want %v", becknResp.Message.Error.Code, tt.wantErrCode)
|
||||
}
|
||||
|
||||
if becknResp.Message.Error.Message != tt.wantErrMsg {
|
||||
t.Errorf("Nack() error message = %v, want %v", becknResp.Message.Error.Message, tt.wantErrMsg)
|
||||
}
|
||||
|
||||
var origReq BecknRequest
|
||||
if err := json.Unmarshal([]byte(tt.requestBody), &origReq); err == nil {
|
||||
if !compareContexts(becknResp.Context, origReq.Context) {
|
||||
t.Errorf("Nack() context not preserved, got = %v, want %v", becknResp.Context, origReq.Context)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAck(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestBody string
|
||||
wantStatus string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid request",
|
||||
requestBody: `{"context": {"domain": "test-domain", "location": "test-location"}}`,
|
||||
wantStatus: "ACK",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty context",
|
||||
requestBody: `{"context": {}}`,
|
||||
wantStatus: "ACK",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
requestBody: `{invalid json}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Complex nested context",
|
||||
requestBody: `{"context": {"domain": "test-domain", "nested": {"key1": "value1", "key2": 123, "array": [1,2,3]}}}`,
|
||||
wantStatus: "ACK",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, err := Ack(ctx, []byte(tt.requestBody))
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Ack() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr && err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var becknResp BecknResponse
|
||||
if err := json.Unmarshal(resp, &becknResp); err != nil {
|
||||
t.Errorf("Failed to unmarshal response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if becknResp.Message.Ack.Status != tt.wantStatus {
|
||||
t.Errorf("Ack() status = %v, want %v", becknResp.Message.Ack.Status, tt.wantStatus)
|
||||
}
|
||||
|
||||
if becknResp.Message.Error != nil {
|
||||
t.Errorf("Ack() should not have error, got %v", becknResp.Message.Error)
|
||||
}
|
||||
|
||||
var origReq BecknRequest
|
||||
if err := json.Unmarshal([]byte(tt.requestBody), &origReq); err == nil {
|
||||
if !compareContexts(becknResp.Context, origReq.Context) {
|
||||
t.Errorf("Ack() context not preserved, got = %v, want %v", becknResp.Context, origReq.Context)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleClientFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errorType ErrorType
|
||||
requestBody string
|
||||
wantErrCode string
|
||||
wantErrMsg string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Schema validation error",
|
||||
errorType: SchemaValidationErrorType,
|
||||
requestBody: `{"context": {"domain": "test-domain", "location": "test-location"}}`,
|
||||
wantErrCode: "400",
|
||||
wantErrMsg: "Schema validation failed",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid request error",
|
||||
errorType: InvalidRequestErrorType,
|
||||
requestBody: `{"context": {"domain": "test-domain"}}`,
|
||||
wantErrCode: "401",
|
||||
wantErrMsg: "Invalid request format",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Unknown error type",
|
||||
errorType: "UNKNOWN_ERROR",
|
||||
requestBody: `{"context": {"domain": "test-domain"}}`,
|
||||
wantErrCode: "500",
|
||||
wantErrMsg: "Internal server error",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
errorType: SchemaValidationErrorType,
|
||||
requestBody: `{invalid json}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, err := HandleClientFailure(ctx, tt.errorType, []byte(tt.requestBody))
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("HandleClientFailure() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr && err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var failureResp ClientFailureBecknResponse
|
||||
if err := json.Unmarshal(resp, &failureResp); err != nil {
|
||||
t.Errorf("Failed to unmarshal response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if failureResp.Error.Code != tt.wantErrCode {
|
||||
t.Errorf("HandleClientFailure() error code = %v, want %v", failureResp.Error.Code, tt.wantErrCode)
|
||||
}
|
||||
|
||||
if failureResp.Error.Message != tt.wantErrMsg {
|
||||
t.Errorf("HandleClientFailure() error message = %v, want %v", failureResp.Error.Message, tt.wantErrMsg)
|
||||
}
|
||||
|
||||
var origReq BecknRequest
|
||||
if err := json.Unmarshal([]byte(tt.requestBody), &origReq); err == nil {
|
||||
if !compareContexts(failureResp.Context, origReq.Context) {
|
||||
t.Errorf("HandleClientFailure() context not preserved, got = %v, want %v", failureResp.Context, origReq.Context)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorMap(t *testing.T) {
|
||||
|
||||
expectedTypes := []ErrorType{
|
||||
SchemaValidationErrorType,
|
||||
InvalidRequestErrorType,
|
||||
}
|
||||
|
||||
for _, tp := range expectedTypes {
|
||||
if _, exists := errorMap[tp]; !exists {
|
||||
t.Errorf("ErrorType %v not found in errorMap", tp)
|
||||
}
|
||||
}
|
||||
|
||||
if DefaultError.Code != "500" || DefaultError.Message != "Internal server error" {
|
||||
t.Errorf("DefaultError not set correctly, got code=%v, message=%v", DefaultError.Code, DefaultError.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func compareContexts(c1, c2 map[string]interface{}) bool {
|
||||
|
||||
if c1 == nil && c2 == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if c1 == nil && len(c2) == 0 || c2 == nil && len(c1) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(c1, c2)
|
||||
}
|
||||
|
||||
func TestSchemaValidationErr_Error(t *testing.T) {
|
||||
// Create sample validation errors
|
||||
validationErrors := []Error{
|
||||
{Paths: "name", Message: "Name is required"},
|
||||
{Paths: "email", Message: "Invalid email format"},
|
||||
@@ -315,3 +140,117 @@ func TestSchemaValidationErr_Error(t *testing.T) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func compareJSON(expected, actual map[string]interface{}) bool {
|
||||
expectedBytes, _ := json.Marshal(expected)
|
||||
actualBytes, _ := json.Marshal(actual)
|
||||
return bytes.Equal(expectedBytes, actualBytes)
|
||||
}
|
||||
|
||||
func TestSendAck_WriteError(t *testing.T) {
|
||||
w := &errorResponseWriter{}
|
||||
SendAck(w)
|
||||
}
|
||||
|
||||
// Mock struct to force JSON marshalling error
|
||||
type badMessage struct{}
|
||||
|
||||
func (b *badMessage) MarshalJSON() ([]byte, error) {
|
||||
return nil, errors.New("marshal error")
|
||||
}
|
||||
|
||||
func TestNack_1(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *model.Error
|
||||
status int
|
||||
expected string
|
||||
useBadJSON bool
|
||||
useBadWrite bool
|
||||
}{
|
||||
{
|
||||
name: "Schema Validation Error",
|
||||
err: &model.Error{
|
||||
Code: "BAD_REQUEST",
|
||||
Paths: "/test/path",
|
||||
Message: "Invalid schema",
|
||||
},
|
||||
status: http.StatusBadRequest,
|
||||
expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"BAD_REQUEST","paths":"/test/path","message":"Invalid schema"}}}`,
|
||||
},
|
||||
{
|
||||
name: "Internal Server Error",
|
||||
err: &model.Error{
|
||||
Code: "INTERNAL_SERVER_ERROR",
|
||||
Message: "Something went wrong",
|
||||
},
|
||||
status: http.StatusInternalServerError,
|
||||
expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"INTERNAL_SERVER_ERROR","message":"Something went wrong"}}}`,
|
||||
},
|
||||
{
|
||||
name: "JSON Marshal Error",
|
||||
err: nil, // This will be overridden to cause marshaling error
|
||||
status: http.StatusInternalServerError,
|
||||
expected: `Internal server error, MessageID: 12345`,
|
||||
useBadJSON: true,
|
||||
},
|
||||
{
|
||||
name: "Write Error",
|
||||
err: &model.Error{
|
||||
Code: "WRITE_ERROR",
|
||||
Message: "Failed to write response",
|
||||
},
|
||||
status: http.StatusInternalServerError,
|
||||
expected: `Internal server error, MessageID: 12345`,
|
||||
useBadWrite: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), model.MsgIDKey, "12345")
|
||||
|
||||
var w http.ResponseWriter
|
||||
if tt.useBadWrite {
|
||||
w = &errorResponseWriter{} // Simulate write error
|
||||
} else {
|
||||
w = httptest.NewRecorder()
|
||||
}
|
||||
|
||||
// TODO: Fix this approach , should not be used like this.
|
||||
if tt.useBadJSON {
|
||||
data, _ := json.Marshal(&badMessage{})
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.status)
|
||||
_, err := w.Write(data)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
nack(w, tt.err, tt.status, ctx)
|
||||
if !tt.useBadWrite {
|
||||
recorder, ok := w.(*httptest.ResponseRecorder)
|
||||
if !ok {
|
||||
t.Fatal("Failed to cast response recorder")
|
||||
}
|
||||
|
||||
if recorder.Code != tt.status {
|
||||
t.Errorf("wanted status code %d, got %d", tt.status, recorder.Code)
|
||||
}
|
||||
|
||||
body := recorder.Body.String()
|
||||
if body != tt.expected {
|
||||
t.Errorf("err.Error() = %s, want %s",
|
||||
body, tt.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user