feat: enhance network membership validation and add tests for extractStringSlice function
This commit is contained in:
@@ -165,7 +165,7 @@ func (c *DeDiRegistryClient) Lookup(ctx context.Context, req *model.Subscription
|
|||||||
detailsSubscriberID, _ := details["subscriber_id"].(string)
|
detailsSubscriberID, _ := details["subscriber_id"].(string)
|
||||||
|
|
||||||
// Validate network memberships if configured.
|
// Validate network memberships if configured.
|
||||||
networkMemberships := extractStringSlice(data["network_memberships"])
|
networkMemberships := extractStringSlice(ctx, "network_memberships", data["network_memberships"])
|
||||||
if len(c.config.AllowedNetworkIDs) > 0 {
|
if len(c.config.AllowedNetworkIDs) > 0 {
|
||||||
if len(networkMemberships) == 0 || !containsAny(networkMemberships, c.config.AllowedNetworkIDs) {
|
if len(networkMemberships) == 0 || !containsAny(networkMemberships, c.config.AllowedNetworkIDs) {
|
||||||
return nil, fmt.Errorf("registry entry with subscriber_id '%s' does not belong to any configured networks (registry.config.allowedNetworkIDs)", detailsSubscriberID)
|
return nil, fmt.Errorf("registry entry with subscriber_id '%s' does not belong to any configured networks (registry.config.allowedNetworkIDs)", detailsSubscriberID)
|
||||||
@@ -210,7 +210,7 @@ func parseTime(timeStr string) time.Time {
|
|||||||
return parsedTime
|
return parsedTime
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractStringSlice(value interface{}) []string {
|
func extractStringSlice(ctx context.Context, fieldName string, value interface{}) []string {
|
||||||
if value == nil {
|
if value == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -219,9 +219,10 @@ func extractStringSlice(value interface{}) []string {
|
|||||||
return v
|
return v
|
||||||
case []interface{}:
|
case []interface{}:
|
||||||
out := make([]string, 0, len(v))
|
out := make([]string, 0, len(v))
|
||||||
for _, item := range v {
|
for i, item := range v {
|
||||||
str, ok := item.(string)
|
str, ok := item.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
log.Warnf(ctx, "Ignoring invalid %s entry at index %d during registry lookup: expected a string network ID, got %T. This entry will not be considered for allowlist validation.", fieldName, i, item)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if str != "" {
|
if str != "" {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -111,6 +112,33 @@ func TestNew(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractStringSlice(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("returns strings from []string", func(t *testing.T) {
|
||||||
|
got := extractStringSlice(ctx, "network_memberships", []string{"commerce-network.org/prod", "local-commerce.org/production"})
|
||||||
|
want := []string{"commerce-network.org/prod", "local-commerce.org/production"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("expected %v, got %v", want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("filters non-string entries from []interface{}", func(t *testing.T) {
|
||||||
|
got := extractStringSlice(ctx, "network_memberships", []interface{}{"commerce-network.org/prod", 42, true, "", "local-commerce.org/production"})
|
||||||
|
want := []string{"commerce-network.org/prod", "local-commerce.org/production"}
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("expected %v, got %v", want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns nil for unsupported type", func(t *testing.T) {
|
||||||
|
got := extractStringSlice(ctx, "network_memberships", "commerce-network.org/prod")
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("expected nil, got %v", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestLookup(t *testing.T) {
|
func TestLookup(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
@@ -284,6 +312,50 @@ func TestLookup(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("allowed network IDs match with mixed network membership types", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"message": "Record retrieved from registry cache",
|
||||||
|
"data": map[string]interface{}{
|
||||||
|
"details": map[string]interface{}{
|
||||||
|
"url": "http://dev.np2.com/beckn/bap",
|
||||||
|
"type": "BAP",
|
||||||
|
"domain": "energy",
|
||||||
|
"subscriber_id": "dev.np2.com",
|
||||||
|
"signing_public_key": "384qqkIIpxo71WaJPsWqQNWUDGAFnfnJPxuDmtuBiLo=",
|
||||||
|
},
|
||||||
|
"network_memberships": []interface{}{123, "commerce-network.org/prod", map[string]interface{}{"invalid": true}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
URL: server.URL + "/dedi",
|
||||||
|
RegistryName: "subscribers.beckn.one",
|
||||||
|
AllowedNetworkIDs: []string{"commerce-network.org/prod"},
|
||||||
|
}
|
||||||
|
|
||||||
|
client, closer, err := New(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error = %v", err)
|
||||||
|
}
|
||||||
|
defer closer()
|
||||||
|
|
||||||
|
req := &model.Subscription{
|
||||||
|
Subscriber: model.Subscriber{
|
||||||
|
SubscriberID: "dev.np2.com",
|
||||||
|
},
|
||||||
|
KeyID: "test-key-id",
|
||||||
|
}
|
||||||
|
_, err = client.Lookup(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Lookup() error = %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Test empty subscriber ID
|
// Test empty subscriber ID
|
||||||
t.Run("empty subscriber ID", func(t *testing.T) {
|
t.Run("empty subscriber ID", func(t *testing.T) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
|
|||||||
Reference in New Issue
Block a user