bug fixes
This commit is contained in:
@@ -63,7 +63,7 @@ func (c *registryClient) Subscribe(ctx context.Context, subscription *model.Subs
|
|||||||
|
|
||||||
// Lookup calls the /lookup endpoint with retry and returns a slice of Subscription.
|
// Lookup calls the /lookup endpoint with retry and returns a slice of Subscription.
|
||||||
func (c *registryClient) Lookup(ctx context.Context, subscription *model.Subscription) ([]model.Subscription, error) {
|
func (c *registryClient) Lookup(ctx context.Context, subscription *model.Subscription) ([]model.Subscription, error) {
|
||||||
lookupURL := fmt.Sprintf("%s/lookUp", c.config.RegisteryURL)
|
lookupURL := fmt.Sprintf("%s/lookup", c.config.RegisteryURL)
|
||||||
|
|
||||||
jsonData, err := json.Marshal(subscription)
|
jsonData, err := json.Marshal(subscription)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
35
core/module/handler/healthcheck.go
Normal file
35
core/module/handler/healthcheck.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HealthCheckResponse defines the structure for our health check JSON response.
|
||||||
|
type healthCheckResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Service string `json:"service"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthHandler handles requests to the /health endpoint.
|
||||||
|
func HealthHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Ensure the request method is GET.
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
response := healthCheckResponse{
|
||||||
|
Status: "ok",
|
||||||
|
Service: "beckn-adapter",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
http.Error(w, "Error encoding response", http.StatusInternalServerError)
|
||||||
|
fmt.Printf("Error encoding health check response: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
112
core/module/handler/healthcheck_test.go
Normal file
112
core/module/handler/healthcheck_test.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestHealthHandler tests the successful GET request to the /health endpoint.
|
||||||
|
func TestHealthHandler(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "/health", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
HealthHandler(rr, req)
|
||||||
|
|
||||||
|
expContentType := "application/json"
|
||||||
|
expStatus := "ok"
|
||||||
|
expService := "beckn-adapter"
|
||||||
|
|
||||||
|
if status := rr.Code; status != http.StatusOK {
|
||||||
|
t.Fatalf("HealthHandler returned wrong status code: got %v want %v",
|
||||||
|
status, http.StatusOK)
|
||||||
|
}
|
||||||
|
if contentType := rr.Header().Get("Content-Type"); contentType != expContentType {
|
||||||
|
t.Errorf("HealthHandler returned wrong Content-Type: got %v want %v",
|
||||||
|
contentType, expContentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response healthCheckResponse
|
||||||
|
err = json.NewDecoder(rr.Body).Decode(&response)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decode response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Status != expStatus {
|
||||||
|
t.Errorf("HealthHandler returned wrong status in JSON: got %v want %v",
|
||||||
|
response.Status, expStatus)
|
||||||
|
}
|
||||||
|
if response.Service != expService {
|
||||||
|
t.Errorf("HealthHandler returned wrong service in JSON: got %v want %v",
|
||||||
|
response.Service, expService)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockResponseWriter is a custom http.ResponseWriter that can simulate an error on Write.
|
||||||
|
type mockResponseWriter struct {
|
||||||
|
httptest.ResponseRecorder
|
||||||
|
writeFail bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResponseWriter) Write(p []byte) (n int, err error) {
|
||||||
|
if m.writeFail {
|
||||||
|
m.writeFail = false
|
||||||
|
return 0, fmt.Errorf("simulated write error")
|
||||||
|
}
|
||||||
|
return m.ResponseRecorder.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHealthHandlerErrors tests error scenarios for the HealthHandler.
|
||||||
|
func TestHealthHandlerErrors(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
recorder *mockResponseWriter
|
||||||
|
expStatus int
|
||||||
|
expBody string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Method Not Allowed",
|
||||||
|
method: http.MethodPost,
|
||||||
|
recorder: &mockResponseWriter{
|
||||||
|
ResponseRecorder: *httptest.NewRecorder(),
|
||||||
|
},
|
||||||
|
expStatus: http.StatusMethodNotAllowed,
|
||||||
|
expBody: "Method not allowed\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JSON Encoding Error",
|
||||||
|
method: http.MethodGet,
|
||||||
|
recorder: &mockResponseWriter{
|
||||||
|
ResponseRecorder: *httptest.NewRecorder(),
|
||||||
|
writeFail: true,
|
||||||
|
},
|
||||||
|
expStatus: http.StatusInternalServerError,
|
||||||
|
expBody: "Error encoding response\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(tt.method, "/health", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request for %s: %v", tt.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
HealthHandler(tt.recorder, req)
|
||||||
|
|
||||||
|
if status := tt.recorder.Code; status != tt.expStatus {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.expStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
if body := tt.recorder.Body.String(); body != tt.expBody {
|
||||||
|
t.Errorf("handler returned unexpected body: got %q want %q", body, tt.expBody)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/beckn/beckn-onix/core/module/client"
|
"github.com/beckn/beckn-onix/core/module/client"
|
||||||
"github.com/beckn/beckn-onix/pkg/log"
|
"github.com/beckn/beckn-onix/pkg/log"
|
||||||
@@ -86,9 +85,6 @@ func (h *stdHandler) stepCtx(r *http.Request, rh http.Header) (*model.StepContex
|
|||||||
}
|
}
|
||||||
r.Body.Close()
|
r.Body.Close()
|
||||||
subID := h.subID(r.Context())
|
subID := h.subID(r.Context())
|
||||||
if len(subID) == 0 {
|
|
||||||
return nil, model.NewBadReqErr(fmt.Errorf("subscriberID not set"))
|
|
||||||
}
|
|
||||||
return &model.StepContext{
|
return &model.StepContext{
|
||||||
Context: r.Context(),
|
Context: r.Context(),
|
||||||
Request: r,
|
Request: r,
|
||||||
@@ -116,7 +112,7 @@ func route(ctx *model.StepContext, r *http.Request, w http.ResponseWriter, pb de
|
|||||||
switch ctx.Route.TargetType {
|
switch ctx.Route.TargetType {
|
||||||
case "url":
|
case "url":
|
||||||
log.Infof(ctx.Context, "Forwarding request to URL: %s", ctx.Route.URL)
|
log.Infof(ctx.Context, "Forwarding request to URL: %s", ctx.Route.URL)
|
||||||
proxyFunc(r, w, ctx.Route.URL)
|
proxyFunc(ctx, r, w)
|
||||||
return
|
return
|
||||||
case "publisher":
|
case "publisher":
|
||||||
if pb == nil {
|
if pb == nil {
|
||||||
@@ -140,16 +136,18 @@ func route(ctx *model.StepContext, r *http.Request, w http.ResponseWriter, pb de
|
|||||||
}
|
}
|
||||||
response.SendAck(w)
|
response.SendAck(w)
|
||||||
}
|
}
|
||||||
|
func proxy(ctx *model.StepContext, r *http.Request, w http.ResponseWriter) {
|
||||||
// proxy forwards the request to a target URL using a reverse proxy.
|
target := ctx.Route.URL
|
||||||
func proxy(r *http.Request, w http.ResponseWriter, target *url.URL) {
|
|
||||||
r.URL.Scheme = target.Scheme
|
|
||||||
r.URL.Host = target.Host
|
|
||||||
r.URL.Path = target.Path
|
|
||||||
|
|
||||||
r.Header.Set("X-Forwarded-Host", r.Host)
|
r.Header.Set("X-Forwarded-Host", r.Host)
|
||||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
|
||||||
log.Infof(r.Context(), "Proxying request to: %s", target)
|
director := func(req *http.Request) {
|
||||||
|
req.URL = target
|
||||||
|
req.Host = target.Host
|
||||||
|
|
||||||
|
log.Request(req.Context(), req, ctx.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := &httputil.ReverseProxy{Director: director}
|
||||||
|
|
||||||
proxy.ServeHTTP(w, r)
|
proxy.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,9 @@ func newSignStep(signer definition.Signer, km definition.KeyManager) (definition
|
|||||||
|
|
||||||
// Run executes the signing step.
|
// Run executes the signing step.
|
||||||
func (s *signStep) Run(ctx *model.StepContext) error {
|
func (s *signStep) Run(ctx *model.StepContext) error {
|
||||||
|
if len(ctx.SubID) == 0 {
|
||||||
|
return model.NewBadReqErr(fmt.Errorf("subscriberID not set"))
|
||||||
|
}
|
||||||
keySet, err := s.km.Keyset(ctx, ctx.SubID)
|
keySet, err := s.km.Keyset(ctx, ctx.SubID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get signing key: %w", err)
|
return fmt.Errorf("failed to get signing key: %w", err)
|
||||||
@@ -43,7 +46,7 @@ func (s *signStep) Run(ctx *model.StepContext) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
authHeader := s.generateAuthHeader(ctx.SubID, keySet.UniqueKeyID, createdAt, validTill, sign)
|
authHeader := s.generateAuthHeader(ctx.SubID, keySet.UniqueKeyID, createdAt, validTill, sign)
|
||||||
|
log.Debugf(ctx, "Signature generated: %v", sign)
|
||||||
header := model.AuthHeaderSubscriber
|
header := model.AuthHeaderSubscriber
|
||||||
if ctx.Role == model.RoleGateway {
|
if ctx.Role == model.RoleGateway {
|
||||||
header = model.AuthHeaderGateway
|
header = model.AuthHeaderGateway
|
||||||
@@ -83,11 +86,14 @@ func (s *validateSignStep) Run(ctx *model.StepContext) error {
|
|||||||
unauthHeader := fmt.Sprintf("Signature realm=\"%s\",headers=\"(created) (expires) digest\"", ctx.SubID)
|
unauthHeader := fmt.Sprintf("Signature realm=\"%s\",headers=\"(created) (expires) digest\"", ctx.SubID)
|
||||||
headerValue := ctx.Request.Header.Get(model.AuthHeaderGateway)
|
headerValue := ctx.Request.Header.Get(model.AuthHeaderGateway)
|
||||||
if len(headerValue) != 0 {
|
if len(headerValue) != 0 {
|
||||||
|
log.Debugf(ctx, "Validating %v Header", model.AuthHeaderGateway)
|
||||||
if err := s.validate(ctx, headerValue); err != nil {
|
if err := s.validate(ctx, headerValue); err != nil {
|
||||||
ctx.RespHeader.Set(model.UnaAuthorizedHeaderGateway, unauthHeader)
|
ctx.RespHeader.Set(model.UnaAuthorizedHeaderGateway, unauthHeader)
|
||||||
return model.NewSignValidationErr(fmt.Errorf("failed to validate %s: %w", model.AuthHeaderGateway, err))
|
return model.NewSignValidationErr(fmt.Errorf("failed to validate %s: %w", model.AuthHeaderGateway, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf(ctx, "Validating %v Header", model.AuthHeaderSubscriber)
|
||||||
headerValue = ctx.Request.Header.Get(model.AuthHeaderSubscriber)
|
headerValue = ctx.Request.Header.Get(model.AuthHeaderSubscriber)
|
||||||
if len(headerValue) == 0 {
|
if len(headerValue) == 0 {
|
||||||
ctx.RespHeader.Set(model.UnaAuthorizedHeaderSubscriber, unauthHeader)
|
ctx.RespHeader.Set(model.UnaAuthorizedHeaderSubscriber, unauthHeader)
|
||||||
@@ -102,13 +108,12 @@ func (s *validateSignStep) Run(ctx *model.StepContext) error {
|
|||||||
|
|
||||||
// validate checks the validity of the provided signature header.
|
// validate checks the validity of the provided signature header.
|
||||||
func (s *validateSignStep) validate(ctx *model.StepContext, value string) error {
|
func (s *validateSignStep) validate(ctx *model.StepContext, value string) error {
|
||||||
headerParts := strings.Split(value, "|")
|
headerVals, err := parseHeader(value)
|
||||||
ids := strings.Split(headerParts[0], "\"")
|
if err != nil {
|
||||||
if len(ids) < 2 || len(headerParts) < 3 {
|
return fmt.Errorf("failed to parse header")
|
||||||
return fmt.Errorf("malformed sign header")
|
|
||||||
}
|
}
|
||||||
keyID := headerParts[1]
|
log.Debugf(ctx, "Validating Signature for subscriberID: %v", headerVals.SubscriberID)
|
||||||
signingPublicKey, _, err := s.km.LookupNPKeys(ctx, ctx.SubID, keyID)
|
signingPublicKey, _, err := s.km.LookupNPKeys(ctx, headerVals.SubscriberID, headerVals.UniqueID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get validation key: %w", err)
|
return fmt.Errorf("failed to get validation key: %w", err)
|
||||||
}
|
}
|
||||||
@@ -118,6 +123,45 @@ func (s *validateSignStep) validate(ctx *model.StepContext, value string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParsedKeyID holds the components from the parsed Authorization header's keyId.
|
||||||
|
type authHeader struct {
|
||||||
|
SubscriberID string
|
||||||
|
UniqueID string
|
||||||
|
Algorithm string
|
||||||
|
}
|
||||||
|
|
||||||
|
// keyID extracts subscriber_id and unique_key_id from the Authorization header.
|
||||||
|
// Example keyId format: "{subscriber_id}|{unique_key_id}|{algorithm}"
|
||||||
|
func parseHeader(header string) (*authHeader, error) {
|
||||||
|
// Example: Signature keyId="bpp.example.com|key-1|ed25519",algorithm="ed25519",...
|
||||||
|
keyIDPart := ""
|
||||||
|
// Look for keyId="<value>"
|
||||||
|
const keyIdPrefix = `keyId="`
|
||||||
|
startIndex := strings.Index(header, keyIdPrefix)
|
||||||
|
if startIndex != -1 {
|
||||||
|
startIndex += len(keyIdPrefix)
|
||||||
|
endIndex := strings.Index(header[startIndex:], `"`)
|
||||||
|
if endIndex != -1 {
|
||||||
|
keyIDPart = strings.TrimSpace(header[startIndex : startIndex+endIndex])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if keyIDPart == "" {
|
||||||
|
return nil, fmt.Errorf("keyId parameter not found in Authorization header")
|
||||||
|
}
|
||||||
|
|
||||||
|
keyIDComponents := strings.Split(keyIDPart, "|")
|
||||||
|
if len(keyIDComponents) != 3 {
|
||||||
|
return nil, fmt.Errorf("keyId parameter has incorrect format, expected 3 components separated by '|', got %d for '%s'", len(keyIDComponents), keyIDPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &authHeader{
|
||||||
|
SubscriberID: strings.TrimSpace(keyIDComponents[0]),
|
||||||
|
UniqueID: strings.TrimSpace(keyIDComponents[1]),
|
||||||
|
Algorithm: strings.TrimSpace(keyIDComponents[2]),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// validateSchemaStep represents the schema validation step.
|
// validateSchemaStep represents the schema validation step.
|
||||||
type validateSchemaStep struct {
|
type validateSchemaStep struct {
|
||||||
validator definition.SchemaValidator
|
validator definition.SchemaValidator
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ var handlerProviders = map[handler.Type]Provider{
|
|||||||
// It iterates over the module configurations, retrieves appropriate handler providers,
|
// It iterates over the module configurations, retrieves appropriate handler providers,
|
||||||
// and registers the handlers with the HTTP multiplexer.
|
// and registers the handlers with the HTTP multiplexer.
|
||||||
func Register(ctx context.Context, mCfgs []Config, mux *http.ServeMux, mgr handler.PluginManager) error {
|
func Register(ctx context.Context, mCfgs []Config, mux *http.ServeMux, mgr handler.PluginManager) error {
|
||||||
|
mux.Handle("/health", http.HandlerFunc(handler.HealthHandler))
|
||||||
|
|
||||||
log.Debugf(ctx, "Registering modules with config: %#v", mCfgs)
|
log.Debugf(ctx, "Registering modules with config: %#v", mCfgs)
|
||||||
// Iterate over the handlers in the configuration.
|
// Iterate over the handlers in the configuration.
|
||||||
for _, c := range mCfgs {
|
for _, c := range mCfgs {
|
||||||
|
|||||||
@@ -118,7 +118,15 @@ func TestRegisterSuccess(t *testing.T) {
|
|||||||
if capturedModuleName != "test-module" {
|
if capturedModuleName != "test-module" {
|
||||||
t.Errorf("expected module_id in context to be 'test-module', got %v", capturedModuleName)
|
t.Errorf("expected module_id in context to be 'test-module', got %v", capturedModuleName)
|
||||||
}
|
}
|
||||||
|
// Verifying /health endpoint registration
|
||||||
|
reqHealth := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||||
|
recHealth := httptest.NewRecorder()
|
||||||
|
mux.ServeHTTP(recHealth, reqHealth)
|
||||||
|
|
||||||
|
if status := recHealth.Code; status != http.StatusOK {
|
||||||
|
t.Errorf("handler for /health returned wrong status code: got %v want %v",
|
||||||
|
status, http.StatusOK)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRegisterFailure tests scenarios where the handler registration should fail.
|
// TestRegisterFailure tests scenarios where the handler registration should fail.
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ type Route struct {
|
|||||||
|
|
||||||
// Keyset represents a collection of cryptographic keys used for signing and encryption.
|
// Keyset represents a collection of cryptographic keys used for signing and encryption.
|
||||||
type Keyset struct {
|
type Keyset struct {
|
||||||
|
SubscriberID string
|
||||||
UniqueKeyID string // UniqueKeyID is the identifier for the key pair.
|
UniqueKeyID string // UniqueKeyID is the identifier for the key pair.
|
||||||
SigningPrivate string // SigningPrivate is the private key used for signing operations.
|
SigningPrivate string // SigningPrivate is the private key used for signing operations.
|
||||||
SigningPublic string // SigningPublic is the public key corresponding to the signing private key.
|
SigningPublic string // SigningPublic is the public key corresponding to the signing private key.
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ func (r *Router) loadRules(configPath string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid URL in rule: %w", err)
|
return fmt.Errorf("invalid URL in rule: %w", err)
|
||||||
}
|
}
|
||||||
|
parsedURL.Path = joinPath(parsedURL, endpoint)
|
||||||
route = &model.Route{
|
route = &model.Route{
|
||||||
TargetType: rule.TargetType,
|
TargetType: rule.TargetType,
|
||||||
URL: parsedURL,
|
URL: parsedURL,
|
||||||
@@ -126,6 +127,7 @@ func (r *Router) loadRules(configPath string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid URL in rule: %w", err)
|
return fmt.Errorf("invalid URL in rule: %w", err)
|
||||||
}
|
}
|
||||||
|
parsedURL.Path = joinPath(parsedURL, endpoint)
|
||||||
}
|
}
|
||||||
route = &model.Route{
|
route = &model.Route{
|
||||||
TargetType: rule.TargetType,
|
TargetType: rule.TargetType,
|
||||||
@@ -227,24 +229,23 @@ func handleProtocolMapping(route *model.Route, npURI, endpoint string) (*model.R
|
|||||||
}
|
}
|
||||||
return &model.Route{
|
return &model.Route{
|
||||||
TargetType: targetTypeURL,
|
TargetType: targetTypeURL,
|
||||||
URL: &url.URL{
|
URL: route.URL,
|
||||||
Scheme: route.URL.Scheme,
|
|
||||||
Host: route.URL.Host,
|
|
||||||
Path: path.Join(route.URL.Path, endpoint),
|
|
||||||
},
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
targetURL, err := url.Parse(target)
|
targetURL, err := url.Parse(target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid %s URI - %s in request body for %s: %w", strings.ToUpper(route.TargetType), target, endpoint, err)
|
return nil, fmt.Errorf("invalid %s URI - %s in request body for %s: %w", strings.ToUpper(route.TargetType), target, endpoint, err)
|
||||||
}
|
}
|
||||||
|
targetURL.Path = joinPath(targetURL, endpoint)
|
||||||
return &model.Route{
|
return &model.Route{
|
||||||
TargetType: targetTypeURL,
|
TargetType: targetTypeURL,
|
||||||
URL: &url.URL{
|
URL: targetURL,
|
||||||
Scheme: targetURL.Scheme,
|
|
||||||
Host: targetURL.Host,
|
|
||||||
Path: path.Join(targetURL.Path, endpoint),
|
|
||||||
},
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func joinPath(u *url.URL, endpoint string) string {
|
||||||
|
if u.Path == "" {
|
||||||
|
u.Path = "/"
|
||||||
|
}
|
||||||
|
return path.Join(u.Path, endpoint)
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,8 +6,11 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/beckn/beckn-onix/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed testData/*
|
//go:embed testData/*
|
||||||
@@ -47,32 +50,36 @@ func setupRouter(t *testing.T, configFile string) (*Router, func() error, string
|
|||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// List of YAML files in the testData directory
|
validConfigFile := "bap_caller.yaml"
|
||||||
yamlFiles := []string{
|
rulesFilePath := setupTestConfig(t, validConfigFile)
|
||||||
"bap_caller.yaml",
|
|
||||||
"bap_receiver.yaml",
|
|
||||||
"bpp_caller.yaml",
|
|
||||||
"bpp_receiver.yaml",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, yamlFile := range yamlFiles {
|
|
||||||
t.Run(yamlFile, func(t *testing.T) {
|
|
||||||
rulesFilePath := setupTestConfig(t, yamlFile)
|
|
||||||
defer os.RemoveAll(filepath.Dir(rulesFilePath))
|
defer os.RemoveAll(filepath.Dir(rulesFilePath))
|
||||||
|
|
||||||
// Define test cases
|
config := &Config{
|
||||||
|
RoutingConfig: rulesFilePath,
|
||||||
|
}
|
||||||
|
|
||||||
|
router, _, err := New(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("New(%v) = %v, want nil error", config, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if router == nil {
|
||||||
|
t.Errorf("New(%v) = nil router, want non-nil", config)
|
||||||
|
}
|
||||||
|
if len(router.rules) == 0 {
|
||||||
|
t.Error("Expected router to have loaded rules, but rules map is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewErrors tests the New function for failure cases.
|
||||||
|
func TestNewErrors(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
config *Config
|
config *Config
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{
|
|
||||||
name: "Valid configuration",
|
|
||||||
config: &Config{
|
|
||||||
RoutingConfig: rulesFilePath,
|
|
||||||
},
|
|
||||||
wantErr: "",
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "Empty config",
|
name: "Empty config",
|
||||||
config: nil,
|
config: nil,
|
||||||
@@ -85,39 +92,100 @@ func TestNew(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantErr: "routingConfig path is empty",
|
wantErr: "routingConfig path is empty",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Routing config file does not exist",
|
|
||||||
config: &Config{
|
|
||||||
RoutingConfig: "/nonexistent/path/to/rules.yaml",
|
|
||||||
},
|
|
||||||
wantErr: "error reading config file",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
router, _, err := New(ctx, tt.config)
|
router, _, err := New(ctx, tt.config)
|
||||||
|
|
||||||
// Check for expected error
|
|
||||||
if tt.wantErr != "" {
|
|
||||||
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||||
t.Errorf("New(%v) = %v, want error containing %q", tt.config, err, tt.wantErr)
|
t.Errorf("New(%v) = %v, want error containing %q", tt.config, err, tt.wantErr)
|
||||||
}
|
}
|
||||||
return
|
if router != nil {
|
||||||
}
|
t.Errorf("New(%v) = %v, want nil router on error", tt.config, router)
|
||||||
|
|
||||||
// Ensure no error occurred
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("New(%v) = %v, want nil error", tt.config, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the router and close function are not nil
|
|
||||||
if router == nil {
|
|
||||||
t.Errorf("New(%v, %v) = nil router, want non-nil", ctx, tt.config)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadRules tests the loadRules function for successful loading and map construction.
|
||||||
|
func TestLoadRules(t *testing.T) {
|
||||||
|
router := &Router{
|
||||||
|
rules: make(map[string]map[string]map[string]*model.Route),
|
||||||
|
}
|
||||||
|
rulesFilePath := setupTestConfig(t, "valid_all_routes.yaml")
|
||||||
|
defer os.RemoveAll(filepath.Dir(rulesFilePath))
|
||||||
|
|
||||||
|
err := router.loadRules(rulesFilePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("loadRules() err = %v, want nil", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expected router.rules map structure based on the yaml.
|
||||||
|
expectedRules := map[string]map[string]map[string]*model.Route{
|
||||||
|
"ONDC:TRV10": {
|
||||||
|
"2.0.0": {
|
||||||
|
"search": {TargetType: targetTypeURL, URL: parseURL(t, "https://mock_gateway.com/v2/ondc/search")},
|
||||||
|
"init": {TargetType: targetTypeBAP, URL: parseURL(t, "https://mock_bpp.com/v2/ondc/init")},
|
||||||
|
"select": {TargetType: targetTypeBAP, URL: parseURL(t, "https://mock_bpp.com/v2/ondc/select")},
|
||||||
|
"on_search": {TargetType: targetTypeBAP, URL: parseURL(t, "https://mock_bap_gateway.com/v2/ondc/on_search")},
|
||||||
|
"confirm": {TargetType: targetTypePublisher, PublisherID: "beckn_onix_topic", URL: nil},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(router.rules, expectedRules) {
|
||||||
|
t.Errorf("Loaded rules mismatch.\nGot:\n%#v\nWant:\n%#v", router.rules, expectedRules)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mustParseURL is a helper for TestLoadRules to parse URLs.
|
||||||
|
func parseURL(t *testing.T, rawURL string) *url.URL {
|
||||||
|
u, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse URL %s: %v", rawURL, err)
|
||||||
|
}
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoadRulesErrors tests the loadRules function for various error cases.
|
||||||
|
func TestLoadRulesErrors(t *testing.T) {
|
||||||
|
router := &Router{
|
||||||
|
rules: make(map[string]map[string]map[string]*model.Route),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
configPath string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty routing config path",
|
||||||
|
configPath: "",
|
||||||
|
wantErr: "routingConfig path is empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Routing config file does not exist",
|
||||||
|
configPath: "/nonexistent/path/to/rules.yaml",
|
||||||
|
wantErr: "error reading config file",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid YAML (Unmarshal error)",
|
||||||
|
configPath: setupTestConfig(t, "invalid_yaml.yaml"),
|
||||||
|
wantErr: "error parsing YAML",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if !strings.Contains(tt.configPath, "/nonexistent/") && tt.configPath != "" {
|
||||||
|
defer os.RemoveAll(filepath.Dir(tt.configPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
err := router.loadRules(tt.configPath)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
|
||||||
|
t.Errorf("loadRules(%q) = %v, want error containing %q", tt.configPath, err, tt.wantErr)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
key: value: invalid
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
# testData/all_route_types.yaml
|
||||||
|
routingRules:
|
||||||
|
- domain: ONDC:TRV10
|
||||||
|
version: 2.0.0
|
||||||
|
targetType: url
|
||||||
|
target:
|
||||||
|
url: https://mock_gateway.com/v2/ondc
|
||||||
|
endpoints:
|
||||||
|
- search
|
||||||
|
- domain: ONDC:TRV10
|
||||||
|
version: 2.0.0
|
||||||
|
targetType: bap
|
||||||
|
target:
|
||||||
|
url: https://mock_bpp.com/v2/ondc
|
||||||
|
endpoints:
|
||||||
|
- init
|
||||||
|
- select
|
||||||
|
- domain: ONDC:TRV10
|
||||||
|
version: 2.0.0
|
||||||
|
targetType: publisher
|
||||||
|
target:
|
||||||
|
publisherId: beckn_onix_topic
|
||||||
|
endpoints:
|
||||||
|
- confirm
|
||||||
|
- domain: ONDC:TRV10
|
||||||
|
version: 2.0.0
|
||||||
|
targetType: bap
|
||||||
|
target:
|
||||||
|
url: https://mock_bap_gateway.com/v2/ondc
|
||||||
|
endpoints:
|
||||||
|
- on_search
|
||||||
Reference in New Issue
Block a user