mirror of
https://github.com/hashicorp/vault.git
synced 2026-02-03 20:40:45 -05:00
* add a new struct for the total number of successful requests for transit and transform * implement tracking for encrypt path * implement tracking in encrypt path * add tracking in rewrap * add tracking to datakey path * add tracking to hmac path * add tracking to sign path * add tracking to verify path * unit tests for verify path * add tracking to cmac path * reset the global counter in each unit test * add tracking to hmac verify * add methods to retrieve and flush transit count * modify the methods that store and update data protection call counts * update the methods * add a helper method to combine replicated and local data call counts * add tracking to the endpoint * fix some formatting errors * add unit tests to path encrypt for tracking * add unit tests to decrypt path * fix linter error * add unit tests to test update and store methods for data protection calls * stub fix: do not create separate files * fix the tracking by coordinating replicated and local data, add unit tests * update all reference to the new data struct * revert to previous design with just one global counter for all calls for each cluster * complete external test * no need to check if current count is greater than 0, remove it * feedback: remove unnacassary comments about atomic addition, standardize comments * leave jira id on todo comment, remove unused method * rename mathods by removing HWM and max in names, update jira id in todo comment, update response field key name * feedback: remove explicit counter in cmac tests, instead put in the expected number * feedback: remove explicit tracking in the rest of the tests * feedback: separate transit testing into its own external test * Update vault/consumption_billing_util_test.go * update comment after test name change * fix comments * fix comments in test * another comment fix * feedback: remove incorrect comment * fix a CE test * fix the update method: instead of storing max, increment by the current count value * update the unit test, remove local prefix as argument to the methods since we store only to non-replicated paths * update the external test * fix a bug: reset the counter everyime we update the stored counter value to prevent double-counting * update one of the tests * update external test --------- Co-authored-by: Amir Aslamov <amir.aslamov@hashicorp.com> Co-authored-by: divyaac <divya.chandrasekaran@hashicorp.com>
This commit is contained in:
parent
8edcbc5a04
commit
81c1c3778b
19 changed files with 612 additions and 17 deletions
|
|
@ -10,6 +10,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
|
|
@ -18,6 +19,7 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -146,6 +148,11 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error
|
|||
return size, nil
|
||||
}
|
||||
|
||||
// incrementDataProtectionCounter atomically increments the data protection call counter to avoid race conditions
|
||||
func (b *backend) incrementDataProtectionCounter(count int64) {
|
||||
atomic.AddInt64(&billing.CurrentDataProtectionCallCounts.Transit, count)
|
||||
}
|
||||
|
||||
// Update cache size and get policy
|
||||
func (b *backend) GetPolicy(ctx context.Context, polReq keysutil.PolicyRequest, rand io.Reader) (retP *keysutil.Policy, retUpserted bool, retErr error) {
|
||||
// Acquire read lock to read cacheSizeChanged
|
||||
|
|
|
|||
|
|
@ -169,6 +169,11 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d
|
|||
resp.Data["plaintext"] = plaintext
|
||||
}
|
||||
|
||||
// Increment the counter for successful operations
|
||||
// Since there are not batched operations, we can add one successful
|
||||
// request to the transit request counter.
|
||||
b.incrementDataProtectionCounter(1)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -15,6 +16,9 @@ import (
|
|||
// TestDataKeyWithPaddingScheme validates that we properly leverage padding scheme
|
||||
// args for the returned keys
|
||||
func TestDataKeyWithPaddingScheme(t *testing.T) {
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
keyName := "test"
|
||||
createKeyReq := &logical.Request{
|
||||
|
|
@ -90,11 +94,16 @@ func TestDataKeyWithPaddingScheme(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
// We expect 8 successful requests ((3 valid cases x 2 operations) + (2 invalid cases x 1 operation))
|
||||
require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// TestDataKeyWithPaddingSchemeInvalidKeyType validates we fail when we specify a
|
||||
// padding_scheme value on an invalid key type (non-RSA)
|
||||
func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) {
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
keyName := "test"
|
||||
createKeyReq := &logical.Request{
|
||||
|
|
@ -122,4 +131,6 @@ func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) {
|
|||
require.ErrorContains(t, err, "invalid request")
|
||||
require.NotNil(t, resp, "response should not be nil")
|
||||
require.Contains(t, resp.Error().Error(), "padding_scheme argument invalid: unsupported key")
|
||||
// We expect 0 successful requests
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -196,6 +196,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
defer p.Unlock()
|
||||
|
||||
successesInBatch := false
|
||||
successfulRequests := 0
|
||||
for i, item := range batchInputItems {
|
||||
if batchResponseItems[i].Error != "" {
|
||||
continue
|
||||
|
|
@ -252,6 +253,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
}
|
||||
successesInBatch = true
|
||||
batchResponseItems[i].Plaintext = plaintext
|
||||
successfulRequests++
|
||||
}
|
||||
|
||||
resp := &logical.Response{}
|
||||
|
|
@ -276,6 +278,9 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
|||
}
|
||||
}
|
||||
|
||||
// Increment the counter for successful operations
|
||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
||||
|
||||
return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ import (
|
|||
|
||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransit_BatchDecryption(t *testing.T) {
|
||||
|
|
@ -21,6 +23,9 @@ func TestTransit_BatchDecryption(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
batchEncryptionInput := []interface{}{
|
||||
map[string]interface{}{"plaintext": "", "reference": "foo"}, // empty string
|
||||
map[string]interface{}{"plaintext": "Cg==", "reference": "bar"}, // newline
|
||||
|
|
@ -69,6 +74,9 @@ func TestTransit_BatchDecryption(t *testing.T) {
|
|||
if err != nil || err == nil && string(jsonResponse) != expectedResult {
|
||||
t.Fatalf("bad: expected json response [%s]", jsonResponse)
|
||||
}
|
||||
|
||||
// We expect 6 successful requests (3 for batch encryption, 3 for batch decryption)
|
||||
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
|
||||
|
|
@ -78,6 +86,9 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
// Create a derived key.
|
||||
req = &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
|
|
@ -278,4 +289,7 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
// We expect 7 successful requests (2 for batch encryption + 1 single-item decryption + 2 batch decryption + 2 batch decryption)
|
||||
require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -531,6 +531,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
// collection and continue to process other items.
|
||||
warnAboutNonceUsage := false
|
||||
successesInBatch := false
|
||||
successfulRequests := 0
|
||||
for i, item := range batchInputItems {
|
||||
if batchResponseItems[i].Error != "" {
|
||||
userErrorInBatch = true
|
||||
|
|
@ -613,6 +614,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
|
||||
batchResponseItems[i].Ciphertext = ciphertext
|
||||
batchResponseItems[i].KeyVersion = keyVersion
|
||||
successfulRequests++
|
||||
}
|
||||
|
||||
resp := &logical.Response{}
|
||||
|
|
@ -647,6 +649,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
|||
resp.AddWarning("Attempted creation of the key during the encrypt operation, but it was created beforehand")
|
||||
}
|
||||
|
||||
// Increment the counter for successful operations
|
||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
||||
|
||||
return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,13 +15,18 @@ import (
|
|||
uuid "github.com/hashicorp/go-uuid"
|
||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransit_MissingPlaintext(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Create the policy
|
||||
|
|
@ -45,12 +50,17 @@ func TestTransit_MissingPlaintext(t *testing.T) {
|
|||
if resp == nil || !resp.IsError() {
|
||||
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
|
||||
}
|
||||
// We expect 0 successful calls
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
|
||||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Create the policy
|
||||
|
|
@ -81,6 +91,8 @@ func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
|
||||
}
|
||||
// We expect 0 successful calls
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case1: Ensure that batch encryption did not affect the normal flow of
|
||||
|
|
@ -89,6 +101,9 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) {
|
|||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Create the policy
|
||||
|
|
@ -143,6 +158,9 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) {
|
|||
if resp.Data["plaintext"] != plaintext {
|
||||
t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
|
||||
}
|
||||
|
||||
// We expect 2 successful requests (1 for encrypt, 1 for decrypt)
|
||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case2: Ensure that batch encryption did not affect the normal flow of
|
||||
|
|
@ -152,6 +170,9 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) {
|
|||
var err error
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
// Upsert the key and encrypt the data
|
||||
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
|
||||
|
|
@ -205,11 +226,16 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) {
|
|||
if resp.Data["plaintext"] != plaintext {
|
||||
t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
|
||||
}
|
||||
|
||||
// We expect 2 successful requests (1 for encrypt, 1 for decrypt)
|
||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case3: If batch encryption input is not base64 encoded, it should fail.
|
||||
func TestTransit_BatchEncryptionCase3(t *testing.T) {
|
||||
var err error
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
|
|
@ -228,6 +254,9 @@ func TestTransit_BatchEncryptionCase3(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
}
|
||||
|
||||
// We expect 0 successful requests
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case4: Test batch encryption with an existing key (and test references)
|
||||
|
|
@ -235,6 +264,9 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) {
|
|||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
policyReq := &logical.Request{
|
||||
|
|
@ -297,6 +329,9 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) {
|
|||
t.Fatalf("reference mismatch. Expected %s, Actual: %s", inputItem["reference"], item.Reference)
|
||||
}
|
||||
}
|
||||
|
||||
// We expect 4 successful requests (2 batch requests + 2 decrypt requests)
|
||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case5: Test batch encryption with an existing derived key
|
||||
|
|
@ -304,6 +339,9 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) {
|
|||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
policyData := map[string]interface{}{
|
||||
|
|
@ -370,6 +408,8 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) {
|
|||
t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
|
||||
}
|
||||
}
|
||||
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case6: Test batch encryption with an upserted non-derived key
|
||||
|
|
@ -377,6 +417,9 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) {
|
|||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
batchInput := []interface{}{
|
||||
|
|
@ -430,6 +473,9 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) {
|
|||
t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
|
||||
}
|
||||
}
|
||||
|
||||
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case7: Test batch encryption with an upserted derived key
|
||||
|
|
@ -439,6 +485,9 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
batchInput := []interface{}{
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
||||
|
|
@ -486,6 +535,8 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) {
|
|||
t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"])
|
||||
}
|
||||
}
|
||||
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case8: If plaintext is not base64 encoded, encryption should fail
|
||||
|
|
@ -495,6 +546,9 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
// Create the policy
|
||||
policyReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
|
|
@ -539,6 +593,8 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
}
|
||||
// We expect 0 successful transit requests
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case9: If both plaintext and batch inputs are supplied, plaintext should be
|
||||
|
|
@ -549,6 +605,9 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
batchInput := []interface{}{
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
||||
|
|
@ -573,6 +632,9 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) {
|
|||
if ok {
|
||||
t.Fatal("ciphertext field should not be set")
|
||||
}
|
||||
|
||||
// We expect 2 successful batch encryptions
|
||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case10: Inconsistent presence of 'context' in batch input should be caught
|
||||
|
|
@ -581,6 +643,9 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
batchInput := []interface{}{
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
||||
|
|
@ -600,6 +665,8 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Fatalf("expected an error")
|
||||
}
|
||||
// We expect no successful transit requests
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case11: Incorrect inputs for context and nonce should not fail the operation
|
||||
|
|
@ -608,6 +675,9 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
batchInput := []interface{}{
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "not-encoded"},
|
||||
|
|
@ -626,6 +696,8 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// We expect 1 successful encryption out of the 2-item batch
|
||||
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case12: Invalid batch input
|
||||
|
|
@ -633,6 +705,9 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) {
|
|||
var err error
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
batchInput := []interface{}{
|
||||
map[string]interface{}{},
|
||||
"unexpected_interface",
|
||||
|
|
@ -651,6 +726,8 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Fatalf("expected an error")
|
||||
}
|
||||
// We expect no successful requests
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case13: Incorrect input for nonce when we aren't in convergent encryption should fail the operation
|
||||
|
|
@ -659,6 +736,9 @@ func TestTransit_EncryptionCase13(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
// Non-batch first
|
||||
data := map[string]interface{}{"plaintext": "bXkgc2VjcmV0IGRhdGE=", "nonce": "R80hr9eNUIuFV52e"}
|
||||
req := &logical.Request{
|
||||
|
|
@ -693,6 +773,8 @@ func TestTransit_EncryptionCase13(t *testing.T) {
|
|||
if v, ok := resp.Data["http_status_code"]; !ok || v.(int) != http.StatusBadRequest {
|
||||
t.Fatal("expected request error")
|
||||
}
|
||||
// We expect no successful transit requests
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Case14: Incorrect input for nonce when we are in convergent version 3 should fail
|
||||
|
|
@ -701,6 +783,9 @@ func TestTransit_EncryptionCase14(t *testing.T) {
|
|||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
cReq := &logical.Request{
|
||||
Operation: logical.UpdateOperation,
|
||||
Path: "keys/my-key",
|
||||
|
|
@ -750,6 +835,8 @@ func TestTransit_EncryptionCase14(t *testing.T) {
|
|||
if v, ok := resp.Data["http_status_code"]; !ok || v.(int) != http.StatusBadRequest {
|
||||
t.Fatal("expected request error")
|
||||
}
|
||||
// We expect no successful transit requests
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem.
|
||||
|
|
|
|||
|
|
@ -189,6 +189,9 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr
|
|||
|
||||
response := make([]batchResponseHMACItem, len(batchInputItems))
|
||||
|
||||
// Count successful HMAC operations
|
||||
successfulRequests := 0
|
||||
|
||||
for i, item := range batchInputItems {
|
||||
rawInput, ok := item["input"]
|
||||
if !ok {
|
||||
|
|
@ -225,6 +228,7 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr
|
|||
retStr := base64.StdEncoding.EncodeToString(retBytes)
|
||||
retStr = fmt.Sprintf("vault:v%s:%s", strconv.Itoa(ver), retStr)
|
||||
response[i].HMAC = retStr
|
||||
successfulRequests++
|
||||
}
|
||||
|
||||
// Generate the response
|
||||
|
|
@ -250,6 +254,9 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr
|
|||
}
|
||||
}
|
||||
|
||||
// Increment the counter for successful operations
|
||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
@ -320,6 +327,7 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f
|
|||
|
||||
response := make([]batchResponseHMACItem, len(batchInputItems))
|
||||
|
||||
successfulRequests := 0
|
||||
for i, item := range batchInputItems {
|
||||
rawInput, ok := item["input"]
|
||||
if !ok {
|
||||
|
|
@ -398,6 +406,7 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f
|
|||
hf.Write(input)
|
||||
retBytes := hf.Sum(nil)
|
||||
response[i].Valid = hmac.Equal(retBytes, verBytes)
|
||||
successfulRequests++
|
||||
}
|
||||
|
||||
// Generate the response
|
||||
|
|
@ -423,6 +432,9 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f
|
|||
}
|
||||
}
|
||||
|
||||
// Increment the counter for successful operations
|
||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,10 +16,14 @@ import (
|
|||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTransit_HMAC(t *testing.T) {
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, storage := createBackendWithSysView(t)
|
||||
|
||||
cases := []struct {
|
||||
|
|
@ -244,9 +248,14 @@ func TestTransit_HMAC(t *testing.T) {
|
|||
t.Fatalf("expected invalid request error, got %v", err)
|
||||
}
|
||||
}
|
||||
// Verify the total successful transit requests
|
||||
require.Equal(t, int64(72), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
func TestTransit_batchHMAC(t *testing.T) {
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, storage := createBackendWithSysView(t)
|
||||
|
||||
// First create a key
|
||||
|
|
@ -393,12 +402,16 @@ func TestTransit_batchHMAC(t *testing.T) {
|
|||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
// do not increment the counter for failed HMAC verify operations
|
||||
|
||||
batchHMACVerifyResponseItems = resp.Data["batch_results"].([]batchResponseHMACItem)
|
||||
|
||||
if batchHMACVerifyResponseItems[0].Valid {
|
||||
t.Fatalf("expected error validating hmac\nreq\n%#v\nresp\n%#v", *req, *resp)
|
||||
}
|
||||
|
||||
// Verify the total successful transit requests
|
||||
require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// TestHMACBatchResultsFields checks that responses to HMAC verify requests using batch_input
|
||||
|
|
@ -427,6 +440,9 @@ func TestHMACBatchResultsFields(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
keyName := "hmac-test-key"
|
||||
_, err = client.Logical().Write("transit/keys/"+keyName, map[string]interface{}{"type": "hmac", "key_size": 32})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -487,4 +503,7 @@ func TestHMACBatchResultsFields(t *testing.T) {
|
|||
require.Contains(t, result, "reference")
|
||||
require.Contains(t, result, "valid")
|
||||
}
|
||||
|
||||
// We expect 4 successful requests (2 for batch HMAC generation, 2 for batch verification)
|
||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -202,6 +202,7 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d *
|
|||
defer p.Unlock()
|
||||
|
||||
warnAboutNonceUsage := false
|
||||
successfulRequests := 0
|
||||
for i, item := range batchInputItems {
|
||||
if batchResponseItems[i].Error != "" {
|
||||
continue
|
||||
|
|
@ -281,6 +282,7 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d *
|
|||
|
||||
batchResponseItems[i].Ciphertext = ciphertext
|
||||
batchResponseItems[i].KeyVersion = keyVersion
|
||||
successfulRequests++
|
||||
}
|
||||
|
||||
resp := &logical.Response{}
|
||||
|
|
@ -306,6 +308,9 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d *
|
|||
resp.AddWarning("A provided nonce value was used within FIPS mode, this violates FIPS 140 compliance.")
|
||||
}
|
||||
|
||||
// Increment the counter for successful operations
|
||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Check the normal flow of rewrap
|
||||
|
|
@ -17,6 +19,9 @@ func TestTransit_BatchRewrapCase1(t *testing.T) {
|
|||
var err error
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
// Upsert the key and encrypt the data
|
||||
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
|
||||
|
|
@ -112,6 +117,9 @@ func TestTransit_BatchRewrapCase1(t *testing.T) {
|
|||
if keyVersion != 2 {
|
||||
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2)
|
||||
}
|
||||
|
||||
// We expect 2 successful requests (1 for encrypt, 1 for rewrap)
|
||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Check the normal flow of rewrap with upserted key
|
||||
|
|
@ -120,6 +128,9 @@ func TestTransit_BatchRewrapCase2(t *testing.T) {
|
|||
var err error
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
// Upsert the key and encrypt the data
|
||||
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
|
||||
|
|
@ -217,6 +228,8 @@ func TestTransit_BatchRewrapCase2(t *testing.T) {
|
|||
if keyVersion != 2 {
|
||||
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2)
|
||||
}
|
||||
// We expect 2 successful transit requests (1 for encrypt, 1 for rewrap)
|
||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// Batch encrypt plaintexts, rotate the keys and rewrap all the ciphertexts
|
||||
|
|
@ -224,6 +237,9 @@ func TestTransit_BatchRewrapCase3(t *testing.T) {
|
|||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
batchEncryptionInput := []interface{}{
|
||||
|
|
@ -323,8 +339,9 @@ func TestTransit_BatchRewrapCase3(t *testing.T) {
|
|||
if resp.Data["plaintext"] != plaintext1 && resp.Data["plaintext"] != plaintext2 {
|
||||
t.Fatalf("bad: plaintext. Expected: %q or %q, Actual: %q", plaintext1, plaintext2, resp.Data["plaintext"])
|
||||
}
|
||||
|
||||
}
|
||||
// We expect 6 successful transit requests (2 for batch encryption, 2 for batch rewrap, and 2 for decryption)
|
||||
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
// TestTransit_BatchRewrapCase4 batch rewrap leveraging RSA padding schemes
|
||||
|
|
@ -332,6 +349,9 @@ func TestTransit_BatchRewrapCase4(t *testing.T) {
|
|||
var resp *logical.Response
|
||||
var err error
|
||||
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, s := createBackendWithStorage(t)
|
||||
|
||||
batchEncryptionInput := []interface{}{
|
||||
|
|
@ -438,4 +458,6 @@ func TestTransit_BatchRewrapCase4(t *testing.T) {
|
|||
t.Fatalf("bad: plaintext. Expected: %q or %q, Actual: %q", plaintext1, plaintext2, resp.Data["plaintext"])
|
||||
}
|
||||
}
|
||||
// We expect 6 succcessful calls to the transit backend (2 for batch encryption, 2 for batch decryption, and 2 for batch rewrap)
|
||||
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -409,6 +409,7 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
|
|||
|
||||
batchInputRaw := d.Raw["batch_input"]
|
||||
var batchInputItems []batchRequestSignItem
|
||||
successfulRequests := 0
|
||||
if batchInputRaw != nil {
|
||||
err = mapstructure.Decode(batchInputRaw, &batchInputItems)
|
||||
if err != nil {
|
||||
|
|
@ -458,6 +459,7 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
|
|||
response[i].Signature = sig.Signature
|
||||
response[i].PublicKey = sig.PublicKey
|
||||
response[i].KeyVersion = keyVersion
|
||||
successfulRequests++
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -490,6 +492,9 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
|
|||
}
|
||||
}
|
||||
|
||||
// Increment the counter for successful operations
|
||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
@ -694,6 +699,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
|
|||
|
||||
response := make([]batchResponseVerifyItem, len(batchInputItems))
|
||||
|
||||
successfulRequests := 0
|
||||
for i, item := range batchInputItems {
|
||||
pva, err := b.getPolicyVerifyArgs(ctx, p, apiArgs, item)
|
||||
if err != nil {
|
||||
|
|
@ -716,6 +722,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
|
|||
}
|
||||
} else {
|
||||
response[i].Valid = valid
|
||||
successfulRequests++
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -743,6 +750,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
|
|||
}
|
||||
}
|
||||
|
||||
// Increment the counter for successful operations
|
||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/hashicorp/vault/helper/constants"
|
||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
|
|
@ -32,6 +33,9 @@ type signOutcome struct {
|
|||
}
|
||||
|
||||
func TestTransit_SignVerify_ECDSA(t *testing.T) {
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
t.Run("256", func(t *testing.T) {
|
||||
testTransit_SignVerify_ECDSA(t, 256)
|
||||
})
|
||||
|
|
@ -41,6 +45,9 @@ func TestTransit_SignVerify_ECDSA(t *testing.T) {
|
|||
t.Run("521", func(t *testing.T) {
|
||||
testTransit_SignVerify_ECDSA(t, 521)
|
||||
})
|
||||
|
||||
// Verify the total successful transit requests
|
||||
require.Equal(t, int64(84), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
func testTransit_SignVerify_ECDSA(t *testing.T, bits int) {
|
||||
|
|
@ -373,6 +380,9 @@ func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, ex
|
|||
// TestTransit_SignVerify_Ed25519Behavior makes sure the options on ENT for a
|
||||
// Ed25519ph/ctx signature fail on CE and ENT if invalid
|
||||
func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) {
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, storage := createBackendWithSysView(t)
|
||||
|
||||
// First create a key
|
||||
|
|
@ -465,9 +475,19 @@ func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
// Verify the total successful transit requests
|
||||
if constants.IsEnterprise {
|
||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
} else {
|
||||
// We expect 0 successful calls on CE because we expect the verify to fail
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransit_SignVerify_ED25519(t *testing.T) {
|
||||
// Reset the transit counter
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
b, storage := createBackendWithSysView(t)
|
||||
|
||||
// First create a key
|
||||
|
|
@ -735,6 +755,7 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
|
|||
// Repeat with the other key
|
||||
sig = signRequest(req, false, "bar")
|
||||
verifyRequest(req, false, outcome, "bar", sig, true)
|
||||
// Now try the v1
|
||||
verifyRequest(req, true, outcome, "bar", v1sig, true)
|
||||
|
||||
// Test Batch Signing
|
||||
|
|
@ -809,6 +830,9 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
|
|||
outcome[1].requestOk = true
|
||||
outcome[1].valid = false
|
||||
verifyRequest(req, false, outcome, "bar", goodsig, true)
|
||||
|
||||
// Verify the total successful transit requests
|
||||
require.Equal(t, int64(24), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
}
|
||||
|
||||
func TestTransit_SignVerify_RSA_PSS(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -10,13 +10,15 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
BillingSubPath = "billing/"
|
||||
ReplicatedPrefix = "replicated/"
|
||||
RoleHWMCountsHWM = "maxRoleCounts/"
|
||||
KvHWMCountsHWM = "maxKvCounts/"
|
||||
LocalPrefix = "local/"
|
||||
ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/"
|
||||
BillingWriteInterval = 10 * time.Minute
|
||||
BillingSubPath = "billing/"
|
||||
ReplicatedPrefix = "replicated/"
|
||||
RoleHWMCountsHWM = "maxRoleCounts/"
|
||||
KvHWMCountsHWM = "maxKvCounts/"
|
||||
DataProtectionCallCountsMetric = "dataProtectionCallCounts/"
|
||||
LocalPrefix = "local/"
|
||||
ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/"
|
||||
|
||||
BillingWriteInterval = 10 * time.Minute
|
||||
)
|
||||
|
||||
var BillingMonthStorageFormat = "%s%d/%02d/%s" // e.g replicated/2026/01/maxKvCounts/
|
||||
|
|
@ -41,3 +43,12 @@ func GetMonthlyBillingPath(localPrefix string, now time.Time, billingMetric stri
|
|||
month := int(now.Month())
|
||||
return fmt.Sprintf(BillingMonthStorageFormat, localPrefix, year, month, billingMetric)
|
||||
}
|
||||
|
||||
type DataProtectionCallCounts struct {
|
||||
Transit int64 `json:"transit,omitempty"`
|
||||
// TODO: Uncomment when we add support for Transform tracking (VAULT-41205)
|
||||
// Transform int64 `json:"transform,omitempty"`
|
||||
}
|
||||
|
||||
// Global counter for all data protection calls on this cluster
|
||||
var CurrentDataProtectionCallCounts = DataProtectionCallCounts{}
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ func (c *Core) updateBillingMetrics(ctx context.Context) error {
|
|||
c.UpdateReplicatedHWMMetrics(ctx, currentMonth)
|
||||
}
|
||||
c.UpdateLocalHWMMetrics(ctx, currentMonth)
|
||||
c.UpdateLocalAggregatedMetrics(ctx, currentMonth)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -112,5 +113,16 @@ func (c *Core) UpdateLocalHWMMetrics(ctx context.Context, currentMonth time.Time
|
|||
} else {
|
||||
c.logger.Info("updated local max external plugin counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Core) UpdateLocalAggregatedMetrics(ctx context.Context, currentMonth time.Time) error {
|
||||
if _, err := c.UpdateDataProtectionCallCounts(ctx, currentMonth); err != nil {
|
||||
c.logger.Error("error updating local max data protection call counts", "error", err)
|
||||
} else {
|
||||
c.logger.Info("updated local max data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ package vault
|
|||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
|
|
@ -66,7 +67,7 @@ func (c *Core) GetStoredThirdPartyPluginCounts(ctx context.Context, month time.T
|
|||
return c.getStoredThirdPartyPluginCountsLocked(ctx, billing.LocalPrefix, month)
|
||||
}
|
||||
|
||||
func combineRoleCounts(ctx context.Context, a, b *RoleCounts) *RoleCounts {
|
||||
func combineRoleCounts(a, b *RoleCounts) *RoleCounts {
|
||||
if a == nil && b == nil {
|
||||
return &RoleCounts{}
|
||||
}
|
||||
|
|
@ -263,3 +264,64 @@ func (c *Core) compareCounts(current, previous int, metricName string) int {
|
|||
func (c *Core) GetBillingSubView() *BarrierView {
|
||||
return c.systemBarrierView.SubView(billing.BillingSubPath)
|
||||
}
|
||||
|
||||
// storeDataProtectionCallCountsLocked must be called with BillingStorageLock held
|
||||
func (c *Core) storeDataProtectionCallCountsLocked(ctx context.Context, maxCounts *billing.DataProtectionCallCounts, localPathPrefix string, month time.Time) error {
|
||||
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric)
|
||||
entry, err := logical.StorageEntryJSON(billingPath, maxCounts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.GetBillingSubView().Put(ctx, entry)
|
||||
}
|
||||
|
||||
// getStoredDataProtectionCallCountsLocked must be called with BillingStorageLock held
|
||||
func (c *Core) getStoredDataProtectionCallCountsLocked(ctx context.Context, localPathPrefix string, month time.Time) (*billing.DataProtectionCallCounts, error) {
|
||||
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric)
|
||||
entry, err := c.GetBillingSubView().Get(ctx, billingPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return &billing.DataProtectionCallCounts{}, nil
|
||||
}
|
||||
var maxCounts billing.DataProtectionCallCounts
|
||||
if err := entry.DecodeJSON(&maxCounts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &maxCounts, nil
|
||||
}
|
||||
|
||||
func (c *Core) GetStoredDataProtectionCallCounts(ctx context.Context, month time.Time) (*billing.DataProtectionCallCounts, error) {
|
||||
c.consumptionBilling.BillingStorageLock.RLock()
|
||||
defer c.consumptionBilling.BillingStorageLock.RUnlock()
|
||||
return c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, month)
|
||||
}
|
||||
|
||||
func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth time.Time) (*billing.DataProtectionCallCounts, error) {
|
||||
c.consumptionBilling.BillingStorageLock.Lock()
|
||||
defer c.consumptionBilling.BillingStorageLock.Unlock()
|
||||
|
||||
// Read and atomically reset the current counter value
|
||||
// TODO: Reset Transform call counts too (VAULT-41205)
|
||||
currentTransitCount := atomic.SwapInt64(&billing.CurrentDataProtectionCallCounts.Transit, 0)
|
||||
|
||||
storedDataProtectionCallCounts, err := c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, currentMonth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if storedDataProtectionCallCounts == nil {
|
||||
storedDataProtectionCallCounts = &billing.DataProtectionCallCounts{}
|
||||
}
|
||||
|
||||
// Sum the current count with the stored count
|
||||
storedDataProtectionCallCounts.Transit += currentTransitCount
|
||||
// TODO: Update Transform call counts (VAULT-41205)
|
||||
|
||||
err = c.storeDataProtectionCallCountsLocked(ctx, storedDataProtectionCallCounts, billing.LocalPrefix, currentMonth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return storedDataProtectionCallCounts, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
package vault
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -38,3 +39,10 @@ func verifyExpectedRoleCounts(t *testing.T, actual *RoleCounts, baseCount int) {
|
|||
}
|
||||
require.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
// testCMACOperations is a no-op in OSS since CMAC is an Enterprise-only feature.
|
||||
// Returns the current count unchanged.
|
||||
func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount int64) int64 {
|
||||
// CMAC is not supported in OSS, so we don't perform any operations
|
||||
return currentCount
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import (
|
|||
logicalDatabase "github.com/hashicorp/vault/builtin/logical/database"
|
||||
logicalNomad "github.com/hashicorp/vault/builtin/logical/nomad"
|
||||
logicalRabbitMQ "github.com/hashicorp/vault/builtin/logical/rabbitmq"
|
||||
"github.com/hashicorp/vault/builtin/logical/transit"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/helper/pluginconsts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
|
|
@ -397,6 +398,258 @@ func TestHWMKvSecretsCounts(t *testing.T) {
|
|||
require.Equal(t, 5, counts)
|
||||
}
|
||||
|
||||
// TestDataProtectionCallCounts tests that we correctly store and track the data protection call counts
|
||||
func TestDataProtectionCallCounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
coreConfig := &CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"transit": transit.Factory,
|
||||
},
|
||||
BillingConfig: billing.BillingConfig{
|
||||
MetricsUpdateCadence: 3 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
core, _, root := TestCoreUnsealedWithConfig(t, coreConfig)
|
||||
|
||||
// Mount transit backend
|
||||
req := logical.TestRequest(t, logical.CreateOperation, "sys/mounts/transit")
|
||||
req.Data["type"] = "transit"
|
||||
req.ClientToken = root
|
||||
ctx := namespace.RootContext(context.Background())
|
||||
_, err := core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reset the transit counters
|
||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
||||
|
||||
// Create an encryption key
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/foo")
|
||||
req.Data["type"] = "aes256-gcm96"
|
||||
req.ClientToken = root
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Perform encryption on the key
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo")
|
||||
req.Data["plaintext"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
req.ClientToken = root
|
||||
resp, err := core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.Data)
|
||||
|
||||
// Verify that the transit counter is incremented (replicated mount by default)
|
||||
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Get the ciphertext from the encryption response
|
||||
ciphertext, ok := resp.Data["ciphertext"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, ciphertext)
|
||||
|
||||
// Now perform decryption using the ciphertext
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/decrypt/foo")
|
||||
req.Data["ciphertext"] = ciphertext
|
||||
req.ClientToken = root
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that the transit counter is incremented
|
||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Test rewrap operation
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/rewrap/foo")
|
||||
req.Data["ciphertext"] = ciphertext
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.Data)
|
||||
|
||||
// Verify that the transit counter is incremented
|
||||
require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Get the new ciphertext from rewrap
|
||||
newCiphertext, ok := resp.Data["ciphertext"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, newCiphertext)
|
||||
|
||||
// Test datakey generation
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/datakey/plaintext/foo")
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.Data)
|
||||
|
||||
// Verify that the transit counter is incremented
|
||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Test HMAC generation
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/hmac/foo")
|
||||
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.Data)
|
||||
|
||||
// Verify that the transit counter is incremented
|
||||
require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Get the HMAC value
|
||||
hmacValue, ok := resp.Data["hmac"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, hmacValue)
|
||||
|
||||
// Test HMAC verification
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/verify/foo")
|
||||
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
req.Data["hmac"] = hmacValue
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.Data)
|
||||
|
||||
// Verify that the transit counter is incremented
|
||||
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Verify the HMAC is valid
|
||||
hmacValid, ok := resp.Data["valid"].(bool)
|
||||
require.True(t, ok)
|
||||
require.True(t, hmacValid)
|
||||
|
||||
// Create a signing key for sign/verify operations
|
||||
req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/signing-key")
|
||||
req.Data["type"] = "ecdsa-p256"
|
||||
req.ClientToken = root
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test sign operation
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/sign/signing-key")
|
||||
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.Data)
|
||||
|
||||
// Verify that the transit counter is incremented
|
||||
require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Get the signature
|
||||
signature, ok := resp.Data["signature"].(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, signature)
|
||||
|
||||
// Test verify operation
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/verify/signing-key")
|
||||
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
req.Data["signature"] = signature
|
||||
req.ClientToken = root
|
||||
resp, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.Data)
|
||||
|
||||
// Verify that the transit counter is incremented
|
||||
require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Verify the signature is valid
|
||||
signatureValid, ok := resp.Data["valid"].(bool)
|
||||
require.True(t, ok)
|
||||
require.True(t, signatureValid)
|
||||
|
||||
// Test CMAC operations (ENT only - will be no-op in OSS)
|
||||
currentCount := billing.CurrentDataProtectionCallCounts.Transit
|
||||
currentCount = testCMACOperations(t, core, ctx, root, currentCount)
|
||||
|
||||
// Verify that the transit counter matches expected count
|
||||
require.Equal(t, currentCount, billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Now test persisting the summed counts - store and retrieve counts
|
||||
// First, update the data protection call counts (this will sum current counter with stored value)
|
||||
summedCounts, err := core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, summedCounts)
|
||||
require.Equal(t, currentCount, summedCounts.Transit)
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
|
||||
|
||||
// Retrieve the stored counts
|
||||
storedCounts, err := core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedCounts)
|
||||
require.Equal(t, currentCount, storedCounts.Transit)
|
||||
|
||||
// Perform more operations to increase the counter
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo")
|
||||
req.Data["plaintext"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
req.ClientToken = root
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Counter should now be 1 (reset + 1 operation)
|
||||
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Update counts again - should sum the new count (1) with the stored count (currentCount)
|
||||
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
expectedSum := currentCount + 1
|
||||
require.Equal(t, expectedSum, summedCounts.Transit, "Count should be sum of stored and current")
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
|
||||
|
||||
// Verify stored counts are now the sum
|
||||
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSum, storedCounts.Transit)
|
||||
|
||||
// Add more operations without manually resetting
|
||||
for i := 0; i < 3; i++ {
|
||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo")
|
||||
req.Data["plaintext"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||
req.ClientToken = root
|
||||
_, err = core.HandleRequest(ctx, req)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Counter should be 3
|
||||
require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit)
|
||||
|
||||
// Update counts - should sum 3 with the previous stored sum
|
||||
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
expectedSum = expectedSum + 3
|
||||
require.Equal(t, expectedSum, summedCounts.Transit, "Count should continue to sum")
|
||||
|
||||
// Verify the counter was reset after update
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
|
||||
|
||||
// Verify stored counts
|
||||
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSum, storedCounts.Transit)
|
||||
|
||||
// Update again without any new operations
|
||||
// This verifies we don't double-count
|
||||
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSum, summedCounts.Transit, "Count should remain the same when no new operations occurred")
|
||||
|
||||
// Verify stored counts haven't changed
|
||||
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedSum, storedCounts.Transit, "Stored count should remain the same")
|
||||
|
||||
// Verify counter is still at 0
|
||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should still be 0")
|
||||
}
|
||||
|
||||
func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) {
|
||||
raw, ok := core.router.root.Get(mount + "/")
|
||||
if !ok {
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ func (b *SystemBackend) useCaseConsumptionBillingPaths() []*framework.Path {
|
|||
Type: framework.TypeMap,
|
||||
Description: "High watermark (for this month) role counts for this cluster.",
|
||||
},
|
||||
"data_protection_call_counts": {
|
||||
Type: framework.TypeMap,
|
||||
Description: "Count of data protection calls on this cluster.",
|
||||
},
|
||||
},
|
||||
}},
|
||||
http.StatusNoContent: {{
|
||||
|
|
@ -81,10 +85,19 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic
|
|||
return nil, fmt.Errorf("error retrieving local max kv counts: %w", err)
|
||||
}
|
||||
|
||||
// Data protection call counts are stored to local path only
|
||||
// Each cluster tracks its own total requests to avoid double counting
|
||||
localDataProtectionCallCounts, err := b.Core.UpdateDataProtectionCallCounts(ctx, currentMonth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving local max data protection call counts: %w", err)
|
||||
}
|
||||
|
||||
// If we are the primary, then combine the replicated and local max role counts. Else just output the local
|
||||
// max role counts. replicatedMaxRoleCounts will be empty if we are not a primary, so this is taken care of for us.
|
||||
combinedMaxRoleCounts := combineRoleCounts(ctx, replicatedMaxRoleCounts, localMaxRoleCounts)
|
||||
combinedMaxRoleCounts := combineRoleCounts(replicatedMaxRoleCounts, localMaxRoleCounts)
|
||||
combinedMaxKvCounts := replicatedKvHWMCounts + localKvHWMCounts
|
||||
// Data protection counts are not combined - each cluster reports its own total
|
||||
combinedMaxDataProtectionCallCounts := localDataProtectionCallCounts
|
||||
|
||||
var replicatedPreviousMonthRoleCounts *RoleCounts
|
||||
replicatedPreviousMonthKvHWMCounts := 0
|
||||
|
|
@ -107,19 +120,29 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic
|
|||
return nil, fmt.Errorf("error retrieving local max kv counts for previous month: %w", err)
|
||||
}
|
||||
|
||||
combinedPreviousMonthRoleCounts := combineRoleCounts(ctx, replicatedPreviousMonthRoleCounts, localPreviousMonthRoleCounts)
|
||||
// Data protection counts for previous month
|
||||
localPreviousMonthDataProtectionCallCounts, err := b.Core.GetStoredDataProtectionCallCounts(ctx, previousMonth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving local max data protection call counts for previous month: %w", err)
|
||||
}
|
||||
|
||||
combinedPreviousMonthRoleCounts := combineRoleCounts(replicatedPreviousMonthRoleCounts, localPreviousMonthRoleCounts)
|
||||
combinedPreviousMonthKvHWMCounts := replicatedPreviousMonthKvHWMCounts + localPreviousMonthKvHWMCounts
|
||||
// Data protection counts are not combined - each cluster reports its own total
|
||||
combinedPreviousMonthDataProtectionCallCounts := localPreviousMonthDataProtectionCallCounts
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"current_month": map[string]interface{}{
|
||||
"timestamp": timeutil.StartOfMonth(currentMonth),
|
||||
"maximum_role_counts": combinedMaxRoleCounts,
|
||||
"maximum_kv_counts": combinedMaxKvCounts,
|
||||
"timestamp": timeutil.StartOfMonth(currentMonth),
|
||||
"maximum_role_counts": combinedMaxRoleCounts,
|
||||
"maximum_kv_counts": combinedMaxKvCounts,
|
||||
"data_protection_call_counts": combinedMaxDataProtectionCallCounts,
|
||||
},
|
||||
"previous_month": map[string]interface{}{
|
||||
"timestamp": previousMonth,
|
||||
"maximum_role_counts": combinedPreviousMonthRoleCounts,
|
||||
"maximum_kv_counts": combinedPreviousMonthKvHWMCounts,
|
||||
"timestamp": previousMonth,
|
||||
"maximum_role_counts": combinedPreviousMonthRoleCounts,
|
||||
"maximum_kv_counts": combinedPreviousMonthKvHWMCounts,
|
||||
"data_protection_call_counts": combinedPreviousMonthDataProtectionCallCounts,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue