diff --git a/.github/workflows/build-and-deploy-plugins.yml b/.github/workflows/build-and-deploy-plugins.yml new file mode 100644 index 0000000..79c33e0 --- /dev/null +++ b/.github/workflows/build-and-deploy-plugins.yml @@ -0,0 +1,116 @@ +name: Build and Upload Plugins + +on: + workflow_dispatch: + inputs: + target_branch: + description: 'Branch to deploy' + required: true + default: 'beckn-onix-v1.0-develop' + + +jobs: + build-and-upload: + runs-on: ubuntu-latest + env: + GCS_BUCKET: ${{ secrets.GCS_BUCKET }} + PLUGIN_OUTPUT_DIR: ./generated + ZIP_FILE: plugins_bundle.zip + + steps: + - name: Checkout this repo + uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.target_branch }} + + - name: Show selected branch + run: echo "Deploying branch:${{ github.event.inputs.target_branch }}" + + + - name: Clone GitHub and Gerrit plugin repos + run: | + # Example GitHub clone + git clone -b beckn-onix-v1.0-develop https://${{ secrets.PAT_GITHUB }}:@github.com/beckn/beckn-onix.git github-repo + + # Example Gerrit clone + git clone https://${{ secrets.GERRIT_USERNAME }}:${{ secrets.GERRIT_PAT }}@open-networks.googlesource.com/onix-dev gerrit-repo + + - name: List directory structure + run: | + echo "📂 Contents of root:" + ls -alh + + echo "📂 Contents of GitHub repo:" + ls -alh github-repo + + echo "📂 Deep list of GitHub repo:" + find github-repo + + echo "📂 Contents of Gerrit repo:" + ls -alh gerrit-repo + + echo "📂 Deep list of Gerrit repo:" + find gerrit-repo + + + - name: Build Go plugins in Docker + run: | + set -e + mkdir -p $PLUGIN_OUTPUT_DIR + + BUILD_CMDS="" + + # GitHub plugins + for dir in github-repo/pkg/plugin/implementation/*; do + if [ -d "$dir/cmd" ]; then + plugin=$(basename "$dir") + BUILD_CMDS+="cd github-repo && go build -buildmode=plugin -buildvcs=false -o ../${PLUGIN_OUTPUT_DIR}/${plugin}.so ./pkg/plugin/implementation/${plugin}/cmd && cd - && " + fi + done + + # Gerrit plugins — build in their own repo/module context + for dir in gerrit-repo/plugins/*; do + if [ -d "$dir/cmd" ]; then + plugin=$(basename "$dir") + BUILD_CMDS+="cd gerrit-repo && go build -buildmode=plugin -buildvcs=false -o ../${PLUGIN_OUTPUT_DIR}/${plugin}.so ./plugins/${plugin}/cmd && cd - && " + fi + done + + BUILD_CMDS=${BUILD_CMDS%" && "} + echo "🛠️ Running build commands inside Docker:" + echo "$BUILD_CMDS" + + docker run --rm -v "$(pwd)":/app -w /app golang:1.24-bullseye sh -c "$BUILD_CMDS" + + - name: List built plugin files + run: | + echo "Looking in $PLUGIN_OUTPUT_DIR" + ls -lh $PLUGIN_OUTPUT_DIR || echo "⚠️ Directory does not exist" + find $PLUGIN_OUTPUT_DIR -name '*.so' || echo "⚠️ No .so files found" + + echo "Creating zip archive..." + cd "$PLUGIN_OUTPUT_DIR" + zip -r "../$ZIP_FILE" *.so + echo "Created $ZIP_FILE" + cd .. + + - name: List zip output + run: | + ls -lh plugins_bundle.zip + + + - name: Authenticate to GCP + run: | + echo '${{ secrets.GOOGLE_APPLICATION_CREDENTIALS_JSON }}' > gcloud-key.json + gcloud auth activate-service-account --key-file=gcloud-key.json + gcloud config set project trusty-relic-370809 + env: + GOOGLE_APPLICATION_CREDENTIALS: gcloud-key.json + + - name: Upload to GCS + run: | + gsutil -m cp -r $ZIP_FILE gs://${GCS_BUCKET}/plugins/ + + - name: Cleanup + run: | + rm -rf $PLUGIN_OUTPUT_DIR $ZIP_FILE gcloud-key.json diff --git a/.github/workflows/deploy-to-gke-BS.yml b/.github/workflows/deploy-to-gke-BS.yml new file mode 100644 index 0000000..d068829 --- /dev/null +++ b/.github/workflows/deploy-to-gke-BS.yml @@ -0,0 +1,57 @@ +name: CI/CD to GKE updated + +on: + #push: + workflow_dispatch: + +jobs: + deploy: + name: Build and Deploy to GKE + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v2 + with: + credentials_json: '${{ secrets.GOOGLE_APPLICATION_CREDENTIALS_JSON }}' + + - name: Set up gcloud CLI + uses: google-github-actions/setup-gcloud@v1 + with: + project_id: ${{ secrets.GCP_PROJECT }} + export_default_credentials: true + + - name: Install GKE Auth Plugin + run: gcloud components install gke-gcloud-auth-plugin --quiet + + - name: Configure Docker to use Artifact Registry + run: gcloud auth configure-docker ${{ secrets.GCP_REGION }}-docker.pkg.dev + + - name: Build Docker Image + run: | + IMAGE_NAME=${{ secrets.GCP_REGION }}-docker.pkg.dev/${{ secrets.GCP_PROJECT }}/${{ secrets.GCP_REPO }}/beckn-onix:${{ github.sha }} + docker build -f Dockerfile.adapter -t $IMAGE_NAME . + docker push $IMAGE_NAME + + - name: Get GKE Credentials + run: | + gcloud container clusters get-credentials ${{ secrets.GKE_CLUSTER }} \ + --zone ${{ secrets.GCP_REGION }} \ + --project ${{ secrets.GCP_PROJECT }} + + - name: Deploy to GKE using Kubernetes Manifests + run: | + IMAGE_NAME=${{ secrets.GCP_REGION }}-docker.pkg.dev/${{ secrets.GCP_PROJECT }}/${{ secrets.GCP_REPO }}/beckn-onix:${{ github.sha }} + + # Replace image in deployment YAML + sed -i "s|image: .*|image: $IMAGE_NAME|g" Deployment/deployment.yaml + + # Apply Kubernetes manifests + kubectl apply -f Deployment/deployment.yaml --namespace=onix-adapter + kubectl apply -f Deployment/service.yaml --namespace=onix-adapter + + # Wait for rollout to complete + kubectl rollout status Deployment/onix-demo-adapter --namespace=onix-adapter diff --git a/.github/workflows/deploy-to-gke.yml b/.github/workflows/deploy-to-gke.yml new file mode 100644 index 0000000..bcb43ef --- /dev/null +++ b/.github/workflows/deploy-to-gke.yml @@ -0,0 +1,44 @@ +name: Deploy to GKE + +on: + workflow_dispatch: + inputs: + service_name: + description: 'Name of the Kubernetes service to deploy' + required: true + type: string + cluster_name: + description: 'Name of the GKE cluster' + required: true + type: string + +jobs: + deploy: + runs-on: ubuntu-latest + + env: + PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} + REGION: ${{ secrets.GCP_REGION }} + GKE_CLUSTER: ${{ github.event.inputs.cluster_name }} + SERVICE_NAME: ${{ github.event.inputs.service_name }} + + steps: + - name: Checkout source + uses: actions/checkout@v3 + + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v2 + with: + credentials_json: ${{ secrets.GCP_SA_KEY }} + + - name: Set up GKE credentials + uses: google-github-actions/get-gke-credentials@v1 + with: + cluster_name: ${{ env.GKE_CLUSTER }} + location: ${{ env.REGION }} + project_id: ${{ env.PROJECT_ID }} + + - name: Deploy to GKE + run: | + echo "Deploying service $SERVICE_NAME to cluster $GKE_CLUSTER" + kubectl set image deployment/$SERVICE_NAME $SERVICE_NAME=gcr.io/$PROJECT_ID/$SERVICE_NAME:latest --record diff --git a/.github/workflows/onix-gcp-terraform-deploy.yml b/.github/workflows/onix-gcp-terraform-deploy.yml new file mode 100644 index 0000000..478979b --- /dev/null +++ b/.github/workflows/onix-gcp-terraform-deploy.yml @@ -0,0 +1,66 @@ +name: Terraform Deploy to GCP + +on: + push: + workflow_dispatch: # Manual triggerr + +jobs: + plan: + name: Terraform Plan Only + runs-on: ubuntu-latest + + steps: + - name: Checkout this repository + uses: actions/checkout@v3 + + - name: Clone Terraform repo from Gerrit + run: | + git clone https://${{ secrets.GERRIT_USERNAME }}:${{ secrets.GERRIT_PAT }}@open-networks.googlesource.com/onix-dev gerrit-repo + echo "==== Contents of Terraform-dir ====" + pwd + cd gerrit-repo/Terraform-CICD + pwd + ls -la + + - name: Authenticate to Google Cloud + run: echo '${{ secrets.GOOGLE_APPLICATION_CREDENTIALS_JSON }}' > gcp-key.json + + - name: Set up Terraform + uses: hashicorp/setup-terraform@v3 + with: + terraform_version: 1.5.0 + + - name: Write GCP credentials to file + run: echo '${{ secrets.GOOGLE_APPLICATION_CREDENTIALS_JSON }}' > gcp-key.json + + - name: Export GCP credentials environment variable + run: echo "GOOGLE_APPLICATION_CREDENTIALS=$GITHUB_WORKSPACE/gcp-key.json" >> $GITHUB_ENV + + - name: Create backend.tf and Terraform Init + working-directory: ./gerrit-repo/Terraform-CICD + env: + GCS_BUCKET: beckn-cicd-tf-state-bucket + run: | + cat < backend.tf + terraform { + backend "gcs" { + bucket = "${GCS_BUCKET}" + prefix = "terraform/state" + credentials = "${{ github.workspace }}/gcp-key.json" + } + } + EOF + + terraform init + + - name: Terraform Plan + working-directory: ./gerrit-repo/Terraform-CICD + run: terraform plan + + - name: Terraform Apply + working-directory: ./gerrit-repo/Terraform-CICD + run: terraform apply -var="subnet_name=onix-gke-subnet" -auto-approve + + - name: Clean up credentials + run: rm -f gcp-key.json + diff --git a/Deployment/deployment.yaml b/Deployment/deployment.yaml new file mode 100644 index 0000000..561fdb5 --- /dev/null +++ b/Deployment/deployment.yaml @@ -0,0 +1,40 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: onix-demo-adapter + namespace: onix-adapter #------ +spec: + replicas: 1 + selector: + matchLabels: + app: onix-adapter + template: + metadata: + labels: + app: onix-adapter + annotations: + gke-gcsfuse/volumes: "true" + spec: + serviceAccountName: "onix-adapter-ksa" #----------- + containers: + - name: onix-adapter + image: "asia-south1-docker.pkg.dev/trusty-relic-370809/onix-adapter-cicd/beckn-onix:latest" #------ + ports: + - containerPort: 8080 + env: + - name: CONFIG_FILE + value: "/mnt/gcs/configs/onix-adapter.yaml" # Updated to GCS path + + volumeMounts: + - name: gcs-bucket + mountPath: /mnt/gcs + readOnly: false + + volumes: + - name: gcs-bucket + csi: + driver: gcsfuse.csi.storage.gke.io + readOnly: false + volumeAttributes: + bucketName: "beckn-cicd-bucket" #---------- + mountOptions: "implicit-dirs" diff --git a/Deployment/service.yaml b/Deployment/service.yaml new file mode 100644 index 0000000..c4be0d8 --- /dev/null +++ b/Deployment/service.yaml @@ -0,0 +1,13 @@ +apiVersion: v1 +kind: Service +metadata: + name: onix-adapter-service + namespace: onix-adapter # Namespace +spec: + selector: + app: onix-adapter # This should match the app name in deployment.yaml + ports: + - protocol: TCP + port: 80 + targetPort: 8080 + type: LoadBalancer #NodePort or LoadBalancer diff --git a/core/module/handler/step.go b/core/module/handler/step.go index 936ee98..3998986 100644 --- a/core/module/handler/step.go +++ b/core/module/handler/step.go @@ -31,18 +31,18 @@ func newSignStep(signer definition.Signer, km definition.KeyManager) (definition // Run executes the signing step. func (s *signStep) Run(ctx *model.StepContext) error { - keyID, key, err := s.km.SigningPrivateKey(ctx, ctx.SubID) + keySet, err := s.km.Keyset(ctx, ctx.SubID) if err != nil { return fmt.Errorf("failed to get signing key: %w", err) } createdAt := time.Now().Unix() validTill := time.Now().Add(5 * time.Minute).Unix() - sign, err := s.signer.Sign(ctx, ctx.Body, key, createdAt, validTill) + sign, err := s.signer.Sign(ctx, ctx.Body, keySet.SigningPrivate, createdAt, validTill) if err != nil { return fmt.Errorf("failed to sign request: %w", err) } - authHeader := s.generateAuthHeader(ctx.SubID, keyID, createdAt, validTill, sign) + authHeader := s.generateAuthHeader(ctx.SubID, keySet.UniqueKeyID, createdAt, validTill, sign) header := model.AuthHeaderSubscriber if ctx.Role == model.RoleGateway { @@ -107,13 +107,12 @@ func (s *validateSignStep) validate(ctx *model.StepContext, value string) error if len(ids) < 2 || len(headerParts) < 3 { return fmt.Errorf("malformed sign header") } - subID := ids[1] keyID := headerParts[1] - key, err := s.km.SigningPublicKey(ctx, subID, keyID) + signingPublicKey, _, err := s.km.LookupNPKeys(ctx, ctx.SubID, keyID) if err != nil { return fmt.Errorf("failed to get validation key: %w", err) } - if err := s.validator.Validate(ctx, ctx.Body, value, key); err != nil { + if err := s.validator.Validate(ctx, ctx.Body, value, signingPublicKey); err != nil { return fmt.Errorf("sign validation failed: %w", err) } return nil diff --git a/go.mod b/go.mod index c00aa40..5de7bcb 100644 --- a/go.mod +++ b/go.mod @@ -15,9 +15,9 @@ require ( require github.com/stretchr/testify v1.10.0 require ( - github.com/davecgh/go-spew v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/stretchr/objx v0.5.2 // indirect gopkg.in/yaml.v3 v3.0.1 ) @@ -29,14 +29,32 @@ require golang.org/x/text v0.23.0 // indirect require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/go-jose/go-jose/v4 v4.0.1 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/hashicorp/go-rootcerts v1.0.2 // indirect + github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 // indirect + github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect + github.com/hashicorp/go-sockaddr v1.0.2 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/ryanuber/go-glob v1.0.0 // indirect + golang.org/x/net v0.38.0 // indirect golang.org/x/sys v0.31.0 // indirect + golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 // indirect ) require ( + github.com/google/uuid v1.6.0 github.com/hashicorp/go-retryablehttp v0.7.7 github.com/redis/go-redis/v9 v9.8.0 + github.com/hashicorp/vault/api v1.16.0 + github.com/rabbitmq/amqp091-go v1.10.0 github.com/rs/zerolog v1.34.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v2 v2.4.0 diff --git a/go.sum b/go.sum index ccaf9d5..f7bbfb0 100644 --- a/go.sum +++ b/go.sum @@ -4,23 +4,57 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWqS6U= +github.com/go-jose/go-jose/v4 v4.0.1/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY= +github.com/go-test/deep v1.0.2 h1:onZX1rnHT3Wv6cqNgYyFOOlgVKJrksuCMCRvJStbMYw= +github.com/go-test/deep v1.0.2/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= +github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= +github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6 h1:om4Al8Oy7kCm/B86rLCLah4Dt5Aa0Fr5rYBG60OzwHQ= +github.com/hashicorp/go-secure-stdlib/parseutil v0.1.6/go.mod h1:QmrqtbKuxxSWTN3ETMPuB+VtEiBJ/A9XhoYGv8E1uD8= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.1/go.mod h1:gKOamz3EwoIoJq7mlMIRBpVTAUn8qPCrEclOKKWhD3U= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= +github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= +github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/vault/api v1.16.0 h1:nbEYGJiAPGzT9U4oWgaaB0g+Rj8E59QuHKyA5LhwQN4= +github.com/hashicorp/vault/api v1.16.0/go.mod h1:KhuUhzOD8lDSk29AtzNjgAu2kxRA9jL9NAbkFlqvkBA= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -28,34 +62,57 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/go-wordwrap v1.0.0/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw= +github.com/rabbitmq/amqp091-go v1.10.0/go.mod h1:Hy4jKW5kQART1u+JkDTF9YYOQUHXqMuhrgxOEeS7G4o= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 h1:PKK9DyHxif4LZo+uQSgXNqs0jj5+xZwwfKHgph2lxBw= github.com/santhosh-tekuri/jsonschema/v6 v6.0.1/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/zenazn/pkcs7pad v0.0.0-20170308005700-253a5b1f0e03 h1:m1h+vudopHsI67FPT9MOncyndWhTcdUoBtI1R1uajGY= github.com/zenazn/pkcs7pad v0.0.0-20170308005700-253a5b1f0e03/go.mod h1:8sheVFH84v3PCyFY/O02mIgSQY9I6wMYPWsq7mDnEZY= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -63,6 +120,8 @@ golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI= +golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= @@ -70,5 +129,6 @@ gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= \ No newline at end of file diff --git a/pkg/model/model.go b/pkg/model/model.go index a91e2a3..004bf23 100644 --- a/pkg/model/model.go +++ b/pkg/model/model.go @@ -10,24 +10,24 @@ import ( // Subscriber represents a unique operational configuration of a trusted platform on a network. type Subscriber struct { - SubscriberID string `json:"subscriber_id"` - URL string `json:"url" format:"uri"` - Type string `json:"type" enum:"BAP,BPP,BG"` - Domain string `json:"domain"` + SubscriberID string `json:"subscriber_id,omitzero"` + URL string `json:"url,omitzero" format:"uri"` + Type string `json:"type,omitzero" enum:"BAP,BPP,BG"` + Domain string `json:"domain,omitzero"` } // Subscription represents subscription details of a network participant. type Subscription struct { Subscriber `json:",inline"` - KeyID string `json:"key_id" format:"uuid"` - SigningPublicKey string `json:"signing_public_key"` - EncrPublicKey string `json:"encr_public_key"` - ValidFrom time.Time `json:"valid_from" format:"date-time"` - ValidUntil time.Time `json:"valid_until" format:"date-time"` - Status string `json:"status" enum:"INITIATED,UNDER_SUBSCRIPTION,SUBSCRIBED,EXPIRED,UNSUBSCRIBED,INVALID_SSL"` - Created time.Time `json:"created" format:"date-time"` - Updated time.Time `json:"updated" format:"date-time"` - Nonce string + KeyID string `json:"key_id,omitzero" format:"uuid"` + SigningPublicKey string `json:"signing_public_key,omitzero"` + EncrPublicKey string `json:"encr_public_key,omitzero"` + ValidFrom time.Time `json:"valid_from,omitzero" format:"date-time"` + ValidUntil time.Time `json:"valid_until,omitzero" format:"date-time"` + Status string `json:"status,omitzero" enum:"INITIATED,UNDER_SUBSCRIPTION,SUBSCRIBED,EXPIRED,UNSUBSCRIBED,INVALID_SSL"` + Created time.Time `json:"created,omitzero" format:"date-time"` + Updated time.Time `json:"updated,omitzero" format:"date-time"` + Nonce string `json:"nonce,omitzero"` } // Authorization-related constants for headers. diff --git a/pkg/plugin/definition/keymanager.go b/pkg/plugin/definition/keymanager.go index f2c0e2f..8bcb5bb 100644 --- a/pkg/plugin/definition/keymanager.go +++ b/pkg/plugin/definition/keymanager.go @@ -8,13 +8,11 @@ import ( // KeyManager defines the interface for key management operations/methods. type KeyManager interface { - GenerateKeyPairs() (*model.Keyset, error) - StorePrivateKeys(ctx context.Context, keyID string, keys *model.Keyset) error - SigningPrivateKey(ctx context.Context, keyID string) (string, string, error) - EncrPrivateKey(ctx context.Context, keyID string) (string, string, error) - SigningPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) - EncrPublicKey(ctx context.Context, subscriberID, uniqueKeyID string) (string, error) - DeletePrivateKeys(ctx context.Context, keyID string) error + GenerateKeyset() (*model.Keyset, error) + InsertKeyset(ctx context.Context, keyID string, keyset *model.Keyset) error + Keyset(ctx context.Context, keyID string) (*model.Keyset, error) + LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (signingPublicKey string, encrPublicKey string, err error) + DeleteKeyset(ctx context.Context, keyID string) error } // KeyManagerProvider initializes a new signer instance. diff --git a/pkg/plugin/definition/publisher.go b/pkg/plugin/definition/publisher.go index 1e744da..55ed217 100644 --- a/pkg/plugin/definition/publisher.go +++ b/pkg/plugin/definition/publisher.go @@ -8,6 +8,7 @@ type Publisher interface { Publish(context.Context, string, []byte) error } +// PublisherProvider is the interface for creating new Publisher instances. type PublisherProvider interface { // New initializes a new publisher instance with the given configuration. New(ctx context.Context, config map[string]string) (Publisher, func() error, error) diff --git a/pkg/plugin/implementation/keymanager/cmd/plugin.go b/pkg/plugin/implementation/keymanager/cmd/plugin.go new file mode 100644 index 0000000..f7579e9 --- /dev/null +++ b/pkg/plugin/implementation/keymanager/cmd/plugin.go @@ -0,0 +1,34 @@ +package main + +import ( + "context" + + "github.com/beckn/beckn-onix/pkg/log" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/beckn/beckn-onix/pkg/plugin/implementation/keymanager" +) + +// keyManagerProvider implements the plugin provider for the KeyManager plugin. +type keyManagerProvider struct{} + +// newKeyManagerFunc is a function type that creates a new KeyManager instance. +var newKeyManagerFunc = keymanager.New + +// New creates and initializes a new KeyManager instance using the provided cache, registry lookup, and configuration. +func (k *keyManagerProvider) New(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg map[string]string) (definition.KeyManager, func() error, error) { + config := &keymanager.Config{ + VaultAddr: cfg["vaultAddr"], + KVVersion: cfg["kvVersion"], + } + log.Debugf(ctx, "Keymanager config mapped: %+v", cfg) + km, cleanup, err := newKeyManagerFunc(ctx, cache, registry, config) + if err != nil { + log.Error(ctx, err, "Failed to initialize KeyManager") + return nil, nil, err + } + log.Debugf(ctx, "KeyManager instance created successfully") + return km, cleanup, nil +} + +// Provider is the exported instance of keyManagerProvider used for plugin registration. +var Provider = keyManagerProvider{} diff --git a/pkg/plugin/implementation/keymanager/cmd/plugin_test.go b/pkg/plugin/implementation/keymanager/cmd/plugin_test.go new file mode 100644 index 0000000..bafd4bb --- /dev/null +++ b/pkg/plugin/implementation/keymanager/cmd/plugin_test.go @@ -0,0 +1,127 @@ +package main + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/beckn/beckn-onix/pkg/model" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/beckn/beckn-onix/pkg/plugin/implementation/keymanager" +) + +type mockRegistry struct { + LookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) +} + +func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + if m.LookupFunc != nil { + return m.LookupFunc(ctx, sub) + } + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + URL: "https://mock.registry/subscriber", + Type: "BPP", + Domain: "retail", + }, + KeyID: sub.KeyID, + SigningPublicKey: "mock-signing-public-key", + EncrPublicKey: "mock-encryption-public-key", + ValidFrom: time.Now().Add(-time.Hour), + ValidUntil: time.Now().Add(time.Hour), + Status: "SUBSCRIBED", + Created: time.Now().Add(-2 * time.Hour), + Updated: time.Now(), + Nonce: "mock-nonce", + }, + }, nil +} + +type mockCache struct{} + +func (m *mockCache) Get(ctx context.Context, key string) (string, error) { + return "", nil +} +func (m *mockCache) Set(ctx context.Context, key string, value string, ttl time.Duration) error { + return nil +} +func (m *mockCache) Clear(ctx context.Context) error { + return nil +} + +func (m *mockCache) Delete(ctx context.Context, key string) error { + return nil +} + +func TestNewSuccess(t *testing.T) { + // Setup dummy implementations and variables + ctx := context.Background() + cache := &mockCache{} + registry := &mockRegistry{} + cfg := map[string]string{ + "vaultAddr": "http://dummy-vault", + "kvVersion": "2", + } + + cleanupCalled := false + fakeCleanup := func() error { + cleanupCalled = true + return nil + } + + newKeyManagerFunc = func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (*keymanager.KeyMgr, func() error, error) { + // return a mock struct pointer of *keymanager.KeyMgr or a stub instance + return &keymanager.KeyMgr{}, fakeCleanup, nil + } + + // Create provider and call New + provider := &keyManagerProvider{} + km, cleanup, err := provider.New(ctx, cache, registry, cfg) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if km == nil { + t.Fatal("Expected non-nil KeyManager instance") + } + if cleanup == nil { + t.Fatal("Expected non-nil cleanup function") + } + + // Call cleanup and check if it behaves correctly + if err := cleanup(); err != nil { + t.Fatalf("Expected no error from cleanup, got %v", err) + } + if !cleanupCalled { + t.Error("Expected cleanup function to be called") + } +} + +func TestNewFailure(t *testing.T) { + // Setup dummy variables + ctx := context.Background() + cache := &mockCache{} + registry := &mockRegistry{} + cfg := map[string]string{ + "vaultAddr": "http://dummy-vault", + "kvVersion": "2", + } + + newKeyManagerFunc = func(ctx context.Context, cache definition.Cache, registry definition.RegistryLookup, cfg *keymanager.Config) (*keymanager.KeyMgr, func() error, error) { + return nil, nil, fmt.Errorf("some error") + } + + provider := &keyManagerProvider{} + km, cleanup, err := provider.New(ctx, cache, registry, cfg) + if err == nil { + t.Fatal("Expected error, got nil") + } + if km != nil { + t.Error("Expected nil KeyManager on error") + } + if cleanup != nil { + t.Error("Expected nil cleanup function on error") + } +} diff --git a/pkg/plugin/implementation/keymanager/keymanager.go b/pkg/plugin/implementation/keymanager/keymanager.go new file mode 100644 index 0000000..7dcac67 --- /dev/null +++ b/pkg/plugin/implementation/keymanager/keymanager.go @@ -0,0 +1,328 @@ +package keymanager + +import ( + "context" + "crypto/ecdh" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "os" + "strings" + + "github.com/beckn/beckn-onix/pkg/log" + "github.com/beckn/beckn-onix/pkg/model" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/google/uuid" + vault "github.com/hashicorp/vault/api" +) + +// Config holds configuration parameters for connecting to Vault. +type Config struct { + VaultAddr string + KVVersion string +} + +// KeyMgr provides methods for managing cryptographic keys using Vault. +type KeyMgr struct { + VaultClient *vault.Client + Registry definition.RegistryLookup + Cache definition.Cache + KvVersion string + SecretPath string +} + +var ( + // ErrEmptyKeyID indicates that the provided key ID is empty. + ErrEmptyKeyID = errors.New("invalid request: keyID cannot be empty") + + // ErrNilKeySet indicates that the provided keyset is nil. + ErrNilKeySet = errors.New("keyset cannot be nil") + + // ErrEmptySubscriberID indicates that the provided subscriber ID is empty. + ErrEmptySubscriberID = errors.New("invalid request: subscriberID cannot be empty") + + // ErrEmptyUniqueKeyID indicates that the provided unique key ID is empty. + ErrEmptyUniqueKeyID = errors.New("invalid request: uniqueKeyID cannot be empty") + + // ErrSubscriberNotFound indicates that no subscriber was found with the provided credentials. + ErrSubscriberNotFound = errors.New("no subscriber found with given credentials") + + // ErrNilCache indicates that the cache implementation is nil. + ErrNilCache = errors.New("cache implementation cannot be nil") + + // ErrNilRegistryLookup indicates that the registry lookup implementation is nil. + ErrNilRegistryLookup = errors.New("registry lookup implementation cannot be nil") +) + +// ValidateCfg validates the Vault configuration and sets default KV version if missing. +func ValidateCfg(cfg *Config) error { + if cfg.VaultAddr == "" { + return errors.New("invalid config: VaultAddr cannot be empty") + } + kvVersion := strings.ToLower(cfg.KVVersion) + if kvVersion == "" { + kvVersion = "v1" + } else if kvVersion != "v1" && kvVersion != "v2" { + return fmt.Errorf("invalid KVVersion: must be 'v1' or 'v2'") + } + cfg.KVVersion = kvVersion + return nil +} + +// getVaultClient is a function that creates a new Vault client. +// This is exported for testing purposes. +var getVaultClient = GetVaultClient + +// New creates a new KeyMgr instance with the provided configuration, cache, and registry lookup. +func New(ctx context.Context, cache definition.Cache, registryLookup definition.RegistryLookup, cfg *Config) (*KeyMgr, func() error, error) { + log.Info(ctx, "Initializing KeyManager plugin") + // Validate configuration. + if err := ValidateCfg(cfg); err != nil { + log.Error(ctx, err, "Invalid configuration for KeyManager") + return nil, nil, err + } + // Check if cache implementation is provided. + if cache == nil { + log.Error(ctx, ErrNilCache, "Cache is nil in KeyManager initialization") + return nil, nil, ErrNilCache + } + + // Check if registry lookup implementation is provided. + if registryLookup == nil { + log.Error(ctx, ErrNilRegistryLookup, "RegistryLookup is nil in KeyManager initialization") + return nil, nil, ErrNilRegistryLookup + } + + // Initialize Vault client. + log.Debugf(ctx, "Creating Vault client with address: %s", cfg.VaultAddr) + vaultClient, err := getVaultClient(ctx, cfg.VaultAddr) + if err != nil { + log.Errorf(ctx, err, "Failed to create Vault client at address: %s", cfg.VaultAddr) + return nil, nil, fmt.Errorf("failed to create vault client: %w", err) + } + + log.Info(ctx, "Successfully created Vault client") + + // Create KeyManager instance. + km := &KeyMgr{ + VaultClient: vaultClient, + Registry: registryLookup, + Cache: cache, + KvVersion: cfg.KVVersion, + } + + // Cleanup function to release KeyManager resources. + cleanup := func() error { + log.Info(ctx, "Cleaning up KeyManager resources") + km.VaultClient = nil + km.Cache = nil + km.Registry = nil + return nil + } + + log.Info(ctx, "KeyManager plugin initialized successfully") + return km, cleanup, nil +} + +// NewVaultClient creates a new Vault client instance. +// This function is exported for testing purposes. +var NewVaultClient = vault.NewClient + +// GetVaultClient creates and authenticates a Vault client using AppRole. +func GetVaultClient(ctx context.Context, vaultAddr string) (*vault.Client, error) { + roleID := os.Getenv("VAULT_ROLE_ID") + secretID := os.Getenv("VAULT_SECRET_ID") + + if roleID == "" || secretID == "" { + log.Error(ctx, fmt.Errorf("missing credentials"), "VAULT_ROLE_ID or VAULT_SECRET_ID is not set") + return nil, fmt.Errorf("VAULT_ROLE_ID or VAULT_SECRET_ID is not set") + } + + config := vault.DefaultConfig() + config.Address = vaultAddr + + client, err := NewVaultClient(config) + if err != nil { + log.Error(ctx, err, "failed to create Vault client") + return nil, fmt.Errorf("failed to create Vault client: %w", err) + } + + data := map[string]interface{}{ + "role_id": roleID, + "secret_id": secretID, + } + + log.Info(ctx, "Logging into Vault with AppRole") + resp, err := client.Logical().Write("auth/approle/login", data) + if err != nil { + log.Error(ctx, err, "failed to login with AppRole") + return nil, fmt.Errorf("failed to login with AppRole: %w", err) + } + if resp == nil || resp.Auth == nil { + log.Error(ctx, nil, "AppRole login failed: no auth info returned") + return nil, errors.New("AppRole login failed: no auth info returned") + } + + log.Info(ctx, "Vault login successful") + client.SetToken(resp.Auth.ClientToken) + return client, nil +} + +var ( + ed25519KeyGenFunc = ed25519.GenerateKey + x25519KeyGenFunc = ecdh.X25519().GenerateKey + uuidGenFunc = uuid.NewRandom +) + +// GenerateKeyset generates a new signing (Ed25519) and encryption (X25519) key pair. +func (km *KeyMgr) GenerateKeyset() (*model.Keyset, error) { + signingPublic, signingPrivate, err := ed25519KeyGenFunc(rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate signing key pair: %w", err) + } + + encrPrivateKey, err := x25519KeyGenFunc(rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate encryption key pair: %w", err) + } + encrPublicKey := encrPrivateKey.PublicKey().Bytes() + uuid, err := uuidGenFunc() + if err != nil { + return nil, fmt.Errorf("failed to generate unique key id uuid: %w", err) + } + return &model.Keyset{ + UniqueKeyID: uuid.String(), + SigningPrivate: encodeBase64(signingPrivate.Seed()), + SigningPublic: encodeBase64(signingPublic), + EncrPrivate: encodeBase64(encrPrivateKey.Bytes()), + EncrPublic: encodeBase64(encrPublicKey), + }, nil +} + +// getSecretPath constructs the Vault secret path for storing keys based on the KV version. +func (km *KeyMgr) getSecretPath(keyID string) string { + if km.KvVersion == "v2" { + return fmt.Sprintf("secret/data/keys/%s", keyID) + } + return fmt.Sprintf("secret/keys/%s", keyID) +} + +// InsertKeyset stores the given keyset in Vault under the specified key ID. +func (km *KeyMgr) InsertKeyset(ctx context.Context, keyID string, keys *model.Keyset) error { + if keyID == "" { + return ErrEmptyKeyID + } + if keys == nil { + return ErrNilKeySet + } + + keyData := map[string]interface{}{ + "uniqueKeyID": keys.UniqueKeyID, + "signingPublicKey": keys.SigningPublic, + "signingPrivateKey": keys.SigningPrivate, + "encrPublicKey": keys.EncrPublic, + "encrPrivateKey": keys.EncrPrivate, + } + path := km.getSecretPath(keyID) + var payload map[string]interface{} + if km.KvVersion == "v2" { + payload = map[string]interface{}{"data": keyData} + } else { + payload = keyData + } + + _, err := km.VaultClient.Logical().Write(path, payload) + if err != nil { + return fmt.Errorf("failed to store secret in Vault at path %s: %w", path, err) + } + return nil +} + +// DeleteKeyset deletes the private keys for the given key ID from Vault. +func (km *KeyMgr) DeleteKeyset(ctx context.Context, keyID string) error { + if keyID == "" { + return ErrEmptyKeyID + } + path := km.getSecretPath(keyID) + return km.VaultClient.KVv2(path).Delete(ctx, keyID) +} + +// Keyset retrieves the keyset for the given key ID from Vault and public keys from the registry. +func (km *KeyMgr) Keyset(ctx context.Context, keyID string) (*model.Keyset, error) { + if keyID == "" { + return nil, ErrEmptyKeyID + } + + path := km.getSecretPath(keyID) + + secret, err := km.VaultClient.Logical().Read(path) + if err != nil || secret == nil { + return nil, fmt.Errorf("failed to read secret from Vault: %w", err) + } + + var data map[string]interface{} + if km.KvVersion == "v2" { + dataRaw, ok := secret.Data["data"] + if !ok { + return nil, errors.New("missing 'data' in secret response") + } + data, ok = dataRaw.(map[string]interface{}) + if !ok { + return nil, errors.New("invalid 'data' format in Vault response") + } + } else { + data = secret.Data + } + + return &model.Keyset{ + UniqueKeyID: data["uniqueKeyID"].(string), + SigningPublic: data["signingPublicKey"].(string), + SigningPrivate: data["signingPrivateKey"].(string), + EncrPublic: data["encrPublicKey"].(string), + EncrPrivate: data["encrPrivateKey"].(string), + }, nil +} + +// LookupNPKeys retrieves the signing and encryption public keys for the given subscriber ID and unique key ID. +func (km *KeyMgr) LookupNPKeys(ctx context.Context, subscriberID, uniqueKeyID string) (string, string, error) { + cacheKey := fmt.Sprintf("%s_%s", subscriberID, uniqueKeyID) + cachedData, err := km.Cache.Get(ctx, cacheKey) + if err == nil { + var keys model.Keyset + if err := json.Unmarshal([]byte(cachedData), &keys); err == nil { + return keys.SigningPublic, keys.EncrPublic, nil + } + } + subscribers, err := km.Registry.Lookup(ctx, &model.Subscription{ + Subscriber: model.Subscriber{ + SubscriberID: subscriberID, + }, + KeyID: uniqueKeyID, + }) + if err != nil { + return "", "", fmt.Errorf("failed to lookup registry: %w", err) + } + if len(subscribers) == 0 { + return "", "", ErrSubscriberNotFound + } + return subscribers[0].SigningPublicKey, subscribers[0].EncrPublicKey, nil +} + +// validateParams checks that subscriberID and uniqueKeyID are not empty. +func validateParams(subscriberID, uniqueKeyID string) error { + if subscriberID == "" { + return ErrEmptySubscriberID + } + if uniqueKeyID == "" { + return ErrEmptyUniqueKeyID + } + return nil +} + +// encodeBase64 returns the base64-encoded string of the given data. +func encodeBase64(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} diff --git a/pkg/plugin/implementation/keymanager/keymanager_test.go b/pkg/plugin/implementation/keymanager/keymanager_test.go new file mode 100644 index 0000000..5415f46 --- /dev/null +++ b/pkg/plugin/implementation/keymanager/keymanager_test.go @@ -0,0 +1,1114 @@ +package keymanager + +import ( + "context" + "crypto/ecdh" + "crypto/ed25519" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/beckn/beckn-onix/pkg/model" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/google/uuid" + "github.com/hashicorp/vault/api" + vault "github.com/hashicorp/vault/api" +) + +type mockRegistry struct { + LookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) +} + +func (m *mockRegistry) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + if m.LookupFunc != nil { + return m.LookupFunc(ctx, sub) + } + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + URL: "https://mock.registry/subscriber", + Type: "BPP", + Domain: "retail", + }, + KeyID: sub.KeyID, + SigningPublicKey: "mock-signing-public-key", + EncrPublicKey: "mock-encryption-public-key", + ValidFrom: time.Now().Add(-time.Hour), + ValidUntil: time.Now().Add(time.Hour), + Status: "SUBSCRIBED", + Created: time.Now().Add(-2 * time.Hour), + Updated: time.Now(), + Nonce: "mock-nonce", + }, + }, nil +} + +type mockCache struct { + GetFunc func(ctx context.Context, key string) (string, error) +} + +func (m *mockCache) Get(ctx context.Context, key string) (string, error) { + return "", nil +} +func (m *mockCache) Set(ctx context.Context, key string, value string, ttl time.Duration) error { + return nil +} +func (m *mockCache) Clear(ctx context.Context) error { + return nil +} + +func (m *mockCache) Delete(ctx context.Context, key string) error { + return nil +} + +func TestValidateCfgSuccess(t *testing.T) { + tests := []struct { + name string + cfg *Config + wantKV string + }{ + { + name: "valid config with v1", + cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v1"}, + wantKV: "v1", + }, + { + name: "valid config with v2", + cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v2"}, + wantKV: "v2", + }, + { + name: "default KV version applied", + cfg: &Config{VaultAddr: "http://localhost:8200"}, + wantKV: "v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateCfg(tt.cfg) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if tt.cfg.KVVersion != tt.wantKV { + t.Errorf("expected KVVersion %s, got %s", tt.wantKV, tt.cfg.KVVersion) + } + }) + } +} + +func TestValidateCfgFailure(t *testing.T) { + tests := []struct { + name string + cfg *Config + }{ + { + name: "missing Vault address", + cfg: &Config{VaultAddr: "", KVVersion: "v1"}, + }, + { + name: "invalid KV version", + cfg: &Config{VaultAddr: "http://localhost:8200", KVVersion: "v3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateCfg(tt.cfg) + if err == nil { + t.Errorf("expected error, got nil") + } + }) + } +} + +func TestGenerateKeyPairs(t *testing.T) { + originalEd25519 := ed25519KeyGenFunc + originalX25519 := x25519KeyGenFunc + originalUUID := uuidGenFunc + + defer func() { + ed25519KeyGenFunc = originalEd25519 + x25519KeyGenFunc = originalX25519 + uuidGenFunc = originalUUID + }() + + tests := []struct { + name string + mockEd25519Err error + mockX25519Err error + mockUUIDErr error + expectErr bool + }{ + { + name: "success case", + expectErr: false, + }, + { + name: "ed25519 key generation failure", + mockEd25519Err: errors.New("mock ed25519 failure"), + expectErr: true, + }, + { + name: "x25519 key generation failure", + mockX25519Err: errors.New("mock x25519 failure"), + expectErr: true, + }, + { + name: "UUID generation failure", + mockUUIDErr: errors.New("mock uuid failure"), + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.mockEd25519Err != nil { + ed25519KeyGenFunc = func(_ io.Reader) (ed25519.PublicKey, ed25519.PrivateKey, error) { + return nil, nil, tt.mockEd25519Err + } + } else { + ed25519KeyGenFunc = ed25519.GenerateKey + } + + if tt.mockX25519Err != nil { + x25519KeyGenFunc = func(_ io.Reader) (*ecdh.PrivateKey, error) { + return nil, tt.mockX25519Err + } + } else { + x25519KeyGenFunc = ecdh.X25519().GenerateKey + } + + if tt.mockUUIDErr != nil { + uuidGenFunc = func() (uuid.UUID, error) { + return uuid.Nil, tt.mockUUIDErr + } + } else { + uuidGenFunc = uuid.NewRandom + } + + km := &KeyMgr{} + keyset, err := km.GenerateKeyset() + + if tt.expectErr { + if err == nil { + t.Errorf("expected error, got nil") + } + if keyset != nil { + t.Errorf("expected nil keyset, got non-nil") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if keyset == nil { + t.Fatal("expected keyset, got nil") + } + if keyset.SigningPrivate == "" || keyset.SigningPublic == "" || keyset.EncrPrivate == "" || keyset.EncrPublic == "" { + t.Error("expected all keys to be populated and base64-encoded") + } + if keyset.UniqueKeyID == "" { + t.Error("expected UniqueKeyID to be non-empty") + } + } + }) + } +} + +func TestGetVaultClient_Failures(t *testing.T) { + originalNewVaultClient := NewVaultClient + defer func() { NewVaultClient = originalNewVaultClient }() + + ctx := context.Background() + + tests := []struct { + name string + roleID string + secretID string + setupServer func(t *testing.T) *httptest.Server + expectErr string + }{ + { + name: "missing credentials", + roleID: "", + secretID: "", + expectErr: "VAULT_ROLE_ID or VAULT_SECRET_ID is not set", + }, + { + name: "vault client creation failure", + roleID: "test-role", + secretID: "test-secret", + setupServer: func(t *testing.T) *httptest.Server { + NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) { + return nil, errors.New("mock client creation error") + } + return nil + }, + expectErr: "failed to create Vault client: mock client creation error", + }, + { + name: "AppRole login failure", + roleID: "test-role", + secretID: "test-secret", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "login failed", http.StatusBadRequest) + })) + }, + expectErr: "failed to login with AppRole: Error making API request", + }, + { + name: "AppRole login returns nil auth", + roleID: "test-role", + secretID: "test-secret", + setupServer: func(t *testing.T) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + if _, err := io.WriteString(w, `{ "auth": null }`); err != nil { + t.Fatalf("failed to write response: %v", err) + } + })) + }, + expectErr: "AppRole login failed: no auth info returned", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv("VAULT_ROLE_ID", tt.roleID) + os.Setenv("VAULT_SECRET_ID", tt.secretID) + + var server *httptest.Server + if tt.setupServer != nil { + server = tt.setupServer(t) + if server != nil { + NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) { + cfg.Address = server.URL + return vault.NewClient(cfg) + } + defer server.Close() + } + } + + client, err := GetVaultClient(ctx, "http://ignored") + if err == nil || !strings.Contains(err.Error(), tt.expectErr) { + t.Errorf("expected error to contain '%s', got: %v", tt.expectErr, err) + } + if client != nil { + t.Error("expected client to be nil on failure") + } + }) + } +} + +func TestGetVaultClient_Success(t *testing.T) { + originalNewVaultClient := NewVaultClient + defer func() { NewVaultClient = originalNewVaultClient }() + + ctx := context.Background() + + os.Setenv("VAULT_ROLE_ID", "test-role") + os.Setenv("VAULT_SECRET_ID", "test-secret") + + // Mock Vault server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(r.URL.Path, "/auth/approle/login") { + t.Errorf("unexpected request path: %s", r.URL.Path) + } + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + if _, err := io.WriteString(w, `{ + "auth": { + "client_token": "mock-token" + } + }`); err != nil { + t.Fatalf("failed to write response: %v", err) + } + })) + defer server.Close() + + NewVaultClient = func(cfg *vault.Config) (*vault.Client, error) { + cfg.Address = server.URL + return vault.NewClient(cfg) + } + + client, err := GetVaultClient(ctx, "http://ignored") + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } + if token := client.Token(); token != "mock-token" { + t.Errorf("expected token to be 'mock-token', got: %s", token) + } +} + +type mockRegistryLookup struct{} + +func (m *mockRegistryLookup) Lookup(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + Type: sub.Type, + }, + KeyID: "mock-key-id", + SigningPublicKey: "mock-signing-pubkey", + EncrPublicKey: "mock-encryption-pubkey", + ValidFrom: time.Now().Add(-time.Hour), + ValidUntil: time.Now().Add(time.Hour), + Status: "SUBSCRIBED", + Created: time.Now(), + Updated: time.Now(), + Nonce: "mock-nonce", + }, + }, nil +} + +func TestNewSuccess(t *testing.T) { + tests := []struct { + name string + cfg *Config + cache definition.Cache + registry definition.RegistryLookup + mockVaultStatus int + mockVaultBody string + }{ + { + name: "valid config", + cfg: &Config{ + VaultAddr: "http://dummy", + KVVersion: "v2", + }, + cache: &mockCache{}, + registry: &mockRegistryLookup{}, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + } + + originalGetVaultClient := getVaultClient + defer func() { getVaultClient = originalGetVaultClient }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.mockVaultStatus) + fmt.Fprint(w, tt.mockVaultBody) + })) + defer vaultServer.Close() + + tt.cfg.VaultAddr = vaultServer.URL + + getVaultClient = func(ctx context.Context, addr string) (*vault.Client, error) { + cfg := vault.DefaultConfig() + cfg.Address = addr + return vault.NewClient(cfg) + } + + ctx := context.Background() + km, cleanup, err := New(ctx, tt.cache, tt.registry, tt.cfg) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if km == nil { + t.Fatalf("expected KeyMgr instance, got nil") + } + if cleanup == nil { + t.Fatalf("expected cleanup function, got nil") + } + _ = cleanup() + }) + } +} + +func TestNewFailure(t *testing.T) { + tests := []struct { + name string + cfg *Config + cache definition.Cache + registry definition.RegistryLookup + mockVaultStatus int + mockVaultBody string + }{ + { + name: "nil cache", + cfg: &Config{ + VaultAddr: "http://dummy", + KVVersion: "v2", + }, + cache: nil, + registry: &mockRegistryLookup{}, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + { + name: "nil registry", + cfg: &Config{ + VaultAddr: "http://dummy", + KVVersion: "v2", + }, + cache: &mockCache{}, + registry: nil, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + { + name: "invalid config", + cfg: &Config{ + VaultAddr: "", // Invalid + KVVersion: "v3", // Unsupported + }, + cache: &mockCache{}, + registry: &mockRegistryLookup{}, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + { + name: "vault client creation failure", + cfg: &Config{ + VaultAddr: "http://dummy", + KVVersion: "v2", + }, + cache: &mockCache{}, + registry: &mockRegistryLookup{}, + mockVaultStatus: http.StatusOK, + mockVaultBody: `{}`, + }, + } + + originalGetVaultClient := getVaultClient + defer func() { getVaultClient = originalGetVaultClient }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.mockVaultStatus) + fmt.Fprint(w, tt.mockVaultBody) + })) + defer vaultServer.Close() + + if tt.cfg != nil { + tt.cfg.VaultAddr = vaultServer.URL + } + + getVaultClient = func(ctx context.Context, addr string) (*vault.Client, error) { + if tt.name == "vault client creation failure" { + return nil, errors.New("simulated vault client creation error") + } + cfg := vault.DefaultConfig() + cfg.Address = addr + return vault.NewClient(cfg) + } + + ctx := context.Background() + km, cleanup, err := New(ctx, tt.cache, tt.registry, tt.cfg) + + if err == nil { + t.Error("expected error, got nil") + } + if km != nil { + t.Error("expected KeyMgr to be nil, got non-nil") + } + if cleanup != nil { + t.Error("expected cleanup to be nil, got non-nil") + } + }) + } + +} + +func TestStorePrivateKeysSuccess(t *testing.T) { + ctx := context.Background() + + keys := &model.Keyset{ + UniqueKeyID: "uuid", + SigningPublic: "signPub", + SigningPrivate: "signPriv", + EncrPublic: "encrPub", + EncrPrivate: "encrPriv", + } + + tests := []struct { + name string + kvVersion string + keyID string + keys *model.Keyset + }{ + { + name: "success kv v1", + kvVersion: "v1", + keyID: "mykeyid", + keys: keys, + }, + { + name: "success kv v2", + kvVersion: "v2", + keyID: "mykeyid", + keys: keys, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := "" + if tt.kvVersion == "v2" { + expectedPath = "/v1/secret/data/keys/" + tt.keyID + } else { + expectedPath = "/v1/secret/keys/" + tt.keyID + } + + if r.URL.Path != expectedPath { + t.Errorf("unexpected request path: got %s, want %s", r.URL.Path, expectedPath) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + fmt.Fprintln(w, `{"data":{}}`) + })) + defer server.Close() + + config := api.DefaultConfig() + config.Address = server.URL + client, err := api.NewClient(config) + if err != nil { + t.Fatalf("failed to create Vault client: %v", err) + } + + km := &KeyMgr{ + VaultClient: client, + KvVersion: tt.kvVersion, + } + + err = km.InsertKeyset(ctx, tt.keyID, tt.keys) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + }) + } +} + +func TestStorePrivateKeysFailure(t *testing.T) { + ctx := context.Background() + + keys := &model.Keyset{ + UniqueKeyID: "uuid", + SigningPublic: "signPub", + SigningPrivate: "signPriv", + EncrPublic: "encrPub", + EncrPrivate: "encrPriv", + } + + tests := []struct { + name string + kvVersion string + keyID string + keys *model.Keyset + statusCode int // for HTTP error simulation + expectedErr string + }{ + { + name: "empty keyID", + keyID: "", + keys: keys, + expectedErr: ErrEmptyKeyID.Error(), + }, + { + name: "nil keys", + keyID: "mykeyid", + keys: nil, + expectedErr: ErrNilKeySet.Error(), + }, + { + name: "vault write error", + kvVersion: "v1", + keyID: "mykeyid", + keys: keys, + statusCode: 500, + expectedErr: "failed to store secret in Vault at path secret/keys/mykeyid: Error making API request.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var server *httptest.Server + if tt.statusCode != 0 { + // Setup test HTTP server to simulate Vault error + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal error", tt.statusCode) + })) + defer server.Close() + } + + var client *api.Client + var err error + if server != nil { + config := api.DefaultConfig() + config.Address = server.URL + client, err = api.NewClient(config) + if err != nil { + t.Fatalf("failed to create Vault client: %v", err) + } + } else { + client = nil + } + + km := &KeyMgr{ + VaultClient: client, + KvVersion: tt.kvVersion, + } + + err = km.InsertKeyset(ctx, tt.keyID, tt.keys) + + if err == nil { + t.Fatalf("expected error %q but got nil", tt.expectedErr) + } + if !strings.Contains(err.Error(), tt.expectedErr) { + t.Errorf("expected error containing %q, got %v", tt.expectedErr, err) + } + }) + } +} + +func TestDeletePrivateKeys(t *testing.T) { + tests := []struct { + name string + kvVersion string + keyID string + wantPath string + wantErr error + }{ + { + name: "empty keyID", + kvVersion: "v1", + keyID: "", + wantErr: ErrEmptyKeyID, + }, + { + name: "v1 delete", + kvVersion: "v1", + keyID: "key123", + wantPath: "/v1/secret/keys/key123/data/key123", + wantErr: nil, + }, + { + name: "v2 delete", + kvVersion: "v2", + keyID: "key123", + wantPath: "/v1/secret/data/keys/key123/data/key123", + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // If empty keyID, no Vault calls, just check error + if tt.keyID == "" { + km := &KeyMgr{ + KvVersion: tt.kvVersion, + VaultClient: nil, + } + err := km.DeleteKeyset(context.Background(), tt.keyID) + if err != tt.wantErr { + t.Errorf("expected error %v, got %v", tt.wantErr, err) + } + return + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + t.Errorf("Expected DELETE method, got %s", r.Method) + } + if r.URL.Path != tt.wantPath { + t.Errorf("Expected path %s, got %s", tt.wantPath, r.URL.Path) + } + w.WriteHeader(http.StatusNoContent) + })) + defer ts.Close() + + vaultClient, err := NewVaultClient(&vault.Config{Address: ts.URL}) + if err != nil { + t.Fatalf("failed to create vault client: %v", err) + } + + km := &KeyMgr{ + KvVersion: tt.kvVersion, + VaultClient: vaultClient, + } + + err = km.DeleteKeyset(context.Background(), tt.keyID) + if err != tt.wantErr { + t.Errorf("DeletePrivateKeys() error = %v, want %v", err, tt.wantErr) + } + }) + } +} + +func setupMockVaultServer(t *testing.T, kvVersion, keyID string, success bool) *httptest.Server { + t.Helper() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPathV1 := fmt.Sprintf("/v1/secret/keys/%s", keyID) + expectedPathV2 := fmt.Sprintf("/v1/secret/data/keys/%s", keyID) + + if (kvVersion == "v2" && r.URL.Path != expectedPathV2) || (kvVersion != "v2" && r.URL.Path != expectedPathV1) { + http.Error(w, "not found", http.StatusNotFound) + return + } + + if !success { + http.Error(w, `{"errors":["key not found"]}`, http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + + if kvVersion == "v2" { + resp := fmt.Sprintf(`{ + "request_id": "req-1234", + "lease_id": "", + "renewable": false, + "lease_duration": 0, + "data": { + "data": { + "uniqueKeyID": "%s", + "signingPublicKey": "sign-pub", + "signingPrivateKey": "sign-priv", + "encrPublicKey": "encr-pub", + "encrPrivateKey": "encr-priv" + }, + "metadata": { + "created_time": "2025-05-28T00:00:00Z", + "deletion_time": "", + "destroyed": false, + "version": 1 + } + }, + "warnings": null, + "auth": null + }`, keyID) + if _, err := w.Write([]byte(resp)); err != nil { + t.Fatalf("failed to write response: %v", err) + } + } else { + resp := fmt.Sprintf(`{ + "request_id": "req-1234", + "lease_id": "", + "renewable": false, + "lease_duration": 0, + "data": { + "uniqueKeyID": "%s", + "signingPublicKey": "sign-pub", + "signingPrivateKey": "sign-priv", + "encrPublicKey": "encr-pub", + "encrPrivateKey": "encr-priv" + }, + "warnings": null, + "auth": null + }`, keyID) + if _, err := w.Write([]byte(resp)); err != nil { + t.Fatalf("failed to write response: %v", err) + } + } + }) + + return httptest.NewServer(handler) +} + +func TestKeysetSuccess(t *testing.T) { + tests := []struct { + name string + kvVersion string + keyID string + }{ + { + name: "success with KV v2", + kvVersion: "v2", + keyID: "test-key-v2", + }, + { + name: "success with KV v1", + kvVersion: "v1", + keyID: "test-key-v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts := setupMockVaultServer(t, tt.kvVersion, tt.keyID, true) + defer ts.Close() + + cfg := vault.DefaultConfig() + cfg.Address = ts.URL + + client, err := vault.NewClient(cfg) + if err != nil { + t.Fatalf("failed to create Vault client: %v", err) + } + + km := &KeyMgr{ + VaultClient: client, + KvVersion: tt.kvVersion, + } + + keys, err := km.Keyset(context.Background(), tt.keyID) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if keys == nil { + t.Fatalf("expected keys but got nil") + } + if keys.UniqueKeyID != tt.keyID { + t.Errorf("expected UniqueKeyID %q, got %q", tt.keyID, keys.UniqueKeyID) + } + if keys.SigningPrivate != "sign-priv" { + t.Errorf("expected SigningPrivate 'sign-priv', got %q", keys.SigningPrivate) + } + }) + } +} + +func TestKeysetFailure(t *testing.T) { + tests := []struct { + name string + kvVersion string + keyID string + success bool + }{ + { + name: "failure: vault returns 404 v2", + kvVersion: "v2", + keyID: "missing-key-v2", + success: false, + }, + { + name: "failure: vault returns 404 v1", + kvVersion: "v1", + keyID: "missing-key-v1", + success: false, + }, + { + name: "failure: empty keyID", + kvVersion: "v2", + keyID: "", + success: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ts *httptest.Server + if tt.keyID != "" { + ts = setupMockVaultServer(t, tt.kvVersion, tt.keyID, tt.success) + defer ts.Close() + } + + cfg := vault.DefaultConfig() + if ts != nil { + cfg.Address = ts.URL + } else { + // For empty keyID case or no mock server, use invalid URL to force error + cfg.Address = "http://invalid" + } + + client, err := vault.NewClient(cfg) + if err != nil { + t.Fatalf("failed to create Vault client: %v", err) + } + + km := &KeyMgr{ + VaultClient: client, + KvVersion: tt.kvVersion, + } + + keys, err := km.Keyset(context.Background(), tt.keyID) + if err == nil { + t.Fatalf("expected error but got nil") + } + if keys != nil { + t.Fatalf("expected nil keys but got %+v", keys) + } + }) + } +} + +func TestValidateParamsSuccess(t *testing.T) { + err := validateParams("someSubscriberID", "someUniqueKeyID") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestValidateParamsFailure(t *testing.T) { + tests := []struct { + name string + subscriberID string + uniqueKeyID string + wantErr error + }{ + { + name: "empty subscriberID", + subscriberID: "", + uniqueKeyID: "validKeyID", + wantErr: ErrEmptySubscriberID, + }, + { + name: "empty uniqueKeyID", + subscriberID: "validSubscriberID", + uniqueKeyID: "", + wantErr: ErrEmptyUniqueKeyID, + }, + { + name: "both empty", + subscriberID: "", + uniqueKeyID: "", + wantErr: ErrEmptySubscriberID, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateParams(tt.subscriberID, tt.uniqueKeyID) + if err == nil { + t.Fatalf("expected error %v but got nil", tt.wantErr) + } + if err != tt.wantErr { + t.Errorf("expected error %v, got %v", tt.wantErr, err) + } + }) + } +} + +func TestLookupNPKeysSuccess(t *testing.T) { + tests := []struct { + name string + cacheGetFunc func(ctx context.Context, key string) (string, error) + registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) + expectedSigningPub string + expectedEncrPub string + }{ + { + name: "Cache hit with valid keys", + cacheGetFunc: func(ctx context.Context, key string) (string, error) { + return `{"SigningPublic":"mock-signing-public-key","EncrPublic":"mock-encryption-public-key"}`, nil + }, + registryLookupFunc: nil, + expectedSigningPub: "mock-signing-public-key", + expectedEncrPub: "mock-encryption-public-key", + }, + { + name: "Cache miss and registry success", + cacheGetFunc: func(ctx context.Context, key string) (string, error) { + + return "", nil + }, + registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return []model.Subscription{ + { + Subscriber: model.Subscriber{ + SubscriberID: sub.SubscriberID, + }, + KeyID: sub.KeyID, + SigningPublicKey: "mock-signing-public-key", + EncrPublicKey: "mock-encryption-public-key", + }, + }, nil + }, + expectedSigningPub: "mock-signing-public-key", + expectedEncrPub: "mock-encryption-public-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the KeyMgr with mocks + km := &KeyMgr{ + Cache: &mockCache{ + GetFunc: tt.cacheGetFunc, + }, + Registry: &mockRegistry{ + LookupFunc: tt.registryLookupFunc, + }, + } + + // Call the method + signingPublic, encrPublic, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id") + + // Validate no errors in success cases + if err != nil { + t.Fatalf("LookupNPKeys() unexpected error: %v", err) + } + + // Validate returned public keys + if signingPublic != tt.expectedSigningPub { + t.Errorf("SigningPublic = %v, want %v", signingPublic, tt.expectedSigningPub) + } + if encrPublic != tt.expectedEncrPub { + t.Errorf("EncrPublic = %v, want %v", encrPublic, tt.expectedEncrPub) + } + }) + } +} + +func TestLookupNPKeysFailure(t *testing.T) { + tests := []struct { + name string + cacheGetFunc func(ctx context.Context, key string) (string, error) + registryLookupFunc func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) + expectedError string + }{ + { + name: "Cache miss and registry failure", + cacheGetFunc: func(ctx context.Context, key string) (string, error) { + return "", nil + }, + registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return nil, fmt.Errorf("registry down") + }, + expectedError: "registry down", + }, + { + name: "Cache miss and registry returns no subscriber", + cacheGetFunc: func(ctx context.Context, key string) (string, error) { + return "", nil + }, + registryLookupFunc: func(ctx context.Context, sub *model.Subscription) ([]model.Subscription, error) { + return nil, nil + }, + expectedError: "no subscriber found with given credentials", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the KeyMgr with mocks + km := &KeyMgr{ + Cache: &mockCache{ + GetFunc: tt.cacheGetFunc, + }, + Registry: &mockRegistry{ + LookupFunc: tt.registryLookupFunc, + }, + } + _, _, err := km.LookupNPKeys(context.Background(), "sub-id", "key-id") + if err == nil { + t.Fatalf("expected an error but got none") + } + + if !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error to contain %v, got %v", tt.expectedError, err.Error()) + } + }) + } +} diff --git a/pkg/plugin/implementation/publisher/cmd/plugin.go b/pkg/plugin/implementation/publisher/cmd/plugin.go new file mode 100644 index 0000000..ccf87fa --- /dev/null +++ b/pkg/plugin/implementation/publisher/cmd/plugin.go @@ -0,0 +1,37 @@ +package main + +import ( + "context" + + "github.com/beckn/beckn-onix/pkg/log" + "github.com/beckn/beckn-onix/pkg/plugin/definition" + "github.com/beckn/beckn-onix/pkg/plugin/implementation/publisher" +) + +// publisherProvider implements the PublisherProvider interface. +// It is responsible for creating a new Publisher instance. +type publisherProvider struct{} + +// New creates a new Publisher instance based on the provided configuration. +func (p *publisherProvider) New(ctx context.Context, config map[string]string) (definition.Publisher, func() error, error) { + cfg := &publisher.Config{ + Addr: config["addr"], + Exchange: config["exchange"], + RoutingKey: config["routing_key"], + Durable: config["durable"] == "true", + UseTLS: config["use_tls"] == "true", + } + log.Debugf(ctx, "Publisher config mapped: %+v", cfg) + + pub, cleanup, err := publisher.New(cfg) + if err != nil { + log.Errorf(ctx, err, "Failed to create publisher instance") + return nil, nil, err + } + + log.Infof(ctx, "Publisher instance created successfully") + return pub, cleanup, nil +} + +// Provider is the instance of publisherProvider that implements the PublisherProvider interface. +var Provider = publisherProvider{} diff --git a/pkg/plugin/implementation/publisher/cmd/plugin_test.go b/pkg/plugin/implementation/publisher/cmd/plugin_test.go new file mode 100644 index 0000000..e3f9837 --- /dev/null +++ b/pkg/plugin/implementation/publisher/cmd/plugin_test.go @@ -0,0 +1,106 @@ +package main + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/beckn/beckn-onix/pkg/plugin/implementation/publisher" + "github.com/rabbitmq/amqp091-go" +) + +type mockChannel struct{} + +func (m *mockChannel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error { + return nil +} +func (m *mockChannel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error { + return nil +} +func (m *mockChannel) Close() error { + return nil +} + +func TestPublisherProvider_New_Success(t *testing.T) { + // Save original dialFunc and channelFunc + originalDialFunc := publisher.DialFunc + originalChannelFunc := publisher.ChannelFunc + defer func() { + publisher.DialFunc = originalDialFunc + publisher.ChannelFunc = originalChannelFunc + }() + + // Override mocks + publisher.DialFunc = func(url string) (*amqp091.Connection, error) { + return nil, nil + } + publisher.ChannelFunc = func(conn *amqp091.Connection) (publisher.Channel, error) { + return &mockChannel{}, nil + } + + t.Setenv("RABBITMQ_USERNAME", "guest") + t.Setenv("RABBITMQ_PASSWORD", "guest") + + config := map[string]string{ + "addr": "localhost", + "exchange": "test-exchange", + "routing_key": "test.key", + "durable": "true", + "use_tls": "false", + } + + ctx := context.Background() + pub, cleanup, err := Provider.New(ctx, config) + + if err != nil { + t.Fatalf("Provider.New returned error: %v", err) + } + if pub == nil { + t.Fatal("Expected non-nil publisher") + } + if cleanup == nil { + t.Fatal("Expected non-nil cleanup function") + } + + if err := cleanup(); err != nil { + t.Errorf("Cleanup returned error: %v", err) + } +} + +func TestPublisherProvider_New_Failure(t *testing.T) { + // Save and restore dialFunc + originalDialFunc := publisher.DialFunc + defer func() { publisher.DialFunc = originalDialFunc }() + + // Simulate dial failure + publisher.DialFunc = func(url string) (*amqp091.Connection, error) { + return nil, errors.New("dial failed") + } + + t.Setenv("RABBITMQ_USERNAME", "guest") + t.Setenv("RABBITMQ_PASSWORD", "guest") + + config := map[string]string{ + "addr": "localhost", + "exchange": "test-exchange", + "routing_key": "test.key", + "durable": "true", + } + + ctx := context.Background() + pub, cleanup, err := Provider.New(ctx, config) + + if err == nil { + t.Fatal("Expected error from Provider.New but got nil") + } + if !strings.Contains(err.Error(), "dial failed") { + t.Errorf("Expected 'dial failed' error, got: %v", err) + } + if pub != nil { + t.Errorf("Expected nil publisher, got: %v", pub) + } + if cleanup != nil { + t.Error("Expected nil cleanup, got non-nil") + } +} diff --git a/pkg/plugin/implementation/publisher/publisher.go b/pkg/plugin/implementation/publisher/publisher.go new file mode 100644 index 0000000..db3e577 --- /dev/null +++ b/pkg/plugin/implementation/publisher/publisher.go @@ -0,0 +1,196 @@ +package publisher + +import ( + "context" + "errors" + "fmt" + "net/url" + "os" + "strings" + + "github.com/beckn/beckn-onix/pkg/log" + "github.com/beckn/beckn-onix/pkg/model" + "github.com/rabbitmq/amqp091-go" +) + +// Config holds the configuration required to establish a connection with RabbitMQ. +type Config struct { + Addr string + Exchange string + RoutingKey string + Durable bool + UseTLS bool +} + +// Channel defines the interface for publishing messages to RabbitMQ. +type Channel interface { + PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error + ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error + Close() error +} + +// Publisher manages the RabbitMQ connection and channel to publish messages. +type Publisher struct { + Conn *amqp091.Connection + Channel Channel + Config *Config +} + +// Error variables representing different failure scenarios. +var ( + ErrEmptyConfig = errors.New("empty config") + ErrAddrMissing = errors.New("missing required field 'Addr'") + ErrExchangeMissing = errors.New("missing required field 'Exchange'") + ErrCredentialMissing = errors.New("missing RabbitMQ credentials in environment") + ErrConnectionFailed = errors.New("failed to connect to RabbitMQ") + ErrChannelFailed = errors.New("failed to open channel") + ErrExchangeDeclare = errors.New("failed to declare exchange") +) + +// Validate checks whether the provided Config is valid for connecting to RabbitMQ. +func Validate(cfg *Config) error { + if cfg == nil { + return model.NewBadReqErr(fmt.Errorf("config is nil")) + } + if strings.TrimSpace(cfg.Addr) == "" { + return model.NewBadReqErr(fmt.Errorf("missing config.Addr")) + } + if strings.TrimSpace(cfg.Exchange) == "" { + return model.NewBadReqErr(fmt.Errorf("missing config.Exchange")) + } + return nil +} + +// GetConnURL constructs the RabbitMQ connection URL using the config and environment credentials. +func GetConnURL(cfg *Config) (string, error) { + user := os.Getenv("RABBITMQ_USERNAME") + pass := os.Getenv("RABBITMQ_PASSWORD") + if user == "" || pass == "" { + return "", model.NewBadReqErr(fmt.Errorf("missing RabbitMQ credentials in environment")) + } + parts := strings.SplitN(strings.TrimSpace(cfg.Addr), "/", 2) + hostPort := parts[0] + vhost := "/" + if len(parts) > 1 { + vhost = parts[1] + } + + if !strings.Contains(hostPort, ":") { + if cfg.UseTLS { + hostPort += ":5671" + } else { + hostPort += ":5672" + } + } + + encodedUser := url.QueryEscape(user) + encodedPass := url.QueryEscape(pass) + encodedVHost := url.QueryEscape(vhost) + protocol := "amqp" + if cfg.UseTLS { + protocol = "amqps" + } + + connURL := fmt.Sprintf("%s://%s:%s@%s/%s", protocol, encodedUser, encodedPass, hostPort, encodedVHost) + log.Debugf(context.Background(), "Generated RabbitMQ connection details: protocol=%s, hostPort=%s, vhost=%s", protocol, hostPort, vhost) + + return connURL, nil +} + +// Publish sends a message to the configured RabbitMQ exchange with the specified routing key. +// If routingKey is empty, the default routing key from Config is used. +func (p *Publisher) Publish(ctx context.Context, routingKey string, msg []byte) error { + if routingKey == "" { + routingKey = p.Config.RoutingKey + } + log.Debugf(ctx, "Attempting to publish message. Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey) + err := p.Channel.PublishWithContext( + ctx, + p.Config.Exchange, + routingKey, + false, + false, + amqp091.Publishing{ + ContentType: "application/json", + Body: msg, + }, + ) + + if err != nil { + log.Errorf(ctx, err, "Publish failed for Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey) + return model.NewBadReqErr(fmt.Errorf("publish message failed: %w", err)) + } + + log.Infof(ctx, "Message published successfully to Exchange: %s, RoutingKey: %s", p.Config.Exchange, routingKey) + return nil +} + +// DialFunc is a function variable used to establish a connection to RabbitMQ. +var DialFunc = amqp091.Dial + +// ChannelFunc is a function variable used to open a channel on the given RabbitMQ connection. +var ChannelFunc = func(conn *amqp091.Connection) (Channel, error) { + return conn.Channel() +} + +// New initializes a new Publisher with the given config, opens a connection, +// channel, and declares the exchange. Returns the publisher and a cleanup function. +func New(cfg *Config) (*Publisher, func() error, error) { + // Step 1: Validate config + if err := Validate(cfg); err != nil { + return nil, nil, err + } + + // Step 2: Build connection URL + connURL, err := GetConnURL(cfg) + if err != nil { + return nil, nil, fmt.Errorf("%w: %v", ErrConnectionFailed, err) + } + + // Step 3: Dial connection + conn, err := DialFunc(connURL) + if err != nil { + return nil, nil, fmt.Errorf("%w: %v", ErrConnectionFailed, err) + } + + // Step 4: Open channel + ch, err := ChannelFunc(conn) + if err != nil { + conn.Close() + return nil, nil, fmt.Errorf("%w: %v", ErrChannelFailed, err) + } + + // Step 5: Declare exchange + if err := ch.ExchangeDeclare( + cfg.Exchange, + "topic", + cfg.Durable, + false, + false, + false, + nil, + ); err != nil { + ch.Close() + conn.Close() + return nil, nil, fmt.Errorf("%w: %v", ErrExchangeDeclare, err) + } + + // Step 6: Construct publisher + pub := &Publisher{ + Conn: conn, + Channel: ch, + Config: cfg, + } + + cleanup := func() error { + if ch != nil { + _ = ch.Close() + } + if conn != nil { + return conn.Close() + } + return nil + } + + return pub, cleanup, nil +} diff --git a/pkg/plugin/implementation/publisher/publisher_test.go b/pkg/plugin/implementation/publisher/publisher_test.go new file mode 100644 index 0000000..82b8404 --- /dev/null +++ b/pkg/plugin/implementation/publisher/publisher_test.go @@ -0,0 +1,362 @@ +package publisher + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/rabbitmq/amqp091-go" +) + +func TestGetConnURLSuccess(t *testing.T) { + tests := []struct { + name string + config *Config + }{ + { + name: "Valid config with connection address", + config: &Config{ + Addr: "localhost:5672", + UseTLS: false, + }, + }, + + { + name: "Valid config with vhost", + config: &Config{ + Addr: "localhost:5672/myvhost", + UseTLS: false, + }, + }, + { + name: "Addr with leading and trailing spaces", + config: &Config{ + Addr: " localhost:5672/myvhost ", + UseTLS: false, + }, + }, + } + + // Set valid credentials + t.Setenv("RABBITMQ_USERNAME", "guest") + t.Setenv("RABBITMQ_PASSWORD", "guest") + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url, err := GetConnURL(tt.config) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if url == "" { + t.Error("expected non-empty URL, got empty string") + } + }) + } +} + +func TestGetConnURLFailure(t *testing.T) { + tests := []struct { + name string + username string + password string + config *Config + wantErr bool + }{ + { + name: "Missing credentials", + username: "", + password: "", + config: &Config{Addr: "localhost:5672"}, + wantErr: true, + }, + { + name: "Missing config address", + username: "guest", + password: "guest", + config: &Config{}, // this won't error unless Validate() is called separately + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.username != "" { + t.Setenv("RABBITMQ_USERNAME", tt.username) + } + + if tt.password != "" { + t.Setenv("RABBITMQ_PASSWORD", tt.password) + } + + url, err := GetConnURL(tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("unexpected error. gotErr = %v, wantErr = %v", err != nil, tt.wantErr) + } + + if err == nil && url == "" { + t.Errorf("expected non-empty URL, got empty string") + } + }) + } +} + +func TestValidateSuccess(t *testing.T) { + tests := []struct { + name string + config *Config + }{ + { + name: "Valid config with Addr and Exchange", + config: &Config{ + Addr: "localhost:5672", + Exchange: "ex", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := Validate(tt.config); err != nil { + t.Errorf("expected no error, got: %v", err) + } + }) + } +} + +func TestValidateFailure(t *testing.T) { + tests := []struct { + name string + config *Config + expectedErrr string + }{ + { + name: "Nil config", + config: nil, + expectedErrr: "config is nil", + }, + { + name: "Missing Addr", + config: &Config{Exchange: "ex"}, + expectedErrr: "missing config.Addr", + }, + { + name: "Missing Exchange", + config: &Config{Addr: "localhost:5672"}, + expectedErrr: "missing config.Exchange", + }, + { + name: "Empty Addr and Exchange", + config: &Config{Addr: " ", Exchange: " "}, + expectedErrr: "missing config.Addr", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := Validate(tt.config) + if err == nil { + t.Errorf("expected error for invalid config, got nil") + return + } + if !strings.Contains(err.Error(), tt.expectedErrr) { + t.Errorf("expected error to contain %q, got: %v", tt.expectedErrr, err) + } + }) + } +} + +type mockChannelForPublish struct { + published bool + exchange string + key string + body []byte + fail bool +} + +func (m *mockChannelForPublish) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error { + if m.fail { + return fmt.Errorf("simulated publish failure") + } + m.published = true + m.exchange = exchange + m.key = key + m.body = msg.Body + return nil +} + +func (m *mockChannelForPublish) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error { + return nil +} + +func (m *mockChannelForPublish) Close() error { + return nil +} + +func TestPublishSuccess(t *testing.T) { + mockCh := &mockChannelForPublish{} + + p := &Publisher{ + Channel: mockCh, + Config: &Config{ + Exchange: "mock.exchange", + RoutingKey: "mock.key", + }, + } + + err := p.Publish(context.Background(), "", []byte(`{"test": true}`)) + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + + if !mockCh.published { + t.Error("expected message to be published, but it wasn't") + } + + if mockCh.exchange != "mock.exchange" || mockCh.key != "mock.key" { + t.Errorf("unexpected exchange or key. got (%s, %s)", mockCh.exchange, mockCh.key) + } +} + +func TestPublishFailure(t *testing.T) { + mockCh := &mockChannelForPublish{fail: true} + + p := &Publisher{ + Channel: mockCh, + Config: &Config{ + Exchange: "mock.exchange", + RoutingKey: "mock.key", + }, + } + + err := p.Publish(context.Background(), "", []byte(`{"test": true}`)) + if err == nil { + t.Error("expected error from failed publish, got nil") + } +} + +type mockChannel struct{} + +func (m *mockChannel) PublishWithContext(ctx context.Context, exchange, key string, mandatory, immediate bool, msg amqp091.Publishing) error { + return nil +} +func (m *mockChannel) ExchangeDeclare(name, kind string, durable, autoDelete, internal, noWait bool, args amqp091.Table) error { + return nil +} +func (m *mockChannel) Close() error { + return nil +} + +func TestNewPublisherSucess(t *testing.T) { + originalDialFunc := DialFunc + originalChannelFunc := ChannelFunc + defer func() { + DialFunc = originalDialFunc + ChannelFunc = originalChannelFunc + }() + + // mockedConn := &mockConnection{} + + DialFunc = func(url string) (*amqp091.Connection, error) { + return nil, nil + } + + ChannelFunc = func(conn *amqp091.Connection) (Channel, error) { + return &mockChannel{}, nil + } + + cfg := &Config{ + Addr: "localhost", + Exchange: "test-ex", + Durable: true, + RoutingKey: "test.key", + } + + t.Setenv("RABBITMQ_USERNAME", "user") + t.Setenv("RABBITMQ_PASSWORD", "pass") + + pub, cleanup, err := New(cfg) + if err != nil { + t.Fatalf("New() failed: %v", err) + } + if pub == nil { + t.Fatal("Publisher should not be nil") + } + if cleanup == nil { + t.Fatal("Cleanup should not be nil") + } + if err := cleanup(); err != nil { + t.Errorf("Cleanup failed: %v", err) + } +} + +func TestNewPublisherFailures(t *testing.T) { + tests := []struct { + name string + cfg *Config + dialFunc func(url string) (*amqp091.Connection, error) // Mocked dial function + envVars map[string]string + expectedError string + }{ + { + name: "ValidateFailure", + cfg: &Config{}, // invalid config + expectedError: "missing config.Addr", + }, + { + name: "GetConnURLFailure", + cfg: &Config{ + Addr: "localhost", + Exchange: "test-ex", + Durable: true, + RoutingKey: "test.key", + }, + envVars: map[string]string{ + "RABBITMQ_USERNAME": "", + "RABBITMQ_PASSWORD": "", + }, + expectedError: "missing RabbitMQ credentials in environment", + }, + { + name: "ConnectionFailure", + cfg: &Config{ + Addr: "localhost", + Exchange: "test-ex", + Durable: true, + RoutingKey: "test.key", + }, + dialFunc: func(url string) (*amqp091.Connection, error) { + return nil, fmt.Errorf("simulated connection failure") + }, + envVars: map[string]string{ + "RABBITMQ_USERNAME": "user", + "RABBITMQ_PASSWORD": "pass", + }, + expectedError: "failed to connect to RabbitMQ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set environment variables + for key, value := range tt.envVars { + t.Setenv(key, value) + } + + // Mock dialFunc if needed + originalDialFunc := DialFunc + if tt.dialFunc != nil { + DialFunc = tt.dialFunc + defer func() { + DialFunc = originalDialFunc + }() + } + + _, _, err := New(tt.cfg) + + if err == nil || (tt.expectedError != "" && !strings.Contains(err.Error(), tt.expectedError)) { + t.Errorf("Test %s failed: expected error containing %v, got: %v", tt.name, tt.expectedError, err) + } + }) + } +}