updated as per the review comments
This commit is contained in:
24
pkg/plugin/implementation/signvalidator/cmd/plugin.go
Normal file
24
pkg/plugin/implementation/signvalidator/cmd/plugin.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/definition"
|
||||
"github.com/beckn/beckn-onix/pkg/plugin/implementation/signvalidator"
|
||||
)
|
||||
|
||||
// provider provides instances of Verifier.
|
||||
type provider struct{}
|
||||
|
||||
// New initializes a new Verifier instance.
|
||||
func (vp provider) New(ctx context.Context, config map[string]string) (definition.SignValidator, func() error, error) {
|
||||
if ctx == nil {
|
||||
return nil, nil, errors.New("context cannot be nil")
|
||||
}
|
||||
|
||||
return signvalidator.New(ctx, &signvalidator.Config{})
|
||||
}
|
||||
|
||||
// Provider is the exported symbol that the plugin manager will look for.
|
||||
var Provider = provider{}
|
||||
89
pkg/plugin/implementation/signvalidator/cmd/plugin_test.go
Normal file
89
pkg/plugin/implementation/signvalidator/cmd/plugin_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestVerifierProviderSuccess tests successful creation of a verifier.
|
||||
func TestVerifierProviderSuccess(t *testing.T) {
|
||||
provider := provider{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
config map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Successful creation",
|
||||
ctx: context.Background(),
|
||||
config: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "Nil context",
|
||||
ctx: context.TODO(),
|
||||
config: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "Empty config",
|
||||
ctx: context.Background(),
|
||||
config: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verifier, close, err := provider.New(tt.ctx, tt.config)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, but got: %v", err)
|
||||
}
|
||||
if verifier == nil {
|
||||
t.Fatal("Expected verifier instance to be non-nil")
|
||||
}
|
||||
if close != nil {
|
||||
if err := close(); err != nil {
|
||||
t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifierProviderFailure tests cases where verifier creation should fail.
|
||||
func TestVerifierProviderFailure(t *testing.T) {
|
||||
provider := provider{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
config map[string]string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Nil context failure",
|
||||
ctx: nil,
|
||||
config: map[string]string{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verifierInstance, close, err := provider.New(tt.ctx, tt.config)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("Expected error: %v, but got: %v", tt.wantErr, err)
|
||||
}
|
||||
if verifierInstance != nil {
|
||||
t.Fatal("Expected verifier instance to be nil")
|
||||
}
|
||||
if close != nil {
|
||||
if err := close(); err != nil {
|
||||
t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err)
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
115
pkg/plugin/implementation/signvalidator/signvalidator.go
Normal file
115
pkg/plugin/implementation/signvalidator/signvalidator.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package signvalidator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2b"
|
||||
)
|
||||
|
||||
// Config struct for Verifier.
|
||||
type Config struct {
|
||||
}
|
||||
|
||||
// validator implements the validator interface.
|
||||
type validator struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
// New creates a new Verifier instance.
|
||||
func New(ctx context.Context, config *Config) (*validator, func() error, error) {
|
||||
v := &validator{config: config}
|
||||
|
||||
return v, nil, nil
|
||||
}
|
||||
|
||||
// Verify checks the signature for the given payload and public key.
|
||||
func (v *validator) Validate(ctx context.Context, body []byte, header string, publicKeyBase64 string) error {
|
||||
createdTimestamp, expiredTimestamp, signature, err := parseAuthHeader(header)
|
||||
if err != nil {
|
||||
// TODO: Return appropriate error code when Error Code Handling Module is ready
|
||||
return fmt.Errorf("error parsing header: %w", err)
|
||||
}
|
||||
|
||||
signatureBytes, err := base64.StdEncoding.DecodeString(signature)
|
||||
if err != nil {
|
||||
// TODO: Return appropriate error code when Error Code Handling Module is ready
|
||||
return fmt.Errorf("error decoding signature: %w", err)
|
||||
}
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
if createdTimestamp > currentTime || currentTime > expiredTimestamp {
|
||||
// TODO: Return appropriate error code when Error Code Handling Module is ready
|
||||
return fmt.Errorf("signature is expired or not yet valid")
|
||||
}
|
||||
|
||||
createdTime := time.Unix(createdTimestamp, 0)
|
||||
expiredTime := time.Unix(expiredTimestamp, 0)
|
||||
|
||||
signingString := hash(body, createdTime.Unix(), expiredTime.Unix())
|
||||
|
||||
decodedPublicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
|
||||
if err != nil {
|
||||
// TODO: Return appropriate error code when Error Code Handling Module is ready
|
||||
return fmt.Errorf("error decoding public key: %w", err)
|
||||
}
|
||||
|
||||
if !ed25519.Verify(ed25519.PublicKey(decodedPublicKey), []byte(signingString), signatureBytes) {
|
||||
// TODO: Return appropriate error code when Error Code Handling Module is ready
|
||||
return fmt.Errorf("signature verification failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAuthHeader extracts signature values from the Authorization header.
|
||||
func parseAuthHeader(header string) (int64, int64, string, error) {
|
||||
header = strings.TrimPrefix(header, "Signature ")
|
||||
|
||||
parts := strings.Split(header, ",")
|
||||
signatureMap := make(map[string]string)
|
||||
|
||||
for _, part := range parts {
|
||||
keyValue := strings.SplitN(strings.TrimSpace(part), "=", 2)
|
||||
if len(keyValue) == 2 {
|
||||
key := strings.TrimSpace(keyValue[0])
|
||||
value := strings.Trim(keyValue[1], "\"")
|
||||
signatureMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
createdTimestamp, err := strconv.ParseInt(signatureMap["created"], 10, 64)
|
||||
if err != nil {
|
||||
// TODO: Return appropriate error code when Error Code Handling Module is ready
|
||||
return 0, 0, "", fmt.Errorf("invalid created timestamp: %w", err)
|
||||
}
|
||||
|
||||
expiredTimestamp, err := strconv.ParseInt(signatureMap["expires"], 10, 64)
|
||||
if err != nil {
|
||||
// TODO: Return appropriate error code when Error Code Handling Module is ready
|
||||
return 0, 0, "", fmt.Errorf("invalid expires timestamp: %w", err)
|
||||
}
|
||||
|
||||
signature := signatureMap["signature"]
|
||||
if signature == "" {
|
||||
// TODO: Return appropriate error code when Error Code Handling Module is ready
|
||||
return 0, 0, "", fmt.Errorf("signature missing in header")
|
||||
}
|
||||
|
||||
return createdTimestamp, expiredTimestamp, signature, nil
|
||||
}
|
||||
|
||||
// hash constructs a signing string for verification.
|
||||
func hash(payload []byte, createdTimestamp, expiredTimestamp int64) string {
|
||||
hasher, _ := blake2b.New512(nil)
|
||||
hasher.Write(payload)
|
||||
hashSum := hasher.Sum(nil)
|
||||
digestB64 := base64.StdEncoding.EncodeToString(hashSum)
|
||||
|
||||
return fmt.Sprintf("(created): %d\n(expires): %d\ndigest: BLAKE-512=%s", createdTimestamp, expiredTimestamp, digestB64)
|
||||
}
|
||||
147
pkg/plugin/implementation/signvalidator/signvalidator_test.go
Normal file
147
pkg/plugin/implementation/signvalidator/signvalidator_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package signvalidator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// generateTestKeyPair generates a new ED25519 key pair for testing.
|
||||
func generateTestKeyPair() (string, string) {
|
||||
publicKey, privateKey, _ := ed25519.GenerateKey(nil)
|
||||
return base64.StdEncoding.EncodeToString(privateKey), base64.StdEncoding.EncodeToString(publicKey)
|
||||
}
|
||||
|
||||
// signTestData creates a valid signature for test cases.
|
||||
func signTestData(privateKeyBase64 string, body []byte, createdAt, expiresAt int64) string {
|
||||
privateKeyBytes, _ := base64.StdEncoding.DecodeString(privateKeyBase64)
|
||||
privateKey := ed25519.PrivateKey(privateKeyBytes)
|
||||
|
||||
signingString := hash(body, createdAt, expiresAt)
|
||||
signature := ed25519.Sign(privateKey, []byte(signingString))
|
||||
|
||||
return base64.StdEncoding.EncodeToString(signature)
|
||||
}
|
||||
|
||||
// TestVerifySuccessCases tests all valid signature verification cases.
|
||||
func TestVerifySuccess(t *testing.T) {
|
||||
privateKeyBase64, publicKeyBase64 := generateTestKeyPair()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
createdAt int64
|
||||
expiresAt int64
|
||||
}{
|
||||
{
|
||||
name: "Valid Signature",
|
||||
body: []byte("Test Payload"),
|
||||
createdAt: time.Now().Unix(),
|
||||
expiresAt: time.Now().Unix() + 3600,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
signature := signTestData(privateKeyBase64, tt.body, tt.createdAt, tt.expiresAt)
|
||||
header := `Signature created="` + strconv.FormatInt(tt.createdAt, 10) +
|
||||
`", expires="` + strconv.FormatInt(tt.expiresAt, 10) +
|
||||
`", signature="` + signature + `"`
|
||||
|
||||
verifier, close, _ := New(context.Background(), &Config{})
|
||||
err := verifier.Validate(context.Background(), tt.body, header, publicKeyBase64)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, but got: %v", err)
|
||||
}
|
||||
if close != nil {
|
||||
if err := close(); err != nil {
|
||||
t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyFailureCases tests all invalid signature verification cases.
|
||||
func TestVerifyFailure(t *testing.T) {
|
||||
privateKeyBase64, publicKeyBase64 := generateTestKeyPair()
|
||||
_, wrongPublicKeyBase64 := generateTestKeyPair()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
header string
|
||||
pubKey string
|
||||
}{
|
||||
{
|
||||
name: "Missing Authorization Header",
|
||||
body: []byte("Test Payload"),
|
||||
header: "",
|
||||
pubKey: publicKeyBase64,
|
||||
},
|
||||
{
|
||||
name: "Malformed Header",
|
||||
body: []byte("Test Payload"),
|
||||
header: `InvalidSignature created="wrong"`,
|
||||
pubKey: publicKeyBase64,
|
||||
},
|
||||
{
|
||||
name: "Invalid Base64 Signature",
|
||||
body: []byte("Test Payload"),
|
||||
header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) +
|
||||
`", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) +
|
||||
`", signature="!!INVALIDBASE64!!"`,
|
||||
pubKey: publicKeyBase64,
|
||||
},
|
||||
{
|
||||
name: "Expired Signature",
|
||||
body: []byte("Test Payload"),
|
||||
header: `Signature created="` + strconv.FormatInt(time.Now().Unix()-7200, 10) +
|
||||
`", expires="` + strconv.FormatInt(time.Now().Unix()-3600, 10) +
|
||||
`", signature="` + signTestData(privateKeyBase64, []byte("Test Payload"), time.Now().Unix()-7200, time.Now().Unix()-3600) + `"`,
|
||||
pubKey: publicKeyBase64,
|
||||
},
|
||||
{
|
||||
name: "Invalid Public Key",
|
||||
body: []byte("Test Payload"),
|
||||
header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) +
|
||||
`", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) +
|
||||
`", signature="` + signTestData(privateKeyBase64, []byte("Test Payload"), time.Now().Unix(), time.Now().Unix()+3600) + `"`,
|
||||
pubKey: wrongPublicKeyBase64,
|
||||
},
|
||||
{
|
||||
name: "Invalid Expires Timestamp",
|
||||
body: []byte("Test Payload"),
|
||||
header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) +
|
||||
`", expires="invalid_timestamp"`,
|
||||
pubKey: publicKeyBase64,
|
||||
},
|
||||
{
|
||||
name: "Signature Missing in Headers",
|
||||
body: []byte("Test Payload"),
|
||||
header: `Signature created="` + strconv.FormatInt(time.Now().Unix(), 10) +
|
||||
`", expires="` + strconv.FormatInt(time.Now().Unix()+3600, 10) + `"`,
|
||||
pubKey: publicKeyBase64,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verifier, close, _ := New(context.Background(), &Config{})
|
||||
err := verifier.Validate(context.Background(), tt.body, tt.header, tt.pubKey)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("Expected an error but got none")
|
||||
}
|
||||
if close != nil {
|
||||
if err := close(); err != nil {
|
||||
t.Fatalf("Test %q failed: cleanup function returned an error: %v", tt.name, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user