From c5bfd0a1224e57ce5e3a888cdf07d37f4457fa88 Mon Sep 17 00:00:00 2001 From: rupinder-syngh <118511490+rupinder-syngh@users.noreply.github.com> Date: Fri, 21 Mar 2025 19:23:40 +0530 Subject: [PATCH] feat: decryption plugin (#430) * feat: Implemented decryption plugin * fix: Removed interface test file * fix: Test case * fix: Test case for plugin * fix: test case change * fix: resolved pr comments * fix: resolved pr comments * fix: removed mock dcrypter * fix: formatting * fix: removed config, close function, enhanced test cases * fix: test cases enhancement, formatting --- pkg/plugin/definition/decrypter.go | 15 ++ .../implementation/decrypter/cmd/plugin.go | 19 ++ .../decrypter/cmd/plugin_test.go | 49 ++++ .../implementation/decrypter/decrypter.go | 85 ++++++ .../decrypter/decrypter_test.go | 251 ++++++++++++++++++ pkg/plugin/manager.go | 23 +- 6 files changed, 441 insertions(+), 1 deletion(-) create mode 100644 pkg/plugin/definition/decrypter.go create mode 100644 pkg/plugin/implementation/decrypter/cmd/plugin.go create mode 100644 pkg/plugin/implementation/decrypter/cmd/plugin_test.go create mode 100644 pkg/plugin/implementation/decrypter/decrypter.go create mode 100644 pkg/plugin/implementation/decrypter/decrypter_test.go diff --git a/pkg/plugin/definition/decrypter.go b/pkg/plugin/definition/decrypter.go new file mode 100644 index 0000000..8bd0b6a --- /dev/null +++ b/pkg/plugin/definition/decrypter.go @@ -0,0 +1,15 @@ +package definition + +import "context" + +// Decrypter defines the methods for decryption. +type Decrypter interface { + // Decrypt decrypts the given body using the provided privateKeyBase64 and publicKeyBase64. + Decrypt(ctx context.Context, encryptedData string, privateKeyBase64, publicKeyBase64 string) (string, error) +} + +// DecrypterProvider initializes a new decrypter instance with the given config. +type DecrypterProvider interface { + // New creates a new decrypter instance based on the provided config. + New(ctx context.Context, config map[string]string) (Decrypter, func() error, error) +} diff --git a/pkg/plugin/implementation/decrypter/cmd/plugin.go b/pkg/plugin/implementation/decrypter/cmd/plugin.go new file mode 100644 index 0000000..cb988a9 --- /dev/null +++ b/pkg/plugin/implementation/decrypter/cmd/plugin.go @@ -0,0 +1,19 @@ +package main + +import ( + "context" + + "github.com/beckn/beckn-onix/pkg/plugin/definition" + decrypter "github.com/beckn/beckn-onix/pkg/plugin/implementation/decrypter" +) + +// DecrypterProvider implements the definition.DecrypterProvider interface. +type DecrypterProvider struct{} + +// New creates a new Decrypter instance using the provided configuration. +func (dp DecrypterProvider) New(ctx context.Context, config map[string]string) (definition.Decrypter, func() error, error) { + return decrypter.New(ctx) +} + +// Provider is the exported symbol that the plugin manager will look for. +var Provider definition.DecrypterProvider = DecrypterProvider{} diff --git a/pkg/plugin/implementation/decrypter/cmd/plugin_test.go b/pkg/plugin/implementation/decrypter/cmd/plugin_test.go new file mode 100644 index 0000000..6a4f168 --- /dev/null +++ b/pkg/plugin/implementation/decrypter/cmd/plugin_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "context" + "testing" +) + +func TestDecrypterProviderSuccess(t *testing.T) { + tests := []struct { + name string + ctx context.Context + config map[string]string + }{ + { + name: "Valid context with empty config", + ctx: context.Background(), + config: map[string]string{}, + }, + { + name: "Valid context with non-empty config", + ctx: context.Background(), + config: map[string]string{"key": "value"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider := DecrypterProvider{} + decrypter, cleanup, err := provider.New(tt.ctx, tt.config) + + // Check error. + if err != nil { + t.Errorf("New() error = %v, want no error", err) + } + + // Check decrypter. + if decrypter == nil { + t.Error("New() decrypter is nil, want non-nil") + } + + // Test cleanup function if it exists. + if cleanup != nil { + if err := cleanup(); err != nil { + t.Errorf("cleanup() error = %v", err) + } + } + }) + } +} diff --git a/pkg/plugin/implementation/decrypter/decrypter.go b/pkg/plugin/implementation/decrypter/decrypter.go new file mode 100644 index 0000000..f312f16 --- /dev/null +++ b/pkg/plugin/implementation/decrypter/decrypter.go @@ -0,0 +1,85 @@ +package decryption + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "encoding/base64" + "fmt" + + "github.com/zenazn/pkcs7pad" +) + +// decrypter implements the Decrypter interface and handles the decryption process. +type decrypter struct { +} + +// New creates a new decrypter instance with the given configuration. +func New(ctx context.Context) (*decrypter, func() error, error) { + return &decrypter{}, nil, nil +} + +// Decrypt decrypts the given encryptedData using the provided privateKeyBase64 and publicKeyBase64. +func (d *decrypter) Decrypt(ctx context.Context, encryptedData, privateKeyBase64, publicKeyBase64 string) (string, error) { + privateKeyBytes, err := base64.StdEncoding.DecodeString(privateKeyBase64) + if err != nil { + return "", fmt.Errorf("invalid private key: %w", err) + } + + publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyBase64) + if err != nil { + return "", fmt.Errorf("invalid public key: %w", err) + } + + // Decode the Base64 encoded encrypted data. + messageByte, err := base64.StdEncoding.DecodeString(encryptedData) + if err != nil { + return "", fmt.Errorf("failed to decode encrypted data: %w", err) + } + + aesCipher, err := createAESCipher(privateKeyBytes, publicKeyBytes) + if err != nil { + return "", fmt.Errorf("failed to create AES cipher: %w", err) + } + + blocksize := aesCipher.BlockSize() + if len(messageByte)%blocksize != 0 { + return "", fmt.Errorf("ciphertext is not a multiple of the blocksize") + } + + for i := 0; i < len(messageByte); i += aesCipher.BlockSize() { + executionSlice := messageByte[i : i+aesCipher.BlockSize()] + aesCipher.Decrypt(executionSlice, executionSlice) + } + + messageByte, err = pkcs7pad.Unpad(messageByte) + if err != nil { + return "", fmt.Errorf("failed to unpad data: %w", err) + } + + return string(messageByte), nil +} + +func createAESCipher(privateKey, publicKey []byte) (cipher.Block, error) { + x25519Curve := ecdh.X25519() + x25519PrivateKey, err := x25519Curve.NewPrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create private key: %w", err) + } + x25519PublicKey, err := x25519Curve.NewPublicKey(publicKey) + if err != nil { + return nil, fmt.Errorf("failed to create public key: %w", err) + } + sharedSecret, err := x25519PrivateKey.ECDH(x25519PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to derive shared secret: %w", err) + } + + aesCipher, err := aes.NewCipher(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AES cipher: %w", err) + } + + return aesCipher, nil +} diff --git a/pkg/plugin/implementation/decrypter/decrypter_test.go b/pkg/plugin/implementation/decrypter/decrypter_test.go new file mode 100644 index 0000000..a2bbe11 --- /dev/null +++ b/pkg/plugin/implementation/decrypter/decrypter_test.go @@ -0,0 +1,251 @@ +package decryption + +import ( + "context" + "crypto/aes" + "crypto/ecdh" + "crypto/rand" + "encoding/base64" + "strings" + "testing" + + "github.com/zenazn/pkcs7pad" +) + +// Helper function to generate valid test keys. +func generateTestKeys(t *testing.T) (privateKeyB64, publicKeyB64 string) { + curve := ecdh.X25519() + privateKey, err := curve.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate private key: %v", err) + } + + publicKey := privateKey.PublicKey() + privateKeyB64 = base64.StdEncoding.EncodeToString(privateKey.Bytes()) + publicKeyB64 = base64.StdEncoding.EncodeToString(publicKey.Bytes()) + + return privateKeyB64, publicKeyB64 +} + +// Helper function to encrypt test data. +func encryptTestData(t *testing.T, data []byte, privateKeyBase64, publicKeyBase64 string) string { + privateKeyBytes, err := base64.StdEncoding.DecodeString(privateKeyBase64) + if err != nil { + t.Fatalf("Invalid private key: %v", err) + } + + publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyBase64) + if err != nil { + t.Fatalf("Invalid public key: %v", err) + } + + x25519Curve := ecdh.X25519() + x25519PrivateKey, err := x25519Curve.NewPrivateKey(privateKeyBytes) + if err != nil { + t.Fatalf("Failed to create private key: %v", err) + } + x25519PublicKey, err := x25519Curve.NewPublicKey(publicKeyBytes) + if err != nil { + t.Fatalf("Failed to create public key: %v", err) + } + + // Generate shared secret for encryption. + sharedSecret, err := x25519PrivateKey.ECDH(x25519PublicKey) + if err != nil { + t.Fatalf("Failed to create shared secret: %v", err) + } + + // Create AES cipher. + block, err := aes.NewCipher(sharedSecret) + if err != nil { + t.Fatalf("Failed to create AES cipher: %v", err) + } + + // Pad the data. + paddedData := pkcs7pad.Pad(data, block.BlockSize()) + + // Encrypt the data. + ciphertext := make([]byte, len(paddedData)) + for i := 0; i < len(paddedData); i += block.BlockSize() { + block.Encrypt(ciphertext[i:i+block.BlockSize()], paddedData[i:i+block.BlockSize()]) + } + + return base64.StdEncoding.EncodeToString(ciphertext) +} + +// TestDecrypterSuccess tests successful decryption scenarios. +func TestDecrypterSuccess(t *testing.T) { + senderPrivateKeyB64, senderPublicKeyB64 := generateTestKeys(t) + receiverPrivateKeyB64, receiverPublicKeyB64 := generateTestKeys(t) + + tests := []struct { + name string + data []byte + }{ + { + name: "Valid decryption with small data", + data: []byte("test"), + }, + { + name: "Valid decryption with medium data", + data: []byte("medium length test data that spans multiple blocks"), + }, + { + name: "Valid decryption with empty data", + data: []byte{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Encrypt the test data. + encryptedData := encryptTestData(t, tt.data, senderPrivateKeyB64, receiverPublicKeyB64) + + decrypter, _, err := New(context.Background()) + if err != nil { + t.Fatalf("Failed to create decrypter: %v", err) + } + + result, err := decrypter.Decrypt(context.Background(), encryptedData, receiverPrivateKeyB64, senderPublicKeyB64) + if err != nil { + t.Errorf("Decrypt() error = %v", err) + } + + if err == nil { + if result != string(tt.data) { + t.Errorf("Decrypt() = %v, want %v", result, string(tt.data)) + } + } + }) + } +} + +// TestDecrypterFailure tests various failure scenarios. +func TestDecrypterFailure(t *testing.T) { + _, senderPublicKeyB64 := generateTestKeys(t) + receiverPrivateKeyB64, _ := generateTestKeys(t) + + tests := []struct { + name string + encryptedData string + privateKey string + publicKey string + expectedErr string + }{ + { + name: "Invalid private key format", + encryptedData: base64.StdEncoding.EncodeToString(make([]byte, 32)), + privateKey: "invalid-base64!@#$", + publicKey: senderPublicKeyB64, + expectedErr: "invalid private key", + }, + { + name: "Invalid public key format", + encryptedData: base64.StdEncoding.EncodeToString(make([]byte, 32)), + privateKey: receiverPrivateKeyB64, + publicKey: "invalid-base64!@#$", + expectedErr: "invalid public key", + }, + { + name: "Invalid encrypted data format", + encryptedData: "invalid-base64!@#$", + privateKey: receiverPrivateKeyB64, + publicKey: senderPublicKeyB64, + expectedErr: "failed to decode encrypted data", + }, + { + name: "Empty private key", + encryptedData: base64.StdEncoding.EncodeToString(make([]byte, 32)), + privateKey: "", + publicKey: senderPublicKeyB64, + expectedErr: "invalid private key", + }, + { + name: "Empty public key", + encryptedData: base64.StdEncoding.EncodeToString(make([]byte, 32)), + privateKey: receiverPrivateKeyB64, + publicKey: "", + expectedErr: "invalid public key", + }, + { + name: "Invalid base64 data", + encryptedData: "=invalid-base64", // Invalid encrypted data. + privateKey: receiverPrivateKeyB64, + publicKey: senderPublicKeyB64, + expectedErr: "failed to decode encrypted data", + }, + { + name: "Invalid private key size", + encryptedData: base64.StdEncoding.EncodeToString(make([]byte, 32)), + privateKey: base64.StdEncoding.EncodeToString([]byte("short")), + publicKey: senderPublicKeyB64, + expectedErr: "failed to create private key", + }, + { + name: "Invalid public key size", + encryptedData: base64.StdEncoding.EncodeToString(make([]byte, 32)), + privateKey: receiverPrivateKeyB64, + publicKey: base64.StdEncoding.EncodeToString([]byte("short")), + expectedErr: "failed to create public key", + }, + { + name: "Invalid block size", + encryptedData: base64.StdEncoding.EncodeToString([]byte("not-block-size")), + privateKey: receiverPrivateKeyB64, + publicKey: senderPublicKeyB64, + expectedErr: "ciphertext is not a multiple of the blocksize", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decrypter, _, err := New(context.Background()) + if err != nil { + t.Fatalf("Failed to create decrypter: %v", err) + } + + _, err = decrypter.Decrypt(context.Background(), tt.encryptedData, tt.privateKey, tt.publicKey) + if err == nil { + t.Error("Expected error but got none") + } + + if err != nil { + if !strings.Contains(err.Error(), tt.expectedErr) { + t.Errorf("Expected error containing %q, got %q", tt.expectedErr, err.Error()) + } + } + }) + } +} + +// TestNewDecrypter tests the creation of new Decrypter instances. +func TestNewDecrypter(t *testing.T) { + tests := []struct { + name string + ctx context.Context + }{ + { + name: "Valid context", + ctx: context.Background(), + }, + { + name: "Nil context", + ctx: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decrypter, _, err := New(tt.ctx) + if err != nil { + t.Errorf("New() error = %v", err) + } + + if err == nil { + if decrypter == nil { + t.Error("Expected non-nil decrypter") + } + } + }) + } +} diff --git a/pkg/plugin/manager.go b/pkg/plugin/manager.go index 6fff4bf..86a0b02 100644 --- a/pkg/plugin/manager.go +++ b/pkg/plugin/manager.go @@ -15,6 +15,7 @@ type Config struct { Root string `yaml:"root"` Signer PluginConfig `yaml:"signer"` Verifier PluginConfig `yaml:"verifier"` + Decrypter PluginConfig `yaml:"decrypter"` Encrypter PluginConfig `yaml:"encrypter"` Publisher PluginConfig `yaml:"publisher"` } @@ -29,6 +30,7 @@ type PluginConfig struct { type Manager struct { sp definition.SignerProvider vp definition.VerifierProvider + dp definition.DecrypterProvider ep definition.EncrypterProvider pb definition.PublisherProvider cfg *Config @@ -58,13 +60,19 @@ func NewManager(ctx context.Context, cfg *Config) (*Manager, error) { return nil, fmt.Errorf("failed to load Verifier plugin: %w", err) } + // Load decrypter plugin. + dp, err := provider[definition.DecrypterProvider](cfg.Root, cfg.Decrypter.ID) + if err != nil { + return nil, fmt.Errorf("failed to load Decrypter plugin: %w", err) + } + // Load encryption plugin. ep, err := provider[definition.EncrypterProvider](cfg.Root, cfg.Encrypter.ID) if err != nil { return nil, fmt.Errorf("failed to load encryption plugin: %w", err) } - return &Manager{sp: sp, vp: vp, pb: pb, ep: ep, cfg: cfg}, nil + return &Manager{sp: sp, vp: vp, pb: pb, ep: ep, dp: dp, cfg: cfg}, nil } // provider loads a plugin dynamically and retrieves its provider instance. @@ -123,6 +131,19 @@ func (m *Manager) Verifier(ctx context.Context) (definition.Verifier, func() err return Verifier, close, nil } +// Decrypter retrieves the decryption plugin instance. +func (m *Manager) Decrypter(ctx context.Context) (definition.Decrypter, func() error, error) { + if m.dp == nil { + return nil, nil, fmt.Errorf("decrypter plugin provider not loaded") + } + + decrypter, close, err := m.dp.New(ctx, m.cfg.Decrypter.Config) + if err != nil { + return nil, nil, fmt.Errorf("failed to initialize Decrypter: %w", err) + } + return decrypter, close, nil +} + // Encrypter retrieves the encryption plugin instance. func (m *Manager) Encrypter(ctx context.Context) (definition.Encrypter, func() error, error) { if m.ep == nil {