diff --git a/go.mod b/go.mod index fc5402a..dd6f0e8 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,21 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) +require ( + cloud.google.com/go v0.119.0 // indirect + cloud.google.com/go/auth v0.15.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect + cloud.google.com/go/iam v1.4.1 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.5 // indirect + github.com/googleapis/gax-go/v2 v2.14.1 // indirect + github.com/zenazn/pkcs7pad v0.0.0-20170308005700-253a5b1f0e03 +) + require golang.org/x/text v0.23.0 // indirect require golang.org/x/sys v0.31.0 // indirect @@ -33,3 +48,7 @@ require ( github.com/mattn/go-isatty v0.0.19 // indirect golang.org/x/sys v0.31.0 // indirect ) +require ( + cloud.google.com/go/pubsub v1.48.0 + gopkg.in/yaml.v2 v2.4.0 +) diff --git a/go.sum b/go.sum index ec6094b..7df543a 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,23 @@ +cloud.google.com/go v0.119.0/go.mod h1:fwB8QLzTcNevxqi8dcpR+hoMIs3jBherGS9VUBDAW08= +cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8= +cloud.google.com/go/auth/oauth2adapt v0.2.7/go.mod h1:NTbTTzfvPl1Y3V1nPpOgl2w6d/FjO7NNUQaWSox6ZMc= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/iam v1.4.1/go.mod h1:2vUEJpUG3Q9p2UdsyksaKpDzlwOrnMzS30isdReIcLM= +cloud.google.com/go/pubsub v1.48.0/go.mod h1:AAtyjyIT/+zaY1ERKFJbefOvkUxRDNp3nD6TdfdqUZk= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.5/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -47,6 +60,8 @@ golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/pkg/model/error.go b/pkg/model/error.go new file mode 100644 index 0000000..5cdee1d --- /dev/null +++ b/pkg/model/error.go @@ -0,0 +1,113 @@ +package model + +import ( + "fmt" + "net/http" + "strings" +) + +// Error represents a standard error response. +type Error struct { + Code string `json:"code"` + Paths string `json:"paths,omitempty"` + Message string `json:"message"` +} + +// This implements the error interface for the Error struct. +func (e *Error) Error() string { + return fmt.Sprintf("Error: Code=%s, Path=%s, Message=%s", e.Code, e.Paths, e.Message) +} + +// SchemaValidationErr occurs when schema validation errors are encountered. +type SchemaValidationErr struct { + Errors []Error +} + +// This implements the error interface for SchemaValidationErr. +func (e *SchemaValidationErr) Error() string { + var errorMessages []string + for _, err := range e.Errors { + errorMessages = append(errorMessages, fmt.Sprintf("%s: %s", err.Paths, err.Message)) + } + return strings.Join(errorMessages, "; ") +} + +// BecknError converts the SchemaValidationErr to an instance of Error. +func (e *SchemaValidationErr) BecknError() *Error { + if len(e.Errors) == 0 { + return &Error{ + Code: http.StatusText(http.StatusBadRequest), + Message: "Schema validation error.", + } + } + + // Collect all error paths and messages + var paths []string + var messages []string + for _, err := range e.Errors { + if err.Paths != "" { + paths = append(paths, err.Paths) + } + messages = append(messages, err.Message) + } + + return &Error{ + Code: http.StatusText(http.StatusBadRequest), + Paths: strings.Join(paths, ";"), + Message: strings.Join(messages, "; "), + } +} + +// SignValidationErr occurs when signature validation fails. +type SignValidationErr struct { + error +} + +// NewSignValidationErr creates a new instance of SignValidationErr from an error. +func NewSignValidationErr(e error) *SignValidationErr { + return &SignValidationErr{e} +} + +// BecknError converts the SignValidationErr to an instance of Error. +func (e *SignValidationErr) BecknError() *Error { + return &Error{ + Code: http.StatusText(http.StatusUnauthorized), + Message: "Signature Validation Error: " + e.Error(), + } +} + +// BadReqErr occurs when a bad request is encountered. +type BadReqErr struct { + error +} + +// NewBadReqErr creates a new instance of BadReqErr from an error. +func NewBadReqErr(err error) *BadReqErr { + return &BadReqErr{err} +} + +// BecknError converts the BadReqErr to an instance of Error. +func (e *BadReqErr) BecknError() *Error { + return &Error{ + Code: http.StatusText(http.StatusBadRequest), + Message: "BAD Request: " + e.Error(), + } +} + +// NotFoundErr occurs when a requested endpoint is not found. +type NotFoundErr struct { + error +} + +// NewNotFoundErr creates a new instance of NotFoundErr from an error. +func NewNotFoundErr(err error) *NotFoundErr { + return &NotFoundErr{err} +} + +// BecknError converts the NotFoundErr to an instance of Error. +func (e *NotFoundErr) BecknError() *Error { + return &Error{ + Code: http.StatusText(http.StatusNotFound), + Message: "Endpoint not found: " + e.Error(), + } +} diff --git a/pkg/model/error_test.go b/pkg/model/error_test.go new file mode 100644 index 0000000..ee295e6 --- /dev/null +++ b/pkg/model/error_test.go @@ -0,0 +1,200 @@ +package model + +import ( + "errors" + "fmt" + "testing" + "net/http" + + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v2" +) + +// NewSignValidationErrf creates a new SignValidationErr with a formatted error message. +func NewSignValidationErrf(format string, a ...any) *SignValidationErr { + return &SignValidationErr{fmt.Errorf(format, a...)} +} + +// NewNotFoundErrf creates a new NotFoundErr with a formatted error message. +func NewNotFoundErrf(format string, a ...any) *NotFoundErr { + return &NotFoundErr{fmt.Errorf(format, a...)} +} + +// NewBadReqErrf creates a new BadReqErr with a formatted error message. +func NewBadReqErrf(format string, a ...any) *BadReqErr { + return &BadReqErr{fmt.Errorf(format, a...)} +} + +func TestError_Error(t *testing.T) { + err := &Error{ + Code: "404", + Paths: "/api/v1/user", + Message: "User not found", + } + + expected := "Error: Code=404, Path=/api/v1/user, Message=User not found" + actual := err.Error() + + if actual != expected { + t.Errorf("err.Error() = %s, want %s", + actual, expected) + } + +} + +func TestSchemaValidationErr_Error(t *testing.T) { + schemaErr := &SchemaValidationErr{ + Errors: []Error{ + {Paths: "/user", Message: "Field required"}, + {Paths: "/email", Message: "Invalid format"}, + }, + } + + expected := "/user: Field required; /email: Invalid format" + actual := schemaErr.Error() + + if actual != expected { + t.Errorf("err.Error() = %s, want %s", + actual, expected) + } +} + +func TestSchemaValidationErr_BecknError(t *testing.T) { + schemaErr := &SchemaValidationErr{ + Errors: []Error{ + {Paths: "/user", Message: "Field required"}, + }, + } + + beErr := schemaErr.BecknError() + expected := "Bad Request" + if beErr.Code != expected { + t.Errorf("err.Error() = %s, want %s", + beErr.Code, expected) + } +} + +func TestSignValidationErr_BecknError(t *testing.T) { + signErr := NewSignValidationErr(errors.New("signature failed")) + beErr := signErr.BecknError() + + expectedMsg := "Signature Validation Error: signature failed" + if beErr.Message != expectedMsg { + t.Errorf("err.Error() = %s, want %s", + beErr.Message, expectedMsg) + } + +} + +func TestNewSignValidationErrf(t *testing.T) { + signErr := NewSignValidationErrf("error %s", "signature failed") + expected := "error signature failed" + if signErr.Error() != expected { + t.Errorf("err.Error() = %s, want %s", + signErr.Error(), expected) + } +} + +func TestNewSignValidationErr(t *testing.T) { + err := errors.New("signature error") + signErr := NewSignValidationErr(err) + + if signErr.Error() != err.Error() { + t.Errorf("err.Error() = %s, want %s", err.Error(), + signErr.Error()) + } +} + +func TestBadReqErr_BecknError(t *testing.T) { + badReqErr := NewBadReqErr(errors.New("invalid input")) + beErr := badReqErr.BecknError() + + expectedMsg := "BAD Request: invalid input" + if beErr.Message != expectedMsg { + t.Errorf("err.Error() = %s, want %s", + beErr.Message, expectedMsg) + } +} + +func TestNewBadReqErrf(t *testing.T) { + badReqErr := NewBadReqErrf("invalid field %s", "name") + expected := "invalid field name" + if badReqErr.Error() != expected { + t.Errorf("err.Error() = %s, want %s", + badReqErr, expected) + } +} + +func TestNewBadReqErr(t *testing.T) { + err := errors.New("bad request") + badReqErr := NewBadReqErr(err) + + if badReqErr.Error() != err.Error() { + t.Errorf("err.Error() = %s, want %s", + badReqErr.Error(), err.Error()) + } + +} + +func TestNotFoundErr_BecknError(t *testing.T) { + notFoundErr := NewNotFoundErr(errors.New("resource not found")) + beErr := notFoundErr.BecknError() + + expectedMsg := "Endpoint not found: resource not found" + if beErr.Message != expectedMsg { + t.Errorf("err.Error() = %s, want %s", + beErr.Message, expectedMsg) + } +} + +func TestNewNotFoundErrf(t *testing.T) { + notFoundErr := NewNotFoundErrf("resource %s not found", "user") + expected := "resource user not found" + if notFoundErr.Error() != expected { + t.Errorf("err.Error() = %s, want %s", + notFoundErr.Error(), expected) + } +} + +func TestNewNotFoundErr(t *testing.T) { + err := errors.New("not found") + notFoundErr := NewNotFoundErr(err) + + if notFoundErr.Error() != err.Error() { + t.Errorf("err.Error() = %s, want %s", + notFoundErr.Error(), err.Error()) + } +} + +func TestRole_UnmarshalYAML_ValidRole(t *testing.T) { + var role Role + yamlData := []byte("bap") + + err := yaml.Unmarshal(yamlData, &role) + assert.NoError(t, err) //TODO: should replace assert here + assert.Equal(t, RoleBAP, role) +} + +func TestRole_UnmarshalYAML_InvalidRole(t *testing.T) { + var role Role + yamlData := []byte("invalid") + + err := yaml.Unmarshal(yamlData, &role) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid Role") +} + +func TestSchemaValidationErr_BecknError_NoErrors(t *testing.T) { + schemaValidationErr := &SchemaValidationErr{Errors: nil} + beErr := schemaValidationErr.BecknError() + + expectedMsg := "Schema validation error." + expectedCode := http.StatusText(http.StatusBadRequest) + + if beErr.Message != expectedMsg { + t.Errorf("beErr.Message = %s, want %s", beErr.Message, expectedMsg) + } + if beErr.Code != expectedCode { + t.Errorf("beErr.Code = %s, want %s", beErr.Code, expectedCode) + } +} diff --git a/pkg/model/model.go b/pkg/model/model.go new file mode 100644 index 0000000..fcd8c65 --- /dev/null +++ b/pkg/model/model.go @@ -0,0 +1,107 @@ +package model + +import ( + "context" + "fmt" + "net/http" + "net/url" + "time" +) + +type Subscriber struct { + SubscriberID string `json:"subscriber_id"` + URL string `json:"url" format:"uri"` + Type string `json:"type" enum:"BAP,BPP,BG"` + Domain string `json:"domain"` +} +type Subscription struct { + Subscriber `json:",inline"` + KeyID string `json:"key_id" format:"uuid"` + SigningPublicKey string `json:"signing_public_key"` + EncrPublicKey string `json:"encr_public_key"` + ValidFrom time.Time `json:"valid_from" format:"date-time"` + ValidUntil time.Time `json:"valid_until" format:"date-time"` + Status string `json:"status" enum:"INITIATED,UNDER_SUBSCRIPTION,SUBSCRIBED,EXPIRED,UNSUBSCRIBED,INVALID_SSL"` + Created time.Time `json:"created" format:"date-time"` + Updated time.Time `json:"updated" format:"date-time"` + Nonce string +} + +const ( + AuthHeaderSubscriber string = "Authorization" + AuthHeaderGateway string = "X-Gateway-Authorization" + UnaAuthorizedHeaderSubscriber string = "WWW-Authenticate" + UnaAuthorizedHeaderGateway string = "Proxy-Authenticate" +) + +type contextKey string + +const MsgIDKey = contextKey("message_id") + +type Role string + +const ( + RoleBAP Role = "bap" + RoleBPP Role = "bpp" + RoleGateway Role = "gateway" + RoleRegistery Role = "registery" +) + +var validRoles = map[Role]bool{ + RoleBAP: true, + RoleBPP: true, + RoleGateway: true, + RoleRegistery: true, +} + +func (r *Role) UnmarshalYAML(unmarshal func(interface{}) error) error { + var roleName string + if err := unmarshal(&roleName); err != nil { + return err + } + + role := Role(roleName) + if !validRoles[role] { + return fmt.Errorf("invalid Role: %s", roleName) + } + *r = role + return nil +} + +type Route struct { + Type string + URL *url.URL + Publisher string +} + +type StepContext struct { + context.Context + Request *http.Request + Body []byte + Route *Route + SubID string + Role Role + RespHeader http.Header +} + +func (ctx *StepContext) WithContext(newCtx context.Context) { + ctx.Context = newCtx +} + +type Status string + +const ( + StatusACK Status = "ACK" + StatusNACK Status = "NACK" +) + +type Ack struct { + Status Status `json:"status"` +} +type Message struct { + Ack Ack `json:"ack"` + Error *Error `json:"error,omitempty"` +} +type Response struct { + Message Message `json:"message"` +} diff --git a/pkg/response/response.go b/pkg/response/response.go index 7364495..c6d1094 100644 --- a/pkg/response/response.go +++ b/pkg/response/response.go @@ -3,20 +3,15 @@ package response import ( "context" "encoding/json" + "errors" "fmt" + "net/http" + "strings" + + "github.com/beckn/beckn-onix/pkg/model" ) -type ErrorType string - -const ( - SchemaValidationErrorType ErrorType = "SCHEMA_VALIDATION_ERROR" - InvalidRequestErrorType ErrorType = "INVALID_REQUEST" -) - -type BecknRequest struct { - Context map[string]interface{} `json:"context,omitempty"` -} type Error struct { Code string `json:"code,omitempty"` @@ -24,6 +19,7 @@ type Error struct { Paths string `json:"paths,omitempty"` } + // SchemaValidationErr represents a collection of schema validation failures. type SchemaValidationErr struct { Errors []Error @@ -45,114 +41,75 @@ type Message struct { Error *Error `json:"error,omitempty"` } -type BecknResponse struct { - Context map[string]interface{} `json:"context,omitempty"` - Message Message `json:"message,omitempty"` -} - -type ClientFailureBecknResponse struct { - Context map[string]interface{} `json:"context,omitempty"` - Error *Error `json:"error,omitempty"` -} - -var errorMap = map[ErrorType]Error{ - SchemaValidationErrorType: { - Code: "400", - Message: "Schema validation failed", - }, - InvalidRequestErrorType: { - Code: "401", - Message: "Invalid request format", - }, -} - -var DefaultError = Error{ - Code: "500", - Message: "Internal server error", -} - -func Nack(ctx context.Context, tp ErrorType, paths string, body []byte) ([]byte, error) { - var req BecknRequest - if err := json.Unmarshal(body, &req); err != nil { - return nil, fmt.Errorf("failed to parse request: %w", err) - } - - errorObj, ok := errorMap[tp] - if paths != "" { - errorObj.Paths = paths - } - - var response BecknResponse - - if !ok { - response = BecknResponse{ - Context: req.Context, - Message: Message{ - Ack: struct { - Status string `json:"status,omitempty"` - }{ - Status: "NACK", - }, - Error: &DefaultError, - }, - } - } else { - response = BecknResponse{ - Context: req.Context, - Message: Message{ - Ack: struct { - Status string `json:"status,omitempty"` - }{ - Status: "NACK", - }, - Error: &errorObj, - }, - } - } - - return json.Marshal(response) -} - -func Ack(ctx context.Context, body []byte) ([]byte, error) { - var req BecknRequest - if err := json.Unmarshal(body, &req); err != nil { - return nil, fmt.Errorf("failed to parse request: %w", err) - } - - response := BecknResponse{ - Context: req.Context, - Message: Message{ - Ack: struct { - Status string `json:"status,omitempty"` - }{ - Status: "ACK", +func SendAck(w http.ResponseWriter) { + resp := &model.Response{ + Message: model.Message{ + Ack: model.Ack{ + Status: model.StatusACK, }, }, } - return json.Marshal(response) + data, _ := json.Marshal(resp) //should not fail here + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write(data) + if err != nil { + http.Error(w, "failed to write response", http.StatusInternalServerError) + return + } } -func HandleClientFailure(ctx context.Context, tp ErrorType, body []byte) ([]byte, error) { - var req BecknRequest - if err := json.Unmarshal(body, &req); err != nil { - return nil, fmt.Errorf("failed to parse request: %w", err) +func nack(w http.ResponseWriter, err *model.Error, status int, ctx context.Context) { + resp := &model.Response{ + Message: model.Message{ + Ack: model.Ack{ + Status: model.StatusNACK, + }, + Error: err, + }, } + data, _ := json.Marshal(resp) //should not fail here - errorObj, ok := errorMap[tp] - var response ClientFailureBecknResponse - - if !ok { - response = ClientFailureBecknResponse{ - Context: req.Context, - Error: &DefaultError, - } - } else { - response = ClientFailureBecknResponse{ - Context: req.Context, - Error: &errorObj, - } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _, er := w.Write(data) + if er != nil { + fmt.Printf("Error writing response: %v, MessageID: %s", er, ctx.Value(model.MsgIDKey)) + http.Error(w, fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), http.StatusInternalServerError) + return + } +} + +func internalServerError(ctx context.Context) *model.Error { + return &model.Error{ + Code: http.StatusText(http.StatusInternalServerError), + Message: fmt.Sprintf("Internal server error, MessageID: %s", ctx.Value(model.MsgIDKey)), + } +} + +func SendNack(ctx context.Context, w http.ResponseWriter, err error) { + var schemaErr *model.SchemaValidationErr + var signErr *model.SignValidationErr + var badReqErr *model.BadReqErr + var notFoundErr *model.NotFoundErr + + switch { + case errors.As(err, &schemaErr): + nack(w, schemaErr.BecknError(), http.StatusBadRequest, ctx) + return + case errors.As(err, &signErr): + nack(w, signErr.BecknError(), http.StatusUnauthorized, ctx) + return + case errors.As(err, &badReqErr): + nack(w, badReqErr.BecknError(), http.StatusBadRequest, ctx) + return + case errors.As(err, ¬FoundErr): + nack(w, notFoundErr.BecknError(), http.StatusNotFound, ctx) + return + default: + nack(w, internalServerError(ctx), http.StatusInternalServerError, ctx) + return } - - return json.Marshal(response) } diff --git a/pkg/response/response_test.go b/pkg/response/response_test.go index 73dcb6a..7e62aca 100644 --- a/pkg/response/response_test.go +++ b/pkg/response/response_test.go @@ -1,308 +1,133 @@ package response import ( + "bytes" "context" "encoding/json" - "reflect" + "errors" + "net/http" + "net/http/httptest" "testing" + + "github.com/beckn/beckn-onix/pkg/model" ) -func TestNack(t *testing.T) { - ctx := context.Background() +type errorResponseWriter struct{} + +// TODO: Optimize the cases by removing these +func (e *errorResponseWriter) Write([]byte) (int, error) { + return 0, errors.New("write error") +} +func (e *errorResponseWriter) WriteHeader(statusCode int) {} + +func (e *errorResponseWriter) Header() http.Header { + return http.Header{} +} + +func TestSendAck(t *testing.T) { + _, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) // For tests + } + rr := httptest.NewRecorder() + + SendAck(rr) + + if rr.Code != http.StatusOK { + t.Errorf("wanted status code %d, got %d", http.StatusOK, rr.Code) + } + + expected := `{"message":{"ack":{"status":"ACK"}}}` + if rr.Body.String() != expected { + t.Errorf("err.Error() = %s, want %s", + rr.Body.String(), expected) + + } +} + +func TestSendNack(t *testing.T) { + ctx := context.WithValue(context.Background(), model.MsgIDKey, "123456") tests := []struct { - name string - errorType ErrorType - requestBody string - wantStatus string - wantErrCode string - wantErrMsg string - wantErr bool - path string + name string + err error + expected string + status int }{ { - name: "Schema validation error", - errorType: SchemaValidationErrorType, - requestBody: `{"context": {"domain": "test-domain", "location": "test-location"}}`, - wantStatus: "NACK", - wantErrCode: "400", - wantErrMsg: "Schema validation failed", - wantErr: false, - path: "test", + name: "SchemaValidationErr", + err: &model.SchemaValidationErr{ + Errors: []model.Error{ + {Paths: "/path1", Message: "Error 1"}, + {Paths: "/path2", Message: "Error 2"}, + }, + }, + status: http.StatusBadRequest, + expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Bad Request","paths":"/path1;/path2","message":"Error 1; Error 2"}}}`, }, { - name: "Invalid request error", - errorType: InvalidRequestErrorType, - requestBody: `{"context": {"domain": "test-domain"}}`, - wantStatus: "NACK", - wantErrCode: "401", - wantErrMsg: "Invalid request format", - wantErr: false, - path: "test", + name: "SignValidationErr", + err: model.NewSignValidationErr(errors.New("signature invalid")), + status: http.StatusUnauthorized, + expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Unauthorized","message":"Signature Validation Error: signature invalid"}}}`, }, { - name: "Unknown error type", - errorType: "UNKNOWN_ERROR", - requestBody: `{"context": {"domain": "test-domain"}}`, - wantStatus: "NACK", - wantErrCode: "500", - wantErrMsg: "Internal server error", - wantErr: false, - path: "test", + name: "BadReqErr", + err: model.NewBadReqErr(errors.New("bad request error")), + status: http.StatusBadRequest, + expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Bad Request","message":"BAD Request: bad request error"}}}`, }, { - name: "Empty request body", - errorType: SchemaValidationErrorType, - requestBody: `{}`, - wantStatus: "NACK", - wantErrCode: "400", - wantErrMsg: "Schema validation failed", - wantErr: false, - path: "test", + name: "NotFoundErr", + err: model.NewNotFoundErr(errors.New("endpoint not found")), + status: http.StatusNotFound, + expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Not Found","message":"Endpoint not found: endpoint not found"}}}`, }, { - name: "Invalid JSON", - errorType: SchemaValidationErrorType, - requestBody: `{invalid json}`, - wantErr: true, - path: "test", - }, - { - name: "Complex nested context", - errorType: SchemaValidationErrorType, - requestBody: `{"context": {"domain": "test-domain", "nested": {"key1": "value1", "key2": 123}}}`, - wantStatus: "NACK", - wantErrCode: "400", - wantErrMsg: "Schema validation failed", - wantErr: false, - path: "test", + name: "InternalServerError", + err: errors.New("unexpected error"), + status: http.StatusInternalServerError, + expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"Internal Server Error","message":"Internal server error, MessageID: 123456"}}}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp, err := Nack(ctx, tt.errorType, tt.path, []byte(tt.requestBody)) + _, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) // For tests + } + rr := httptest.NewRecorder() - if (err != nil) != tt.wantErr { - t.Errorf("Nack() error = %v, wantErr %v", err, tt.wantErr) - return + SendNack(ctx, rr, tt.err) + + if rr.Code != tt.status { + t.Errorf("wanted status code %d, got %d", tt.status, rr.Code) } - if tt.wantErr && err != nil { - return + var actual map[string]interface{} + err = json.Unmarshal(rr.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal response: %v", err) } - var becknResp BecknResponse - if err := json.Unmarshal(resp, &becknResp); err != nil { - t.Errorf("Failed to unmarshal response: %v", err) - return + var expected map[string]interface{} + err = json.Unmarshal([]byte(tt.expected), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) } - if becknResp.Message.Ack.Status != tt.wantStatus { - t.Errorf("Nack() status = %v, want %v", becknResp.Message.Ack.Status, tt.wantStatus) + if !compareJSON(expected, actual) { + t.Errorf("err.Error() = %s, want %s", + actual, expected) } - if becknResp.Message.Error.Code != tt.wantErrCode { - t.Errorf("Nack() error code = %v, want %v", becknResp.Message.Error.Code, tt.wantErrCode) - } - - if becknResp.Message.Error.Message != tt.wantErrMsg { - t.Errorf("Nack() error message = %v, want %v", becknResp.Message.Error.Message, tt.wantErrMsg) - } - - var origReq BecknRequest - if err := json.Unmarshal([]byte(tt.requestBody), &origReq); err == nil { - if !compareContexts(becknResp.Context, origReq.Context) { - t.Errorf("Nack() context not preserved, got = %v, want %v", becknResp.Context, origReq.Context) - } - } }) } } -func TestAck(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - requestBody string - wantStatus string - wantErr bool - }{ - { - name: "Valid request", - requestBody: `{"context": {"domain": "test-domain", "location": "test-location"}}`, - wantStatus: "ACK", - wantErr: false, - }, - { - name: "Empty context", - requestBody: `{"context": {}}`, - wantStatus: "ACK", - wantErr: false, - }, - { - name: "Invalid JSON", - requestBody: `{invalid json}`, - wantErr: true, - }, - { - name: "Complex nested context", - requestBody: `{"context": {"domain": "test-domain", "nested": {"key1": "value1", "key2": 123, "array": [1,2,3]}}}`, - wantStatus: "ACK", - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - resp, err := Ack(ctx, []byte(tt.requestBody)) - - if (err != nil) != tt.wantErr { - t.Errorf("Ack() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if tt.wantErr && err != nil { - return - } - - var becknResp BecknResponse - if err := json.Unmarshal(resp, &becknResp); err != nil { - t.Errorf("Failed to unmarshal response: %v", err) - return - } - - if becknResp.Message.Ack.Status != tt.wantStatus { - t.Errorf("Ack() status = %v, want %v", becknResp.Message.Ack.Status, tt.wantStatus) - } - - if becknResp.Message.Error != nil { - t.Errorf("Ack() should not have error, got %v", becknResp.Message.Error) - } - - var origReq BecknRequest - if err := json.Unmarshal([]byte(tt.requestBody), &origReq); err == nil { - if !compareContexts(becknResp.Context, origReq.Context) { - t.Errorf("Ack() context not preserved, got = %v, want %v", becknResp.Context, origReq.Context) - } - } - }) - } -} - -func TestHandleClientFailure(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - errorType ErrorType - requestBody string - wantErrCode string - wantErrMsg string - wantErr bool - }{ - { - name: "Schema validation error", - errorType: SchemaValidationErrorType, - requestBody: `{"context": {"domain": "test-domain", "location": "test-location"}}`, - wantErrCode: "400", - wantErrMsg: "Schema validation failed", - wantErr: false, - }, - { - name: "Invalid request error", - errorType: InvalidRequestErrorType, - requestBody: `{"context": {"domain": "test-domain"}}`, - wantErrCode: "401", - wantErrMsg: "Invalid request format", - wantErr: false, - }, - { - name: "Unknown error type", - errorType: "UNKNOWN_ERROR", - requestBody: `{"context": {"domain": "test-domain"}}`, - wantErrCode: "500", - wantErrMsg: "Internal server error", - wantErr: false, - }, - { - name: "Invalid JSON", - errorType: SchemaValidationErrorType, - requestBody: `{invalid json}`, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - resp, err := HandleClientFailure(ctx, tt.errorType, []byte(tt.requestBody)) - - if (err != nil) != tt.wantErr { - t.Errorf("HandleClientFailure() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if tt.wantErr && err != nil { - return - } - - var failureResp ClientFailureBecknResponse - if err := json.Unmarshal(resp, &failureResp); err != nil { - t.Errorf("Failed to unmarshal response: %v", err) - return - } - - if failureResp.Error.Code != tt.wantErrCode { - t.Errorf("HandleClientFailure() error code = %v, want %v", failureResp.Error.Code, tt.wantErrCode) - } - - if failureResp.Error.Message != tt.wantErrMsg { - t.Errorf("HandleClientFailure() error message = %v, want %v", failureResp.Error.Message, tt.wantErrMsg) - } - - var origReq BecknRequest - if err := json.Unmarshal([]byte(tt.requestBody), &origReq); err == nil { - if !compareContexts(failureResp.Context, origReq.Context) { - t.Errorf("HandleClientFailure() context not preserved, got = %v, want %v", failureResp.Context, origReq.Context) - } - } - }) - } -} - -func TestErrorMap(t *testing.T) { - - expectedTypes := []ErrorType{ - SchemaValidationErrorType, - InvalidRequestErrorType, - } - - for _, tp := range expectedTypes { - if _, exists := errorMap[tp]; !exists { - t.Errorf("ErrorType %v not found in errorMap", tp) - } - } - - if DefaultError.Code != "500" || DefaultError.Message != "Internal server error" { - t.Errorf("DefaultError not set correctly, got code=%v, message=%v", DefaultError.Code, DefaultError.Message) - } -} - -func compareContexts(c1, c2 map[string]interface{}) bool { - - if c1 == nil && c2 == nil { - return true - } - - if c1 == nil && len(c2) == 0 || c2 == nil && len(c1) == 0 { - return true - } - - return reflect.DeepEqual(c1, c2) -} - func TestSchemaValidationErr_Error(t *testing.T) { + // Create sample validation errors validationErrors := []Error{ {Paths: "name", Message: "Name is required"}, {Paths: "email", Message: "Invalid email format"}, @@ -315,3 +140,117 @@ func TestSchemaValidationErr_Error(t *testing.T) { } } + +func compareJSON(expected, actual map[string]interface{}) bool { + expectedBytes, _ := json.Marshal(expected) + actualBytes, _ := json.Marshal(actual) + return bytes.Equal(expectedBytes, actualBytes) +} + +func TestSendAck_WriteError(t *testing.T) { + w := &errorResponseWriter{} + SendAck(w) +} + +// Mock struct to force JSON marshalling error +type badMessage struct{} + +func (b *badMessage) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshal error") +} + +func TestNack_1(t *testing.T) { + tests := []struct { + name string + err *model.Error + status int + expected string + useBadJSON bool + useBadWrite bool + }{ + { + name: "Schema Validation Error", + err: &model.Error{ + Code: "BAD_REQUEST", + Paths: "/test/path", + Message: "Invalid schema", + }, + status: http.StatusBadRequest, + expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"BAD_REQUEST","paths":"/test/path","message":"Invalid schema"}}}`, + }, + { + name: "Internal Server Error", + err: &model.Error{ + Code: "INTERNAL_SERVER_ERROR", + Message: "Something went wrong", + }, + status: http.StatusInternalServerError, + expected: `{"message":{"ack":{"status":"NACK"},"error":{"code":"INTERNAL_SERVER_ERROR","message":"Something went wrong"}}}`, + }, + { + name: "JSON Marshal Error", + err: nil, // This will be overridden to cause marshaling error + status: http.StatusInternalServerError, + expected: `Internal server error, MessageID: 12345`, + useBadJSON: true, + }, + { + name: "Write Error", + err: &model.Error{ + Code: "WRITE_ERROR", + Message: "Failed to write response", + }, + status: http.StatusInternalServerError, + expected: `Internal server error, MessageID: 12345`, + useBadWrite: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + ctx := context.WithValue(req.Context(), model.MsgIDKey, "12345") + + var w http.ResponseWriter + if tt.useBadWrite { + w = &errorResponseWriter{} // Simulate write error + } else { + w = httptest.NewRecorder() + } + + // TODO: Fix this approach , should not be used like this. + if tt.useBadJSON { + data, _ := json.Marshal(&badMessage{}) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tt.status) + _, err := w.Write(data) + if err != nil { + http.Error(w, "Failed to write response", http.StatusInternalServerError) + return + } + return + } + + nack(w, tt.err, tt.status, ctx) + if !tt.useBadWrite { + recorder, ok := w.(*httptest.ResponseRecorder) + if !ok { + t.Fatal("Failed to cast response recorder") + } + + if recorder.Code != tt.status { + t.Errorf("wanted status code %d, got %d", tt.status, recorder.Code) + } + + body := recorder.Body.String() + if body != tt.expected { + t.Errorf("err.Error() = %s, want %s", + body, tt.expected) + } + } + }) + } +}