From db330663bde2b088437106906c50c67fa9d98e12 Mon Sep 17 00:00:00 2001 From: Nirmal N R Date: Wed, 1 Apr 2026 18:54:31 +0530 Subject: [PATCH] feat: enhance network membership validation and add tests for extractStringSlice function --- .../dediregistry/dediregistry.go | 7 +- .../dediregistry/dediregistry_test.go | 72 +++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/pkg/plugin/implementation/dediregistry/dediregistry.go b/pkg/plugin/implementation/dediregistry/dediregistry.go index 1cecb2b..787e53a 100644 --- a/pkg/plugin/implementation/dediregistry/dediregistry.go +++ b/pkg/plugin/implementation/dediregistry/dediregistry.go @@ -165,7 +165,7 @@ func (c *DeDiRegistryClient) Lookup(ctx context.Context, req *model.Subscription detailsSubscriberID, _ := details["subscriber_id"].(string) // Validate network memberships if configured. - networkMemberships := extractStringSlice(data["network_memberships"]) + networkMemberships := extractStringSlice(ctx, "network_memberships", data["network_memberships"]) if len(c.config.AllowedNetworkIDs) > 0 { if len(networkMemberships) == 0 || !containsAny(networkMemberships, c.config.AllowedNetworkIDs) { return nil, fmt.Errorf("registry entry with subscriber_id '%s' does not belong to any configured networks (registry.config.allowedNetworkIDs)", detailsSubscriberID) @@ -210,7 +210,7 @@ func parseTime(timeStr string) time.Time { return parsedTime } -func extractStringSlice(value interface{}) []string { +func extractStringSlice(ctx context.Context, fieldName string, value interface{}) []string { if value == nil { return nil } @@ -219,9 +219,10 @@ func extractStringSlice(value interface{}) []string { return v case []interface{}: out := make([]string, 0, len(v)) - for _, item := range v { + for i, item := range v { str, ok := item.(string) if !ok { + log.Warnf(ctx, "Ignoring invalid %s entry at index %d during registry lookup: expected a string network ID, got %T. This entry will not be considered for allowlist validation.", fieldName, i, item) continue } if str != "" { diff --git a/pkg/plugin/implementation/dediregistry/dediregistry_test.go b/pkg/plugin/implementation/dediregistry/dediregistry_test.go index 4f6a0cf..6633298 100644 --- a/pkg/plugin/implementation/dediregistry/dediregistry_test.go +++ b/pkg/plugin/implementation/dediregistry/dediregistry_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "reflect" "testing" "time" @@ -111,6 +112,33 @@ func TestNew(t *testing.T) { }) } +func TestExtractStringSlice(t *testing.T) { + ctx := context.Background() + + t.Run("returns strings from []string", func(t *testing.T) { + got := extractStringSlice(ctx, "network_memberships", []string{"commerce-network.org/prod", "local-commerce.org/production"}) + want := []string{"commerce-network.org/prod", "local-commerce.org/production"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("expected %v, got %v", want, got) + } + }) + + t.Run("filters non-string entries from []interface{}", func(t *testing.T) { + got := extractStringSlice(ctx, "network_memberships", []interface{}{"commerce-network.org/prod", 42, true, "", "local-commerce.org/production"}) + want := []string{"commerce-network.org/prod", "local-commerce.org/production"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("expected %v, got %v", want, got) + } + }) + + t.Run("returns nil for unsupported type", func(t *testing.T) { + got := extractStringSlice(ctx, "network_memberships", "commerce-network.org/prod") + if got != nil { + t.Fatalf("expected nil, got %v", got) + } + }) +} + func TestLookup(t *testing.T) { ctx := context.Background() @@ -284,6 +312,50 @@ func TestLookup(t *testing.T) { } }) + t.Run("allowed network IDs match with mixed network membership types", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "message": "Record retrieved from registry cache", + "data": map[string]interface{}{ + "details": map[string]interface{}{ + "url": "http://dev.np2.com/beckn/bap", + "type": "BAP", + "domain": "energy", + "subscriber_id": "dev.np2.com", + "signing_public_key": "384qqkIIpxo71WaJPsWqQNWUDGAFnfnJPxuDmtuBiLo=", + }, + "network_memberships": []interface{}{123, "commerce-network.org/prod", map[string]interface{}{"invalid": true}}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + config := &Config{ + URL: server.URL + "/dedi", + RegistryName: "subscribers.beckn.one", + AllowedNetworkIDs: []string{"commerce-network.org/prod"}, + } + + client, closer, err := New(ctx, config) + if err != nil { + t.Fatalf("New() error = %v", err) + } + defer closer() + + req := &model.Subscription{ + Subscriber: model.Subscriber{ + SubscriberID: "dev.np2.com", + }, + KeyID: "test-key-id", + } + _, err = client.Lookup(ctx, req) + if err != nil { + t.Errorf("Lookup() error = %v", err) + } + }) + // Test empty subscriber ID t.Run("empty subscriber ID", func(t *testing.T) { config := &Config{