feat: enhance network membership validation and add tests for extractStringSlice function
This commit is contained in:
@@ -165,7 +165,7 @@ func (c *DeDiRegistryClient) Lookup(ctx context.Context, req *model.Subscription
|
||||
detailsSubscriberID, _ := details["subscriber_id"].(string)
|
||||
|
||||
// Validate network memberships if configured.
|
||||
networkMemberships := extractStringSlice(data["network_memberships"])
|
||||
networkMemberships := extractStringSlice(ctx, "network_memberships", data["network_memberships"])
|
||||
if len(c.config.AllowedNetworkIDs) > 0 {
|
||||
if len(networkMemberships) == 0 || !containsAny(networkMemberships, c.config.AllowedNetworkIDs) {
|
||||
return nil, fmt.Errorf("registry entry with subscriber_id '%s' does not belong to any configured networks (registry.config.allowedNetworkIDs)", detailsSubscriberID)
|
||||
@@ -210,7 +210,7 @@ func parseTime(timeStr string) time.Time {
|
||||
return parsedTime
|
||||
}
|
||||
|
||||
func extractStringSlice(value interface{}) []string {
|
||||
func extractStringSlice(ctx context.Context, fieldName string, value interface{}) []string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -219,9 +219,10 @@ func extractStringSlice(value interface{}) []string {
|
||||
return v
|
||||
case []interface{}:
|
||||
out := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
for i, item := range v {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
log.Warnf(ctx, "Ignoring invalid %s entry at index %d during registry lookup: expected a string network ID, got %T. This entry will not be considered for allowlist validation.", fieldName, i, item)
|
||||
continue
|
||||
}
|
||||
if str != "" {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -111,6 +112,33 @@ func TestNew(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractStringSlice(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("returns strings from []string", func(t *testing.T) {
|
||||
got := extractStringSlice(ctx, "network_memberships", []string{"commerce-network.org/prod", "local-commerce.org/production"})
|
||||
want := []string{"commerce-network.org/prod", "local-commerce.org/production"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("expected %v, got %v", want, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filters non-string entries from []interface{}", func(t *testing.T) {
|
||||
got := extractStringSlice(ctx, "network_memberships", []interface{}{"commerce-network.org/prod", 42, true, "", "local-commerce.org/production"})
|
||||
want := []string{"commerce-network.org/prod", "local-commerce.org/production"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("expected %v, got %v", want, got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns nil for unsupported type", func(t *testing.T) {
|
||||
got := extractStringSlice(ctx, "network_memberships", "commerce-network.org/prod")
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil, got %v", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLookup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -284,6 +312,50 @@ func TestLookup(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allowed network IDs match with mixed network membership types", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := map[string]interface{}{
|
||||
"message": "Record retrieved from registry cache",
|
||||
"data": map[string]interface{}{
|
||||
"details": map[string]interface{}{
|
||||
"url": "http://dev.np2.com/beckn/bap",
|
||||
"type": "BAP",
|
||||
"domain": "energy",
|
||||
"subscriber_id": "dev.np2.com",
|
||||
"signing_public_key": "384qqkIIpxo71WaJPsWqQNWUDGAFnfnJPxuDmtuBiLo=",
|
||||
},
|
||||
"network_memberships": []interface{}{123, "commerce-network.org/prod", map[string]interface{}{"invalid": true}},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
config := &Config{
|
||||
URL: server.URL + "/dedi",
|
||||
RegistryName: "subscribers.beckn.one",
|
||||
AllowedNetworkIDs: []string{"commerce-network.org/prod"},
|
||||
}
|
||||
|
||||
client, closer, err := New(ctx, config)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
defer closer()
|
||||
|
||||
req := &model.Subscription{
|
||||
Subscriber: model.Subscriber{
|
||||
SubscriberID: "dev.np2.com",
|
||||
},
|
||||
KeyID: "test-key-id",
|
||||
}
|
||||
_, err = client.Lookup(ctx, req)
|
||||
if err != nil {
|
||||
t.Errorf("Lookup() error = %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test empty subscriber ID
|
||||
t.Run("empty subscriber ID", func(t *testing.T) {
|
||||
config := &Config{
|
||||
|
||||
Reference in New Issue
Block a user