feat: enhance network membership validation and add tests for extractStringSlice function

This commit is contained in:
Nirmal N R
2026-04-01 18:54:31 +05:30
parent d2d211031b
commit db330663bd
2 changed files with 76 additions and 3 deletions

View File

@@ -165,7 +165,7 @@ func (c *DeDiRegistryClient) Lookup(ctx context.Context, req *model.Subscription
detailsSubscriberID, _ := details["subscriber_id"].(string) detailsSubscriberID, _ := details["subscriber_id"].(string)
// Validate network memberships if configured. // Validate network memberships if configured.
networkMemberships := extractStringSlice(data["network_memberships"]) networkMemberships := extractStringSlice(ctx, "network_memberships", data["network_memberships"])
if len(c.config.AllowedNetworkIDs) > 0 { if len(c.config.AllowedNetworkIDs) > 0 {
if len(networkMemberships) == 0 || !containsAny(networkMemberships, c.config.AllowedNetworkIDs) { if len(networkMemberships) == 0 || !containsAny(networkMemberships, c.config.AllowedNetworkIDs) {
return nil, fmt.Errorf("registry entry with subscriber_id '%s' does not belong to any configured networks (registry.config.allowedNetworkIDs)", detailsSubscriberID) return nil, fmt.Errorf("registry entry with subscriber_id '%s' does not belong to any configured networks (registry.config.allowedNetworkIDs)", detailsSubscriberID)
@@ -210,7 +210,7 @@ func parseTime(timeStr string) time.Time {
return parsedTime return parsedTime
} }
func extractStringSlice(value interface{}) []string { func extractStringSlice(ctx context.Context, fieldName string, value interface{}) []string {
if value == nil { if value == nil {
return nil return nil
} }
@@ -219,9 +219,10 @@ func extractStringSlice(value interface{}) []string {
return v return v
case []interface{}: case []interface{}:
out := make([]string, 0, len(v)) out := make([]string, 0, len(v))
for _, item := range v { for i, item := range v {
str, ok := item.(string) str, ok := item.(string)
if !ok { if !ok {
log.Warnf(ctx, "Ignoring invalid %s entry at index %d during registry lookup: expected a string network ID, got %T. This entry will not be considered for allowlist validation.", fieldName, i, item)
continue continue
} }
if str != "" { if str != "" {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"time" "time"
@@ -111,6 +112,33 @@ func TestNew(t *testing.T) {
}) })
} }
func TestExtractStringSlice(t *testing.T) {
ctx := context.Background()
t.Run("returns strings from []string", func(t *testing.T) {
got := extractStringSlice(ctx, "network_memberships", []string{"commerce-network.org/prod", "local-commerce.org/production"})
want := []string{"commerce-network.org/prod", "local-commerce.org/production"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("expected %v, got %v", want, got)
}
})
t.Run("filters non-string entries from []interface{}", func(t *testing.T) {
got := extractStringSlice(ctx, "network_memberships", []interface{}{"commerce-network.org/prod", 42, true, "", "local-commerce.org/production"})
want := []string{"commerce-network.org/prod", "local-commerce.org/production"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("expected %v, got %v", want, got)
}
})
t.Run("returns nil for unsupported type", func(t *testing.T) {
got := extractStringSlice(ctx, "network_memberships", "commerce-network.org/prod")
if got != nil {
t.Fatalf("expected nil, got %v", got)
}
})
}
func TestLookup(t *testing.T) { func TestLookup(t *testing.T) {
ctx := context.Background() ctx := context.Background()
@@ -284,6 +312,50 @@ func TestLookup(t *testing.T) {
} }
}) })
t.Run("allowed network IDs match with mixed network membership types", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := map[string]interface{}{
"message": "Record retrieved from registry cache",
"data": map[string]interface{}{
"details": map[string]interface{}{
"url": "http://dev.np2.com/beckn/bap",
"type": "BAP",
"domain": "energy",
"subscriber_id": "dev.np2.com",
"signing_public_key": "384qqkIIpxo71WaJPsWqQNWUDGAFnfnJPxuDmtuBiLo=",
},
"network_memberships": []interface{}{123, "commerce-network.org/prod", map[string]interface{}{"invalid": true}},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
config := &Config{
URL: server.URL + "/dedi",
RegistryName: "subscribers.beckn.one",
AllowedNetworkIDs: []string{"commerce-network.org/prod"},
}
client, closer, err := New(ctx, config)
if err != nil {
t.Fatalf("New() error = %v", err)
}
defer closer()
req := &model.Subscription{
Subscriber: model.Subscriber{
SubscriberID: "dev.np2.com",
},
KeyID: "test-key-id",
}
_, err = client.Lookup(ctx, req)
if err != nil {
t.Errorf("Lookup() error = %v", err)
}
})
// Test empty subscriber ID // Test empty subscriber ID
t.Run("empty subscriber ID", func(t *testing.T) { t.Run("empty subscriber ID", func(t *testing.T) {
config := &Config{ config := &Config{