vault/builtin/logical/transit/path_rewrap_test.go
Vault Automation caf642b7d2
Backport Vault 42177 Add Backend Field into ce/main (#12152)
* Vault 42177 Add Backend Field (#12092)

* add a new struct for the total number of successful requests for transit and transform

* implement tracking for encrypt path

* implement tracking in encrypt path

* add tracking in rewrap

* add tracking to datakey path

* add tracking to  hmac path

* add tracking to sign  path

* add tracking to verify path

* unit tests for verify path

* add tracking to cmac path

* reset the global counter in each unit test

* add tracking to hmac verify

* add methods to retrieve and flush transit count

* modify the methods that store and update data protection call counts

* update the methods

* add a helper method to combine replicated and local data call counts

* add tracking to the endpoint

* fix some formatting errors

* add unit tests to path encrypt for tracking

* add unit tests to decrypt path

* fix linter error

* add unit tests to test update and store methods for data protection calls

* stub fix: do not create separate files

* fix the tracking by coordinating replicated and local data, add unit tests

* update all reference to the new data struct

* revert to previous design with just one global counter for all calls for each cluster

* complete external test

* no need to check if current count is greater than 0, remove it

* feedback: remove unnacassary comments about atomic addition, standardize comments

* leave jira id on todo comment, remove unused method

* rename mathods by removing HWM and max in names, update jira id in todo comment, update response field key name

* feedback: remove explicit counter in cmac tests, instead put in the expected number

* feedback: remove explicit tracking in the rest of the tests

* feedback: separate transit testing into its own external test

* Update vault/consumption_billing_util_test.go

Co-authored-by: divyaac <divya.chandrasekaran@hashicorp.com>

* update comment after test name change

* fix comments

* fix comments in test

* another comment fix

* feedback: remove incorrect comment

* fix a CE test

* fix the update method: instead of storing max, increment by the current count value

* update the unit test, remove local prefix as argument to the methods since we store only to non-replicated paths

* update the external test

* Adds a field to backend to track billing data

removed file

* Changed implementation to use a map instead

* Some more comments

* Add more implementation

* Edited grpc server backend

* Refactored a bit

* Fix one more test

* Modified map:

* Revert "Modified map:"

This reverts commit 1730fe1f358b210e6abae43fbdca09e585aaaaa8.

* Removed some other things

* Edited consumption billing files a bit

* Testing function

* Fix transit stuff and make sure tests pass

* Changes

* More changes

* More changes

* Edited external test

* Edited some more tests

* Edited and fixed tests

* One more fix

* Fix some more tests

* Moved some testing structures around and added error checking

* Fixed some nits

* Update builtin/logical/transit/path_sign_verify.go

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>

* Edited some errors

* Fixed error logs

* Edited one more thing

* Decorate the error

* Update vault/consumption_billing.go

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>

---------

Co-authored-by: Amir Aslamov <amir.aslamov@hashicorp.com>
Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>

* Edited stub function

---------

Co-authored-by: divyaac <divya.chandrasekaran@hashicorp.com>
Co-authored-by: Amir Aslamov <amir.aslamov@hashicorp.com>
Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
Co-authored-by: divyaac <divyaac@berkeley.edu>
2026-02-03 22:48:12 +00:00

450 lines
14 KiB
Go

// Copyright IBM Corp. 2016, 2025
// SPDX-License-Identifier: BUSL-1.1
package transit
import (
"context"
"strings"
"testing"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
// Check the normal flow of rewrap
func TestTransit_BatchRewrapCase1(t *testing.T) {
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
// Upsert the key and encrypt the data
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
encData := map[string]interface{}{
"plaintext": plaintext,
}
// Create a key and encrypt a plaintext
encReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "encrypt/upserted_key",
Storage: s,
Data: encData,
}
resp, err = b.HandleRequest(context.Background(), encReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Cache the ciphertext
ciphertext := resp.Data["ciphertext"]
if !strings.HasPrefix(ciphertext.(string), "vault:v1") {
t.Fatalf("bad: ciphertext version: expected: 'vault:v1', actual: %s", ciphertext)
}
keyVersion := resp.Data["key_version"].(int)
if keyVersion != 1 {
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 1)
}
rewrapData := map[string]interface{}{
"ciphertext": ciphertext,
}
// Read the policy and check if the latest version is 1
policyReq := &logical.Request{
Operation: logical.ReadOperation,
Path: "keys/upserted_key",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["latest_version"] != 1 {
t.Fatalf("bad: latest_version: expected: 1, actual: %d", resp.Data["latest_version"])
}
rotateReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/upserted_key/rotate",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), rotateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Read the policy again and the latest version is 2
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["latest_version"] != 2 {
t.Fatalf("bad: latest_version: expected: 2, actual: %d", resp.Data["latest_version"])
}
// Rewrap the ciphertext and check that they are different
rewrapReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "rewrap/upserted_key",
Storage: s,
Data: rewrapData,
}
resp, err = b.HandleRequest(context.Background(), rewrapReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if ciphertext.(string) == resp.Data["ciphertext"].(string) {
t.Fatalf("bad: ciphertexts are same before and after rewrap")
}
if !strings.HasPrefix(resp.Data["ciphertext"].(string), "vault:v2") {
t.Fatalf("bad: ciphertext version: expected: 'vault:v2', actual: %s", resp.Data["ciphertext"].(string))
}
keyVersion = resp.Data["key_version"].(int)
if keyVersion != 2 {
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2)
}
// We expect 2 successful requests (1 for encrypt, 1 for rewrap)
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
}
// Check the normal flow of rewrap with upserted key
func TestTransit_BatchRewrapCase2(t *testing.T) {
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
// Upsert the key and encrypt the data
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
encData := map[string]interface{}{
"plaintext": plaintext,
"context": "dmlzaGFsCg==",
}
// Create a key and encrypt a plaintext
encReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "encrypt/upserted_key",
Storage: s,
Data: encData,
}
resp, err = b.HandleRequest(context.Background(), encReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Cache the ciphertext
ciphertext := resp.Data["ciphertext"]
if !strings.HasPrefix(ciphertext.(string), "vault:v1") {
t.Fatalf("bad: ciphertext version: expected: 'vault:v1', actual: %s", ciphertext)
}
keyVersion := resp.Data["key_version"].(int)
if keyVersion != 1 {
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 1)
}
rewrapData := map[string]interface{}{
"ciphertext": ciphertext,
"context": "dmlzaGFsCg==",
}
// Read the policy and check if the latest version is 1
policyReq := &logical.Request{
Operation: logical.ReadOperation,
Path: "keys/upserted_key",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["latest_version"] != 1 {
t.Fatalf("bad: latest_version: expected: 1, actual: %d", resp.Data["latest_version"])
}
rotateReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/upserted_key/rotate",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), rotateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
// Read the policy again and the latest version is 2
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Data["latest_version"] != 2 {
t.Fatalf("bad: latest_version: expected: 2, actual: %d", resp.Data["latest_version"])
}
// Rewrap the ciphertext and check that they are different
rewrapReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "rewrap/upserted_key",
Storage: s,
Data: rewrapData,
}
resp, err = b.HandleRequest(context.Background(), rewrapReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if ciphertext.(string) == resp.Data["ciphertext"].(string) {
t.Fatalf("bad: ciphertexts are same before and after rewrap")
}
if !strings.HasPrefix(resp.Data["ciphertext"].(string), "vault:v2") {
t.Fatalf("bad: ciphertext version: expected: 'vault:v2', actual: %s", resp.Data["ciphertext"].(string))
}
keyVersion = resp.Data["key_version"].(int)
if keyVersion != 2 {
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2)
}
// We expect 2 successful transit requests (1 for encrypt, 1 for rewrap)
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
}
// Batch encrypt plaintexts, rotate the keys and rewrap all the ciphertexts
func TestTransit_BatchRewrapCase3(t *testing.T) {
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
batchEncryptionInput := []interface{}{
map[string]interface{}{"plaintext": "dmlzaGFsCg==", "reference": "ek"},
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "reference": "do"},
}
batchEncryptionData := map[string]interface{}{
"batch_input": batchEncryptionInput,
}
batchReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "encrypt/upserted_key",
Storage: s,
Data: batchEncryptionData,
}
resp, err = b.HandleRequest(context.Background(), batchReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
batchEncryptionResponseItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
batchRewrapInput := make([]interface{}, len(batchEncryptionResponseItems))
for i, item := range batchEncryptionResponseItems {
batchRewrapInput[i] = map[string]interface{}{"ciphertext": item.Ciphertext, "reference": item.Reference}
}
batchRewrapData := map[string]interface{}{
"batch_input": batchRewrapInput,
}
rotateReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/upserted_key/rotate",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), rotateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
rewrapReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "rewrap/upserted_key",
Storage: s,
Data: batchRewrapData,
}
resp, err = b.HandleRequest(context.Background(), rewrapReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
batchRewrapResponseItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
if len(batchRewrapResponseItems) != len(batchEncryptionResponseItems) {
t.Fatalf("bad: length of input and output or rewrap are not matching; expected: %d, actual: %d", len(batchEncryptionResponseItems), len(batchRewrapResponseItems))
}
decReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "decrypt/upserted_key",
Storage: s,
}
for i, eItem := range batchEncryptionResponseItems {
rItem := batchRewrapResponseItems[i]
inputRef := batchEncryptionInput[i].(map[string]interface{})["reference"]
if eItem.Reference != inputRef {
t.Fatalf("bad: reference mismatch. Expected %s, Actual: %s", inputRef, eItem.Reference)
}
if eItem.Ciphertext == rItem.Ciphertext {
t.Fatalf("bad: rewrap input and output are the same")
}
if !strings.HasPrefix(rItem.Ciphertext, "vault:v2") {
t.Fatalf("bad: invalid version of ciphertext in rewrap response; expected: 'vault:v2', actual: %s", rItem.Ciphertext)
}
if rItem.KeyVersion != 2 {
t.Fatalf("unexpected key version; got: %d, expected: %d", rItem.KeyVersion, 2)
}
decReq.Data = map[string]interface{}{
"ciphertext": rItem.Ciphertext,
}
resp, err = b.HandleRequest(context.Background(), decReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
plaintext1 := "dGhlIHF1aWNrIGJyb3duIGZveA=="
plaintext2 := "dmlzaGFsCg=="
if resp.Data["plaintext"] != plaintext1 && resp.Data["plaintext"] != plaintext2 {
t.Fatalf("bad: plaintext. Expected: %q or %q, Actual: %q", plaintext1, plaintext2, resp.Data["plaintext"])
}
}
// We expect 6 successful transit requests (2 for batch encryption, 2 for batch rewrap, and 2 for decryption)
require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load())
}
// TestTransit_BatchRewrapCase4 batch rewrap leveraging RSA padding schemes
func TestTransit_BatchRewrapCase4(t *testing.T) {
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
batchEncryptionInput := []interface{}{
map[string]interface{}{"plaintext": "dmlzaGFsCg==", "reference": "ek", "padding_scheme": "pkcs1v15"},
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "reference": "do", "padding_scheme": "pkcs1v15"},
}
batchEncryptionData := map[string]interface{}{
"type": "rsa-2048",
"batch_input": batchEncryptionInput,
}
batchReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "encrypt/upserted_key",
Storage: s,
Data: batchEncryptionData,
}
resp, err = b.HandleRequest(context.Background(), batchReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
batchEncryptionResponseItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
batchRewrapInput := make([]interface{}, len(batchEncryptionResponseItems))
for i, item := range batchEncryptionResponseItems {
batchRewrapInput[i] = map[string]interface{}{
"ciphertext": item.Ciphertext,
"reference": item.Reference,
"decrypt_padding_scheme": "pkcs1v15",
"encrypt_padding_scheme": "oaep",
}
}
batchRewrapData := map[string]interface{}{
"batch_input": batchRewrapInput,
}
rotateReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/upserted_key/rotate",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), rotateReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
rewrapReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "rewrap/upserted_key",
Storage: s,
Data: batchRewrapData,
}
resp, err = b.HandleRequest(context.Background(), rewrapReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
batchRewrapResponseItems := resp.Data["batch_results"].([]EncryptBatchResponseItem)
if len(batchRewrapResponseItems) != len(batchEncryptionResponseItems) {
t.Fatalf("bad: length of input and output or rewrap are not matching; expected: %d, actual: %d", len(batchEncryptionResponseItems), len(batchRewrapResponseItems))
}
decReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "decrypt/upserted_key",
Storage: s,
}
for i, eItem := range batchEncryptionResponseItems {
rItem := batchRewrapResponseItems[i]
inputRef := batchEncryptionInput[i].(map[string]interface{})["reference"]
if eItem.Reference != inputRef {
t.Fatalf("bad: reference mismatch. Expected %s, Actual: %s", inputRef, eItem.Reference)
}
if eItem.Ciphertext == rItem.Ciphertext {
t.Fatalf("bad: rewrap input and output are the same")
}
if !strings.HasPrefix(rItem.Ciphertext, "vault:v2") {
t.Fatalf("bad: invalid version of ciphertext in rewrap response; expected: 'vault:v2', actual: %s", rItem.Ciphertext)
}
if rItem.KeyVersion != 2 {
t.Fatalf("unexpected key version; got: %d, expected: %d", rItem.KeyVersion, 2)
}
decReq.Data = map[string]interface{}{
"ciphertext": rItem.Ciphertext,
}
resp, err = b.HandleRequest(context.Background(), decReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
plaintext1 := "dGhlIHF1aWNrIGJyb3duIGZveA=="
plaintext2 := "dmlzaGFsCg=="
if resp.Data["plaintext"] != plaintext1 && resp.Data["plaintext"] != plaintext2 {
t.Fatalf("bad: plaintext. Expected: %q or %q, Actual: %q", plaintext1, plaintext2, resp.Data["plaintext"])
}
}
// We expect 6 succcessful calls to the transit backend (2 for batch encryption, 2 for batch decryption, and 2 for batch rewrap)
require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load())
}