vault/builtin/logical/transit/path_decrypt_test.go
Vault Automation 81c1c3778b
VAULT-41092: transit engine metrics (#11814) (#12103)
* add a new struct for the total number of successful requests for transit and transform

* implement tracking for encrypt path

* implement tracking in encrypt path

* add tracking in rewrap

* add tracking to datakey path

* add tracking to  hmac path

* add tracking to sign  path

* add tracking to verify path

* unit tests for verify path

* add tracking to cmac path

* reset the global counter in each unit test

* add tracking to hmac verify

* add methods to retrieve and flush transit count

* modify the methods that store and update data protection call counts

* update the methods

* add a helper method to combine replicated and local data call counts

* add tracking to the endpoint

* fix some formatting errors

* add unit tests to path encrypt for tracking

* add unit tests to decrypt path

* fix linter error

* add unit tests to test update and store methods for data protection calls

* stub fix: do not create separate files

* fix the tracking by coordinating replicated and local data, add unit tests

* update all reference to the new data struct

* revert to previous design with just one global counter for all calls for each cluster

* complete external test

* no need to check if current count is greater than 0, remove it

* feedback: remove unnacassary comments about atomic addition, standardize comments

* leave jira id on todo comment, remove unused method

* rename mathods by removing HWM and max in names, update jira id in todo comment, update response field key name

* feedback: remove explicit counter in cmac tests, instead put in the expected number

* feedback: remove explicit tracking in the rest of the tests

* feedback: separate transit testing into its own external test

* Update vault/consumption_billing_util_test.go



* update comment after test name change

* fix comments

* fix comments in test

* another comment fix

* feedback: remove incorrect comment

* fix a CE test

* fix the update method: instead of storing max, increment by the current count value

* update the unit test, remove local prefix as argument to the methods since we store only to non-replicated paths

* update the external test

* fix a bug: reset the counter everyime we update the stored counter value to prevent double-counting

* update one of the tests

* update external test

---------

Co-authored-by: Amir Aslamov <amir.aslamov@hashicorp.com>
Co-authored-by: divyaac <divya.chandrasekaran@hashicorp.com>
2026-01-30 15:16:05 -05:00

295 lines
9.6 KiB
Go

// Copyright IBM Corp. 2016, 2025
// SPDX-License-Identifier: BUSL-1.1
package transit
import (
"context"
"encoding/json"
"net/http"
"reflect"
"testing"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/billing"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
)
func TestTransit_BatchDecryption(t *testing.T) {
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
batchEncryptionInput := []interface{}{
map[string]interface{}{"plaintext": "", "reference": "foo"}, // empty string
map[string]interface{}{"plaintext": "Cg==", "reference": "bar"}, // newline
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "reference": "baz"},
}
batchEncryptionData := map[string]interface{}{
"batch_input": batchEncryptionInput,
}
batchEncryptionReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "encrypt/upserted_key",
Storage: s,
Data: batchEncryptionData,
}
resp, err = b.HandleRequest(context.Background(), batchEncryptionReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
batchResponseItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
batchDecryptionInput := make([]interface{}, len(batchResponseItems))
for i, item := range batchResponseItems {
batchDecryptionInput[i] = map[string]interface{}{"ciphertext": item.Ciphertext, "reference": item.Reference}
}
batchDecryptionData := map[string]interface{}{
"batch_input": batchDecryptionInput,
}
batchDecryptionReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "decrypt/upserted_key",
Storage: s,
Data: batchDecryptionData,
}
resp, err = b.HandleRequest(context.Background(), batchDecryptionReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
batchDecryptionResponseItems := resp.Data["batch_results"].([]DecryptBatchResponseItem)
// This seems fragile
expectedResult := "[{\"plaintext\":\"\",\"reference\":\"foo\"},{\"plaintext\":\"Cg==\",\"reference\":\"bar\"},{\"plaintext\":\"dGhlIHF1aWNrIGJyb3duIGZveA==\",\"reference\":\"baz\"}]"
jsonResponse, err := json.Marshal(batchDecryptionResponseItems)
if err != nil || err == nil && string(jsonResponse) != expectedResult {
t.Fatalf("bad: expected json response [%s]", jsonResponse)
}
// We expect 6 successful requests (3 for batch encryption, 3 for batch decryption)
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
}
func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
var req *logical.Request
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
// Create a derived key.
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/existing_key",
Storage: s,
Data: map[string]interface{}{
"derived": true,
},
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Encrypt some values for use in test cases.
plaintextItems := []struct {
plaintext, context string
}{
{plaintext: "dGhlIHF1aWNrIGJyb3duIGZveA==", context: "dGVzdGNvbnRleHQ="},
{plaintext: "anVtcGVkIG92ZXIgdGhlIGxhenkgZG9n", context: "dGVzdGNvbnRleHQy"},
}
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "encrypt/existing_key",
Storage: s,
Data: map[string]interface{}{
"batch_input": []interface{}{
map[string]interface{}{"plaintext": plaintextItems[0].plaintext, "context": plaintextItems[0].context},
map[string]interface{}{"plaintext": plaintextItems[1].plaintext, "context": plaintextItems[1].context},
},
},
}
resp, err = b.HandleRequest(context.Background(), req)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
encryptedItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
tests := []struct {
name string
in []interface{}
want []DecryptBatchResponseItem
shouldErr bool
wantHTTPStatus int
params map[string]interface{}
}{
{
name: "nil-input",
in: nil,
shouldErr: true,
},
{
name: "empty-input",
in: []interface{}{},
shouldErr: true,
},
{
name: "single-item-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Plaintext: plaintextItems[0].plaintext},
},
},
{
name: "single-item-invalid-ciphertext",
in: []interface{}{
map[string]interface{}{"ciphertext": "xxx", "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Error: "invalid ciphertext: no prefix"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "single-item-wrong-context",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "batch-full-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[0].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Plaintext: plaintextItems[0].plaintext},
{Plaintext: plaintextItems[1].plaintext},
},
},
{
name: "batch-partial-success",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
{Plaintext: plaintextItems[1].plaintext},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "batch-partial-success-overridden-response",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[1].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
{Plaintext: plaintextItems[1].plaintext},
},
params: map[string]interface{}{"partial_failure_response_code": http.StatusAccepted},
wantHTTPStatus: http.StatusAccepted,
},
{
name: "batch-full-failure",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
{Error: "cipher: message authentication failed"},
},
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "batch-full-failure-overridden-response",
in: []interface{}{
map[string]interface{}{"ciphertext": encryptedItems[0].Ciphertext, "context": plaintextItems[1].context},
map[string]interface{}{"ciphertext": encryptedItems[1].Ciphertext, "context": plaintextItems[0].context},
},
want: []DecryptBatchResponseItem{
{Error: "cipher: message authentication failed"},
{Error: "cipher: message authentication failed"},
},
params: map[string]interface{}{"partial_failure_response_code": http.StatusAccepted},
// Full failure, shouldn't affect status code
wantHTTPStatus: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req = &logical.Request{
Operation: logical.UpdateOperation,
Path: "decrypt/existing_key",
Storage: s,
Data: map[string]interface{}{
"batch_input": tt.in,
},
}
for k, v := range tt.params {
req.Data[k] = v
}
resp, err = b.HandleRequest(context.Background(), req)
didErr := err != nil || (resp != nil && resp.IsError())
if didErr {
if !tt.shouldErr {
t.Fatalf("unexpected error err:%v, resp:%#v", err, resp)
}
} else {
if tt.shouldErr {
t.Fatal("expected error, but none occurred")
}
if rawRespBody, ok := resp.Data[logical.HTTPRawBody]; ok {
httpResp := &logical.HTTPResponse{}
err = jsonutil.DecodeJSON([]byte(rawRespBody.(string)), httpResp)
if err != nil {
t.Fatalf("failed to unmarshal nested response: err:%v, resp:%#v", err, resp)
}
if respStatus, ok := resp.Data[logical.HTTPStatusCode]; !ok || respStatus != tt.wantHTTPStatus {
t.Fatalf("HTTP response status code mismatch, want:%d, got:%d", tt.wantHTTPStatus, respStatus)
}
resp = logical.HTTPResponseToLogicalResponse(httpResp)
}
var respItems []DecryptBatchResponseItem
err = mapstructure.Decode(resp.Data["batch_results"], &respItems)
if err != nil {
t.Fatalf("problem decoding response items: err:%v, resp:%#v", err, resp)
}
if !reflect.DeepEqual(tt.want, respItems) {
t.Fatalf("response items mismatch, want:%#v, got:%#v", tt.want, respItems)
}
}
})
}
// We expect 7 successful requests (2 for batch encryption + 1 single-item decryption + 2 batch decryption + 2 batch decryption)
require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit)
}