From 81c1c3778bcc07f5b819e32d7bbb951a2b27ebbf Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Fri, 30 Jan 2026 15:16:05 -0500 Subject: [PATCH] VAULT-41092: transit engine metrics (#11814) (#12103) * add a new struct for the total number of successful requests for transit and transform * implement tracking for encrypt path * implement tracking in encrypt path * add tracking in rewrap * add tracking to datakey path * add tracking to hmac path * add tracking to sign path * add tracking to verify path * unit tests for verify path * add tracking to cmac path * reset the global counter in each unit test * add tracking to hmac verify * add methods to retrieve and flush transit count * modify the methods that store and update data protection call counts * update the methods * add a helper method to combine replicated and local data call counts * add tracking to the endpoint * fix some formatting errors * add unit tests to path encrypt for tracking * add unit tests to decrypt path * fix linter error * add unit tests to test update and store methods for data protection calls * stub fix: do not create separate files * fix the tracking by coordinating replicated and local data, add unit tests * update all reference to the new data struct * revert to previous design with just one global counter for all calls for each cluster * complete external test * no need to check if current count is greater than 0, remove it * feedback: remove unnacassary comments about atomic addition, standardize comments * leave jira id on todo comment, remove unused method * rename mathods by removing HWM and max in names, update jira id in todo comment, update response field key name * feedback: remove explicit counter in cmac tests, instead put in the expected number * feedback: remove explicit tracking in the rest of the tests * feedback: separate transit testing into its own external test * Update vault/consumption_billing_util_test.go * update comment after test name change * fix comments * fix comments in test * another comment fix * feedback: remove incorrect comment * fix a CE test * fix the update method: instead of storing max, increment by the current count value * update the unit test, remove local prefix as argument to the methods since we store only to non-replicated paths * update the external test * fix a bug: reset the counter everyime we update the stored counter value to prevent double-counting * update one of the tests * update external test --------- Co-authored-by: Amir Aslamov Co-authored-by: divyaac --- builtin/logical/transit/backend.go | 7 + builtin/logical/transit/path_datakey.go | 5 + builtin/logical/transit/path_datakey_test.go | 11 + builtin/logical/transit/path_decrypt.go | 5 + builtin/logical/transit/path_decrypt_test.go | 14 + builtin/logical/transit/path_encrypt.go | 5 + builtin/logical/transit/path_encrypt_test.go | 87 ++++++ builtin/logical/transit/path_hmac.go | 12 + builtin/logical/transit/path_hmac_test.go | 19 ++ builtin/logical/transit/path_rewrap.go | 5 + builtin/logical/transit/path_rewrap_test.go | 24 +- builtin/logical/transit/path_sign_verify.go | 10 + .../logical/transit/path_sign_verify_test.go | 24 ++ vault/billing/billing_counts.go | 25 +- vault/consumption_billing.go | 12 + vault/consumption_billing_util.go | 64 ++++- vault/consumption_billing_util_oss_test.go | 8 + vault/consumption_billing_util_test.go | 253 ++++++++++++++++++ vault/logical_system_use_case_billing.go | 39 ++- 19 files changed, 612 insertions(+), 17 deletions(-) diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index ff71198a1c..971ac7e2ca 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -18,6 +19,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/billing" ) const ( @@ -146,6 +148,11 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error return size, nil } +// incrementDataProtectionCounter atomically increments the data protection call counter to avoid race conditions +func (b *backend) incrementDataProtectionCounter(count int64) { + atomic.AddInt64(&billing.CurrentDataProtectionCallCounts.Transit, count) +} + // Update cache size and get policy func (b *backend) GetPolicy(ctx context.Context, polReq keysutil.PolicyRequest, rand io.Reader) (retP *keysutil.Policy, retUpserted bool, retErr error) { // Acquire read lock to read cacheSizeChanged diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index 72e6849edb..f8b851c11e 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -169,6 +169,11 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d resp.Data["plaintext"] = plaintext } + // Increment the counter for successful operations + // Since there are not batched operations, we can add one successful + // request to the transit request counter. + b.incrementDataProtectionCounter(1) + return resp, nil } diff --git a/builtin/logical/transit/path_datakey_test.go b/builtin/logical/transit/path_datakey_test.go index 085e530296..c8150afc69 100644 --- a/builtin/logical/transit/path_datakey_test.go +++ b/builtin/logical/transit/path_datakey_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -15,6 +16,9 @@ import ( // TestDataKeyWithPaddingScheme validates that we properly leverage padding scheme // args for the returned keys func TestDataKeyWithPaddingScheme(t *testing.T) { + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) keyName := "test" createKeyReq := &logical.Request{ @@ -90,11 +94,16 @@ func TestDataKeyWithPaddingScheme(t *testing.T) { } }) } + // We expect 8 successful requests ((3 valid cases x 2 operations) + (2 invalid cases x 1 operation)) + require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit) } // TestDataKeyWithPaddingSchemeInvalidKeyType validates we fail when we specify a // padding_scheme value on an invalid key type (non-RSA) func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) { + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) keyName := "test" createKeyReq := &logical.Request{ @@ -122,4 +131,6 @@ func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) { require.ErrorContains(t, err, "invalid request") require.NotNil(t, resp, "response should not be nil") require.Contains(t, resp.Error().Error(), "padding_scheme argument invalid: unsupported key") + // We expect 0 successful requests + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 5b86289d05..f25275d352 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -196,6 +196,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d defer p.Unlock() successesInBatch := false + successfulRequests := 0 for i, item := range batchInputItems { if batchResponseItems[i].Error != "" { continue @@ -252,6 +253,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d } successesInBatch = true batchResponseItems[i].Plaintext = plaintext + successfulRequests++ } resp := &logical.Response{} @@ -276,6 +278,9 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d } } + // Increment the counter for successful operations + b.incrementDataProtectionCounter(int64(successfulRequests)) + return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch) } diff --git a/builtin/logical/transit/path_decrypt_test.go b/builtin/logical/transit/path_decrypt_test.go index 244fbec418..b2f8046f16 100644 --- a/builtin/logical/transit/path_decrypt_test.go +++ b/builtin/logical/transit/path_decrypt_test.go @@ -12,7 +12,9 @@ import ( "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/require" ) func TestTransit_BatchDecryption(t *testing.T) { @@ -21,6 +23,9 @@ func TestTransit_BatchDecryption(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + batchEncryptionInput := []interface{}{ map[string]interface{}{"plaintext": "", "reference": "foo"}, // empty string map[string]interface{}{"plaintext": "Cg==", "reference": "bar"}, // newline @@ -69,6 +74,9 @@ func TestTransit_BatchDecryption(t *testing.T) { if err != nil || err == nil && string(jsonResponse) != expectedResult { t.Fatalf("bad: expected json response [%s]", jsonResponse) } + + // We expect 6 successful requests (3 for batch encryption, 3 for batch decryption) + require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) } func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { @@ -78,6 +86,9 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + // Create a derived key. req = &logical.Request{ Operation: logical.UpdateOperation, @@ -278,4 +289,7 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { } }) } + + // We expect 7 successful requests (2 for batch encryption + 1 single-item decryption + 2 batch decryption + 2 batch decryption) + require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit) } diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index aef4a2fafd..4ebaf2e2b1 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -531,6 +531,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d // collection and continue to process other items. warnAboutNonceUsage := false successesInBatch := false + successfulRequests := 0 for i, item := range batchInputItems { if batchResponseItems[i].Error != "" { userErrorInBatch = true @@ -613,6 +614,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d batchResponseItems[i].Ciphertext = ciphertext batchResponseItems[i].KeyVersion = keyVersion + successfulRequests++ } resp := &logical.Response{} @@ -647,6 +649,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d resp.AddWarning("Attempted creation of the key during the encrypt operation, but it was created beforehand") } + // Increment the counter for successful operations + b.incrementDataProtectionCounter(int64(successfulRequests)) + return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch) } diff --git a/builtin/logical/transit/path_encrypt_test.go b/builtin/logical/transit/path_encrypt_test.go index be1c52ab23..05ab5e2fd3 100644 --- a/builtin/logical/transit/path_encrypt_test.go +++ b/builtin/logical/transit/path_encrypt_test.go @@ -15,13 +15,18 @@ import ( uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/require" ) func TestTransit_MissingPlaintext(t *testing.T) { var resp *logical.Response var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) // Create the policy @@ -45,12 +50,17 @@ func TestTransit_MissingPlaintext(t *testing.T) { if resp == nil || !resp.IsError() { t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp) } + // We expect 0 successful calls + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } func TestTransit_MissingPlaintextInBatchInput(t *testing.T) { var resp *logical.Response var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) // Create the policy @@ -81,6 +91,8 @@ func TestTransit_MissingPlaintextInBatchInput(t *testing.T) { if err == nil { t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp) } + // We expect 0 successful calls + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } // Case1: Ensure that batch encryption did not affect the normal flow of @@ -89,6 +101,9 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) { var resp *logical.Response var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) // Create the policy @@ -143,6 +158,9 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) { if resp.Data["plaintext"] != plaintext { t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"]) } + + // We expect 2 successful requests (1 for encrypt, 1 for decrypt) + require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) } // Case2: Ensure that batch encryption did not affect the normal flow of @@ -152,6 +170,9 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) { var err error b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -205,11 +226,16 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) { if resp.Data["plaintext"] != plaintext { t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"]) } + + // We expect 2 successful requests (1 for encrypt, 1 for decrypt) + require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) } // Case3: If batch encryption input is not base64 encoded, it should fail. func TestTransit_BatchEncryptionCase3(t *testing.T) { var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 b, s := createBackendWithStorage(t) @@ -228,6 +254,9 @@ func TestTransit_BatchEncryptionCase3(t *testing.T) { if err == nil { t.Fatal("expected an error") } + + // We expect 0 successful requests + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } // Case4: Test batch encryption with an existing key (and test references) @@ -235,6 +264,9 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) { var resp *logical.Response var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) policyReq := &logical.Request{ @@ -297,6 +329,9 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) { t.Fatalf("reference mismatch. Expected %s, Actual: %s", inputItem["reference"], item.Reference) } } + + // We expect 4 successful requests (2 batch requests + 2 decrypt requests) + require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) } // Case5: Test batch encryption with an existing derived key @@ -304,6 +339,9 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) { var resp *logical.Response var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) policyData := map[string]interface{}{ @@ -370,6 +408,8 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) { t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"]) } } + // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) + require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) } // Case6: Test batch encryption with an upserted non-derived key @@ -377,6 +417,9 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) { var resp *logical.Response var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) batchInput := []interface{}{ @@ -430,6 +473,9 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) { t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"]) } } + + // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) + require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) } // Case7: Test batch encryption with an upserted derived key @@ -439,6 +485,9 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, @@ -486,6 +535,8 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) { t.Fatalf("bad: plaintext. Expected: %q, Actual: %q", plaintext, resp.Data["plaintext"]) } } + // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) + require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) } // Case8: If plaintext is not base64 encoded, encryption should fail @@ -495,6 +546,9 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + // Create the policy policyReq := &logical.Request{ Operation: logical.UpdateOperation, @@ -539,6 +593,8 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) { if err == nil { t.Fatal("expected an error") } + // We expect 0 successful transit requests + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } // Case9: If both plaintext and batch inputs are supplied, plaintext should be @@ -549,6 +605,9 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, @@ -573,6 +632,9 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) { if ok { t.Fatal("ciphertext field should not be set") } + + // We expect 2 successful batch encryptions + require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) } // Case10: Inconsistent presence of 'context' in batch input should be caught @@ -581,6 +643,9 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, @@ -600,6 +665,8 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) { if err == nil { t.Fatalf("expected an error") } + // We expect no successful transit requests + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } // Case11: Incorrect inputs for context and nonce should not fail the operation @@ -608,6 +675,9 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "not-encoded"}, @@ -626,6 +696,8 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) { if err != nil { t.Fatal(err) } + // We expect 1 successful encryption out of the 2-item batch + require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) } // Case12: Invalid batch input @@ -633,6 +705,9 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) { var err error b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + batchInput := []interface{}{ map[string]interface{}{}, "unexpected_interface", @@ -651,6 +726,8 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) { if err == nil { t.Fatalf("expected an error") } + // We expect no successful requests + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } // Case13: Incorrect input for nonce when we aren't in convergent encryption should fail the operation @@ -659,6 +736,9 @@ func TestTransit_EncryptionCase13(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + // Non-batch first data := map[string]interface{}{"plaintext": "bXkgc2VjcmV0IGRhdGE=", "nonce": "R80hr9eNUIuFV52e"} req := &logical.Request{ @@ -693,6 +773,8 @@ func TestTransit_EncryptionCase13(t *testing.T) { if v, ok := resp.Data["http_status_code"]; !ok || v.(int) != http.StatusBadRequest { t.Fatal("expected request error") } + // We expect no successful transit requests + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } // Case14: Incorrect input for nonce when we are in convergent version 3 should fail @@ -701,6 +783,9 @@ func TestTransit_EncryptionCase14(t *testing.T) { b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + cReq := &logical.Request{ Operation: logical.UpdateOperation, Path: "keys/my-key", @@ -750,6 +835,8 @@ func TestTransit_EncryptionCase14(t *testing.T) { if v, ok := resp.Data["http_status_code"]; !ok || v.(int) != http.StatusBadRequest { t.Fatal("expected request error") } + // We expect no successful transit requests + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) } // Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem. diff --git a/builtin/logical/transit/path_hmac.go b/builtin/logical/transit/path_hmac.go index 10ae014e0e..316fa608f1 100644 --- a/builtin/logical/transit/path_hmac.go +++ b/builtin/logical/transit/path_hmac.go @@ -189,6 +189,9 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr response := make([]batchResponseHMACItem, len(batchInputItems)) + // Count successful HMAC operations + successfulRequests := 0 + for i, item := range batchInputItems { rawInput, ok := item["input"] if !ok { @@ -225,6 +228,7 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr retStr := base64.StdEncoding.EncodeToString(retBytes) retStr = fmt.Sprintf("vault:v%s:%s", strconv.Itoa(ver), retStr) response[i].HMAC = retStr + successfulRequests++ } // Generate the response @@ -250,6 +254,9 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr } } + // Increment the counter for successful operations + b.incrementDataProtectionCounter(int64(successfulRequests)) + return resp, nil } @@ -320,6 +327,7 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f response := make([]batchResponseHMACItem, len(batchInputItems)) + successfulRequests := 0 for i, item := range batchInputItems { rawInput, ok := item["input"] if !ok { @@ -398,6 +406,7 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f hf.Write(input) retBytes := hf.Sum(nil) response[i].Valid = hmac.Equal(retBytes, verBytes) + successfulRequests++ } // Generate the response @@ -423,6 +432,9 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f } } + // Increment the counter for successful operations + b.incrementDataProtectionCounter(int64(successfulRequests)) + return resp, nil } diff --git a/builtin/logical/transit/path_hmac_test.go b/builtin/logical/transit/path_hmac_test.go index 6c456c2401..13c1e5e556 100644 --- a/builtin/logical/transit/path_hmac_test.go +++ b/builtin/logical/transit/path_hmac_test.go @@ -16,10 +16,14 @@ import ( "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/vault/billing" "github.com/stretchr/testify/require" ) func TestTransit_HMAC(t *testing.T) { + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, storage := createBackendWithSysView(t) cases := []struct { @@ -244,9 +248,14 @@ func TestTransit_HMAC(t *testing.T) { t.Fatalf("expected invalid request error, got %v", err) } } + // Verify the total successful transit requests + require.Equal(t, int64(72), billing.CurrentDataProtectionCallCounts.Transit) } func TestTransit_batchHMAC(t *testing.T) { + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, storage := createBackendWithSysView(t) // First create a key @@ -393,12 +402,16 @@ func TestTransit_batchHMAC(t *testing.T) { if resp == nil { t.Fatal("expected non-nil response") } + // do not increment the counter for failed HMAC verify operations batchHMACVerifyResponseItems = resp.Data["batch_results"].([]batchResponseHMACItem) if batchHMACVerifyResponseItems[0].Valid { t.Fatalf("expected error validating hmac\nreq\n%#v\nresp\n%#v", *req, *resp) } + + // Verify the total successful transit requests + require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit) } // TestHMACBatchResultsFields checks that responses to HMAC verify requests using batch_input @@ -427,6 +440,9 @@ func TestHMACBatchResultsFields(t *testing.T) { }) require.NoError(t, err) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + keyName := "hmac-test-key" _, err = client.Logical().Write("transit/keys/"+keyName, map[string]interface{}{"type": "hmac", "key_size": 32}) require.NoError(t, err) @@ -487,4 +503,7 @@ func TestHMACBatchResultsFields(t *testing.T) { require.Contains(t, result, "reference") require.Contains(t, result, "valid") } + + // We expect 4 successful requests (2 for batch HMAC generation, 2 for batch verification) + require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index df326ad5d2..1451aa768c 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -202,6 +202,7 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d * defer p.Unlock() warnAboutNonceUsage := false + successfulRequests := 0 for i, item := range batchInputItems { if batchResponseItems[i].Error != "" { continue @@ -281,6 +282,7 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d * batchResponseItems[i].Ciphertext = ciphertext batchResponseItems[i].KeyVersion = keyVersion + successfulRequests++ } resp := &logical.Response{} @@ -306,6 +308,9 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d * resp.AddWarning("A provided nonce value was used within FIPS mode, this violates FIPS 140 compliance.") } + // Increment the counter for successful operations + b.incrementDataProtectionCounter(int64(successfulRequests)) + return resp, nil } diff --git a/builtin/logical/transit/path_rewrap_test.go b/builtin/logical/transit/path_rewrap_test.go index 416f42578b..a3734da6da 100644 --- a/builtin/logical/transit/path_rewrap_test.go +++ b/builtin/logical/transit/path_rewrap_test.go @@ -9,6 +9,8 @@ import ( "testing" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/billing" + "github.com/stretchr/testify/require" ) // Check the normal flow of rewrap @@ -17,6 +19,9 @@ func TestTransit_BatchRewrapCase1(t *testing.T) { var err error b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -112,6 +117,9 @@ func TestTransit_BatchRewrapCase1(t *testing.T) { if keyVersion != 2 { t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2) } + + // We expect 2 successful requests (1 for encrypt, 1 for rewrap) + require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) } // Check the normal flow of rewrap with upserted key @@ -120,6 +128,9 @@ func TestTransit_BatchRewrapCase2(t *testing.T) { var err error b, s := createBackendWithStorage(t) + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -217,6 +228,8 @@ func TestTransit_BatchRewrapCase2(t *testing.T) { if keyVersion != 2 { t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2) } + // We expect 2 successful transit requests (1 for encrypt, 1 for rewrap) + require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) } // Batch encrypt plaintexts, rotate the keys and rewrap all the ciphertexts @@ -224,6 +237,9 @@ func TestTransit_BatchRewrapCase3(t *testing.T) { var resp *logical.Response var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) batchEncryptionInput := []interface{}{ @@ -323,8 +339,9 @@ func TestTransit_BatchRewrapCase3(t *testing.T) { if resp.Data["plaintext"] != plaintext1 && resp.Data["plaintext"] != plaintext2 { t.Fatalf("bad: plaintext. Expected: %q or %q, Actual: %q", plaintext1, plaintext2, resp.Data["plaintext"]) } - } + // We expect 6 successful transit requests (2 for batch encryption, 2 for batch rewrap, and 2 for decryption) + require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) } // TestTransit_BatchRewrapCase4 batch rewrap leveraging RSA padding schemes @@ -332,6 +349,9 @@ func TestTransit_BatchRewrapCase4(t *testing.T) { var resp *logical.Response var err error + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, s := createBackendWithStorage(t) batchEncryptionInput := []interface{}{ @@ -438,4 +458,6 @@ func TestTransit_BatchRewrapCase4(t *testing.T) { t.Fatalf("bad: plaintext. Expected: %q or %q, Actual: %q", plaintext1, plaintext2, resp.Data["plaintext"]) } } + // We expect 6 succcessful calls to the transit backend (2 for batch encryption, 2 for batch decryption, and 2 for batch rewrap) + require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) } diff --git a/builtin/logical/transit/path_sign_verify.go b/builtin/logical/transit/path_sign_verify.go index d76c9831f3..114fc98209 100644 --- a/builtin/logical/transit/path_sign_verify.go +++ b/builtin/logical/transit/path_sign_verify.go @@ -409,6 +409,7 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr batchInputRaw := d.Raw["batch_input"] var batchInputItems []batchRequestSignItem + successfulRequests := 0 if batchInputRaw != nil { err = mapstructure.Decode(batchInputRaw, &batchInputItems) if err != nil { @@ -458,6 +459,7 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr response[i].Signature = sig.Signature response[i].PublicKey = sig.PublicKey response[i].KeyVersion = keyVersion + successfulRequests++ } } @@ -490,6 +492,9 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr } } + // Increment the counter for successful operations + b.incrementDataProtectionCounter(int64(successfulRequests)) + return resp, nil } @@ -694,6 +699,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * response := make([]batchResponseVerifyItem, len(batchInputItems)) + successfulRequests := 0 for i, item := range batchInputItems { pva, err := b.getPolicyVerifyArgs(ctx, p, apiArgs, item) if err != nil { @@ -716,6 +722,7 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * } } else { response[i].Valid = valid + successfulRequests++ } } @@ -743,6 +750,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * } } + // Increment the counter for successful operations + b.incrementDataProtectionCounter(int64(successfulRequests)) + return resp, nil } diff --git a/builtin/logical/transit/path_sign_verify_test.go b/builtin/logical/transit/path_sign_verify_test.go index 8949734013..482f6627dc 100644 --- a/builtin/logical/transit/path_sign_verify_test.go +++ b/builtin/logical/transit/path_sign_verify_test.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/vault/helper/constants" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" @@ -32,6 +33,9 @@ type signOutcome struct { } func TestTransit_SignVerify_ECDSA(t *testing.T) { + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + t.Run("256", func(t *testing.T) { testTransit_SignVerify_ECDSA(t, 256) }) @@ -41,6 +45,9 @@ func TestTransit_SignVerify_ECDSA(t *testing.T) { t.Run("521", func(t *testing.T) { testTransit_SignVerify_ECDSA(t, 521) }) + + // Verify the total successful transit requests + require.Equal(t, int64(84), billing.CurrentDataProtectionCallCounts.Transit) } func testTransit_SignVerify_ECDSA(t *testing.T, bits int) { @@ -373,6 +380,9 @@ func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, ex // TestTransit_SignVerify_Ed25519Behavior makes sure the options on ENT for a // Ed25519ph/ctx signature fail on CE and ENT if invalid func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) { + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, storage := createBackendWithSysView(t) // First create a key @@ -465,9 +475,19 @@ func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) { } }) } + // Verify the total successful transit requests + if constants.IsEnterprise { + require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + } else { + // We expect 0 successful calls on CE because we expect the verify to fail + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + } } func TestTransit_SignVerify_ED25519(t *testing.T) { + // Reset the transit counter + billing.CurrentDataProtectionCallCounts.Transit = 0 + b, storage := createBackendWithSysView(t) // First create a key @@ -735,6 +755,7 @@ func TestTransit_SignVerify_ED25519(t *testing.T) { // Repeat with the other key sig = signRequest(req, false, "bar") verifyRequest(req, false, outcome, "bar", sig, true) + // Now try the v1 verifyRequest(req, true, outcome, "bar", v1sig, true) // Test Batch Signing @@ -809,6 +830,9 @@ func TestTransit_SignVerify_ED25519(t *testing.T) { outcome[1].requestOk = true outcome[1].valid = false verifyRequest(req, false, outcome, "bar", goodsig, true) + + // Verify the total successful transit requests + require.Equal(t, int64(24), billing.CurrentDataProtectionCallCounts.Transit) } func TestTransit_SignVerify_RSA_PSS(t *testing.T) { diff --git a/vault/billing/billing_counts.go b/vault/billing/billing_counts.go index c81702076a..d92bfe4d2e 100644 --- a/vault/billing/billing_counts.go +++ b/vault/billing/billing_counts.go @@ -10,13 +10,15 @@ import ( ) const ( - BillingSubPath = "billing/" - ReplicatedPrefix = "replicated/" - RoleHWMCountsHWM = "maxRoleCounts/" - KvHWMCountsHWM = "maxKvCounts/" - LocalPrefix = "local/" - ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/" - BillingWriteInterval = 10 * time.Minute + BillingSubPath = "billing/" + ReplicatedPrefix = "replicated/" + RoleHWMCountsHWM = "maxRoleCounts/" + KvHWMCountsHWM = "maxKvCounts/" + DataProtectionCallCountsMetric = "dataProtectionCallCounts/" + LocalPrefix = "local/" + ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/" + + BillingWriteInterval = 10 * time.Minute ) var BillingMonthStorageFormat = "%s%d/%02d/%s" // e.g replicated/2026/01/maxKvCounts/ @@ -41,3 +43,12 @@ func GetMonthlyBillingPath(localPrefix string, now time.Time, billingMetric stri month := int(now.Month()) return fmt.Sprintf(BillingMonthStorageFormat, localPrefix, year, month, billingMetric) } + +type DataProtectionCallCounts struct { + Transit int64 `json:"transit,omitempty"` + // TODO: Uncomment when we add support for Transform tracking (VAULT-41205) + // Transform int64 `json:"transform,omitempty"` +} + +// Global counter for all data protection calls on this cluster +var CurrentDataProtectionCallCounts = DataProtectionCallCounts{} diff --git a/vault/consumption_billing.go b/vault/consumption_billing.go index bccf164beb..4c7bdd6c53 100644 --- a/vault/consumption_billing.go +++ b/vault/consumption_billing.go @@ -72,6 +72,7 @@ func (c *Core) updateBillingMetrics(ctx context.Context) error { c.UpdateReplicatedHWMMetrics(ctx, currentMonth) } c.UpdateLocalHWMMetrics(ctx, currentMonth) + c.UpdateLocalAggregatedMetrics(ctx, currentMonth) } return nil } @@ -112,5 +113,16 @@ func (c *Core) UpdateLocalHWMMetrics(ctx context.Context, currentMonth time.Time } else { c.logger.Info("updated local max external plugin counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth) } + + return nil +} + +func (c *Core) UpdateLocalAggregatedMetrics(ctx context.Context, currentMonth time.Time) error { + if _, err := c.UpdateDataProtectionCallCounts(ctx, currentMonth); err != nil { + c.logger.Error("error updating local max data protection call counts", "error", err) + } else { + c.logger.Info("updated local max data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth) + } + return nil } diff --git a/vault/consumption_billing_util.go b/vault/consumption_billing_util.go index 41398e7d85..e6656a3145 100644 --- a/vault/consumption_billing_util.go +++ b/vault/consumption_billing_util.go @@ -6,6 +6,7 @@ package vault import ( "context" "strconv" + "sync/atomic" "time" "github.com/hashicorp/vault/sdk/logical" @@ -66,7 +67,7 @@ func (c *Core) GetStoredThirdPartyPluginCounts(ctx context.Context, month time.T return c.getStoredThirdPartyPluginCountsLocked(ctx, billing.LocalPrefix, month) } -func combineRoleCounts(ctx context.Context, a, b *RoleCounts) *RoleCounts { +func combineRoleCounts(a, b *RoleCounts) *RoleCounts { if a == nil && b == nil { return &RoleCounts{} } @@ -263,3 +264,64 @@ func (c *Core) compareCounts(current, previous int, metricName string) int { func (c *Core) GetBillingSubView() *BarrierView { return c.systemBarrierView.SubView(billing.BillingSubPath) } + +// storeDataProtectionCallCountsLocked must be called with BillingStorageLock held +func (c *Core) storeDataProtectionCallCountsLocked(ctx context.Context, maxCounts *billing.DataProtectionCallCounts, localPathPrefix string, month time.Time) error { + billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric) + entry, err := logical.StorageEntryJSON(billingPath, maxCounts) + if err != nil { + return err + } + return c.GetBillingSubView().Put(ctx, entry) +} + +// getStoredDataProtectionCallCountsLocked must be called with BillingStorageLock held +func (c *Core) getStoredDataProtectionCallCountsLocked(ctx context.Context, localPathPrefix string, month time.Time) (*billing.DataProtectionCallCounts, error) { + billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric) + entry, err := c.GetBillingSubView().Get(ctx, billingPath) + if err != nil { + return nil, err + } + if entry == nil { + return &billing.DataProtectionCallCounts{}, nil + } + var maxCounts billing.DataProtectionCallCounts + if err := entry.DecodeJSON(&maxCounts); err != nil { + return nil, err + } + return &maxCounts, nil +} + +func (c *Core) GetStoredDataProtectionCallCounts(ctx context.Context, month time.Time) (*billing.DataProtectionCallCounts, error) { + c.consumptionBilling.BillingStorageLock.RLock() + defer c.consumptionBilling.BillingStorageLock.RUnlock() + return c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, month) +} + +func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth time.Time) (*billing.DataProtectionCallCounts, error) { + c.consumptionBilling.BillingStorageLock.Lock() + defer c.consumptionBilling.BillingStorageLock.Unlock() + + // Read and atomically reset the current counter value + // TODO: Reset Transform call counts too (VAULT-41205) + currentTransitCount := atomic.SwapInt64(&billing.CurrentDataProtectionCallCounts.Transit, 0) + + storedDataProtectionCallCounts, err := c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, currentMonth) + if err != nil { + return nil, err + } + if storedDataProtectionCallCounts == nil { + storedDataProtectionCallCounts = &billing.DataProtectionCallCounts{} + } + + // Sum the current count with the stored count + storedDataProtectionCallCounts.Transit += currentTransitCount + // TODO: Update Transform call counts (VAULT-41205) + + err = c.storeDataProtectionCallCountsLocked(ctx, storedDataProtectionCallCounts, billing.LocalPrefix, currentMonth) + if err != nil { + return nil, err + } + + return storedDataProtectionCallCounts, nil +} diff --git a/vault/consumption_billing_util_oss_test.go b/vault/consumption_billing_util_oss_test.go index 39c8c4ee9b..effde53104 100644 --- a/vault/consumption_billing_util_oss_test.go +++ b/vault/consumption_billing_util_oss_test.go @@ -6,6 +6,7 @@ package vault import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -38,3 +39,10 @@ func verifyExpectedRoleCounts(t *testing.T, actual *RoleCounts, baseCount int) { } require.Equal(t, expected, actual) } + +// testCMACOperations is a no-op in OSS since CMAC is an Enterprise-only feature. +// Returns the current count unchanged. +func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount int64) int64 { + // CMAC is not supported in OSS, so we don't perform any operations + return currentCount +} diff --git a/vault/consumption_billing_util_test.go b/vault/consumption_billing_util_test.go index 59678fe284..a438c1b0c2 100644 --- a/vault/consumption_billing_util_test.go +++ b/vault/consumption_billing_util_test.go @@ -23,6 +23,7 @@ import ( logicalDatabase "github.com/hashicorp/vault/builtin/logical/database" logicalNomad "github.com/hashicorp/vault/builtin/logical/nomad" logicalRabbitMQ "github.com/hashicorp/vault/builtin/logical/rabbitmq" + "github.com/hashicorp/vault/builtin/logical/transit" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/pluginconsts" "github.com/hashicorp/vault/sdk/logical" @@ -397,6 +398,258 @@ func TestHWMKvSecretsCounts(t *testing.T) { require.Equal(t, 5, counts) } +// TestDataProtectionCallCounts tests that we correctly store and track the data protection call counts +func TestDataProtectionCallCounts(t *testing.T) { + t.Parallel() + coreConfig := &CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "transit": transit.Factory, + }, + BillingConfig: billing.BillingConfig{ + MetricsUpdateCadence: 3 * time.Second, + }, + } + + core, _, root := TestCoreUnsealedWithConfig(t, coreConfig) + + // Mount transit backend + req := logical.TestRequest(t, logical.CreateOperation, "sys/mounts/transit") + req.Data["type"] = "transit" + req.ClientToken = root + ctx := namespace.RootContext(context.Background()) + _, err := core.HandleRequest(ctx, req) + require.NoError(t, err) + + // Reset the transit counters + billing.CurrentDataProtectionCallCounts.Transit = 0 + + // Create an encryption key + req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/foo") + req.Data["type"] = "aes256-gcm96" + req.ClientToken = root + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + // Perform encryption on the key + req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo") + req.Data["plaintext"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" + req.ClientToken = root + resp, err := core.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + + // Verify that the transit counter is incremented (replicated mount by default) + require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) + + // Get the ciphertext from the encryption response + ciphertext, ok := resp.Data["ciphertext"].(string) + require.True(t, ok) + require.NotEmpty(t, ciphertext) + + // Now perform decryption using the ciphertext + req = logical.TestRequest(t, logical.UpdateOperation, "transit/decrypt/foo") + req.Data["ciphertext"] = ciphertext + req.ClientToken = root + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + // Verify that the transit counter is incremented + require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + + // Test rewrap operation + req = logical.TestRequest(t, logical.UpdateOperation, "transit/rewrap/foo") + req.Data["ciphertext"] = ciphertext + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + + // Verify that the transit counter is incremented + require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit) + + // Get the new ciphertext from rewrap + newCiphertext, ok := resp.Data["ciphertext"].(string) + require.True(t, ok) + require.NotEmpty(t, newCiphertext) + + // Test datakey generation + req = logical.TestRequest(t, logical.UpdateOperation, "transit/datakey/plaintext/foo") + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + + // Verify that the transit counter is incremented + require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + + // Test HMAC generation + req = logical.TestRequest(t, logical.UpdateOperation, "transit/hmac/foo") + req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + + // Verify that the transit counter is incremented + require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit) + + // Get the HMAC value + hmacValue, ok := resp.Data["hmac"].(string) + require.True(t, ok) + require.NotEmpty(t, hmacValue) + + // Test HMAC verification + req = logical.TestRequest(t, logical.UpdateOperation, "transit/verify/foo") + req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" + req.Data["hmac"] = hmacValue + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + + // Verify that the transit counter is incremented + require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + + // Verify the HMAC is valid + hmacValid, ok := resp.Data["valid"].(bool) + require.True(t, ok) + require.True(t, hmacValid) + + // Create a signing key for sign/verify operations + req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/signing-key") + req.Data["type"] = "ecdsa-p256" + req.ClientToken = root + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + // Test sign operation + req = logical.TestRequest(t, logical.UpdateOperation, "transit/sign/signing-key") + req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + + // Verify that the transit counter is incremented + require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit) + + // Get the signature + signature, ok := resp.Data["signature"].(string) + require.True(t, ok) + require.NotEmpty(t, signature) + + // Test verify operation + req = logical.TestRequest(t, logical.UpdateOperation, "transit/verify/signing-key") + req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" + req.Data["signature"] = signature + req.ClientToken = root + resp, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.Data) + + // Verify that the transit counter is incremented + require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit) + + // Verify the signature is valid + signatureValid, ok := resp.Data["valid"].(bool) + require.True(t, ok) + require.True(t, signatureValid) + + // Test CMAC operations (ENT only - will be no-op in OSS) + currentCount := billing.CurrentDataProtectionCallCounts.Transit + currentCount = testCMACOperations(t, core, ctx, root, currentCount) + + // Verify that the transit counter matches expected count + require.Equal(t, currentCount, billing.CurrentDataProtectionCallCounts.Transit) + + // Now test persisting the summed counts - store and retrieve counts + // First, update the data protection call counts (this will sum current counter with stored value) + summedCounts, err := core.UpdateDataProtectionCallCounts(ctx, time.Now()) + require.NoError(t, err) + require.NotNil(t, summedCounts) + require.Equal(t, currentCount, summedCounts.Transit) + + // Verify the counter was reset after update + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + + // Retrieve the stored counts + storedCounts, err := core.GetStoredDataProtectionCallCounts(ctx, time.Now()) + require.NoError(t, err) + require.NotNil(t, storedCounts) + require.Equal(t, currentCount, storedCounts.Transit) + + // Perform more operations to increase the counter + req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo") + req.Data["plaintext"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" + req.ClientToken = root + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + + // Counter should now be 1 (reset + 1 operation) + require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) + + // Update counts again - should sum the new count (1) with the stored count (currentCount) + summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) + require.NoError(t, err) + expectedSum := currentCount + 1 + require.Equal(t, expectedSum, summedCounts.Transit, "Count should be sum of stored and current") + + // Verify the counter was reset after update + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + + // Verify stored counts are now the sum + storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) + require.NoError(t, err) + require.Equal(t, expectedSum, storedCounts.Transit) + + // Add more operations without manually resetting + for i := 0; i < 3; i++ { + req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo") + req.Data["plaintext"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" + req.ClientToken = root + _, err = core.HandleRequest(ctx, req) + require.NoError(t, err) + } + + // Counter should be 3 + require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit) + + // Update counts - should sum 3 with the previous stored sum + summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) + require.NoError(t, err) + expectedSum = expectedSum + 3 + require.Equal(t, expectedSum, summedCounts.Transit, "Count should continue to sum") + + // Verify the counter was reset after update + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + + // Verify stored counts + storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) + require.NoError(t, err) + require.Equal(t, expectedSum, storedCounts.Transit) + + // Update again without any new operations + // This verifies we don't double-count + summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) + require.NoError(t, err) + require.Equal(t, expectedSum, summedCounts.Transit, "Count should remain the same when no new operations occurred") + + // Verify stored counts haven't changed + storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) + require.NoError(t, err) + require.Equal(t, expectedSum, storedCounts.Transit, "Stored count should remain the same") + + // Verify counter is still at 0 + require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should still be 0") +} + func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) { raw, ok := core.router.root.Get(mount + "/") if !ok { diff --git a/vault/logical_system_use_case_billing.go b/vault/logical_system_use_case_billing.go index df43446e8e..5fc4dc1eed 100644 --- a/vault/logical_system_use_case_billing.go +++ b/vault/logical_system_use_case_billing.go @@ -31,6 +31,10 @@ func (b *SystemBackend) useCaseConsumptionBillingPaths() []*framework.Path { Type: framework.TypeMap, Description: "High watermark (for this month) role counts for this cluster.", }, + "data_protection_call_counts": { + Type: framework.TypeMap, + Description: "Count of data protection calls on this cluster.", + }, }, }}, http.StatusNoContent: {{ @@ -81,10 +85,19 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic return nil, fmt.Errorf("error retrieving local max kv counts: %w", err) } + // Data protection call counts are stored to local path only + // Each cluster tracks its own total requests to avoid double counting + localDataProtectionCallCounts, err := b.Core.UpdateDataProtectionCallCounts(ctx, currentMonth) + if err != nil { + return nil, fmt.Errorf("error retrieving local max data protection call counts: %w", err) + } + // If we are the primary, then combine the replicated and local max role counts. Else just output the local // max role counts. replicatedMaxRoleCounts will be empty if we are not a primary, so this is taken care of for us. - combinedMaxRoleCounts := combineRoleCounts(ctx, replicatedMaxRoleCounts, localMaxRoleCounts) + combinedMaxRoleCounts := combineRoleCounts(replicatedMaxRoleCounts, localMaxRoleCounts) combinedMaxKvCounts := replicatedKvHWMCounts + localKvHWMCounts + // Data protection counts are not combined - each cluster reports its own total + combinedMaxDataProtectionCallCounts := localDataProtectionCallCounts var replicatedPreviousMonthRoleCounts *RoleCounts replicatedPreviousMonthKvHWMCounts := 0 @@ -107,19 +120,29 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic return nil, fmt.Errorf("error retrieving local max kv counts for previous month: %w", err) } - combinedPreviousMonthRoleCounts := combineRoleCounts(ctx, replicatedPreviousMonthRoleCounts, localPreviousMonthRoleCounts) + // Data protection counts for previous month + localPreviousMonthDataProtectionCallCounts, err := b.Core.GetStoredDataProtectionCallCounts(ctx, previousMonth) + if err != nil { + return nil, fmt.Errorf("error retrieving local max data protection call counts for previous month: %w", err) + } + + combinedPreviousMonthRoleCounts := combineRoleCounts(replicatedPreviousMonthRoleCounts, localPreviousMonthRoleCounts) combinedPreviousMonthKvHWMCounts := replicatedPreviousMonthKvHWMCounts + localPreviousMonthKvHWMCounts + // Data protection counts are not combined - each cluster reports its own total + combinedPreviousMonthDataProtectionCallCounts := localPreviousMonthDataProtectionCallCounts resp := map[string]interface{}{ "current_month": map[string]interface{}{ - "timestamp": timeutil.StartOfMonth(currentMonth), - "maximum_role_counts": combinedMaxRoleCounts, - "maximum_kv_counts": combinedMaxKvCounts, + "timestamp": timeutil.StartOfMonth(currentMonth), + "maximum_role_counts": combinedMaxRoleCounts, + "maximum_kv_counts": combinedMaxKvCounts, + "data_protection_call_counts": combinedMaxDataProtectionCallCounts, }, "previous_month": map[string]interface{}{ - "timestamp": previousMonth, - "maximum_role_counts": combinedPreviousMonthRoleCounts, - "maximum_kv_counts": combinedPreviousMonthKvHWMCounts, + "timestamp": previousMonth, + "maximum_role_counts": combinedPreviousMonthRoleCounts, + "maximum_kv_counts": combinedPreviousMonthKvHWMCounts, + "data_protection_call_counts": combinedPreviousMonthDataProtectionCallCounts, }, }