Backport Vault 42177 Add Backend Field into ce/main (#12152)

* Vault 42177 Add Backend Field (#12092)

* add a new struct for the total number of successful requests for transit and transform

* implement tracking for encrypt path

* implement tracking in encrypt path

* add tracking in rewrap

* add tracking to datakey path

* add tracking to  hmac path

* add tracking to sign  path

* add tracking to verify path

* unit tests for verify path

* add tracking to cmac path

* reset the global counter in each unit test

* add tracking to hmac verify

* add methods to retrieve and flush transit count

* modify the methods that store and update data protection call counts

* update the methods

* add a helper method to combine replicated and local data call counts

* add tracking to the endpoint

* fix some formatting errors

* add unit tests to path encrypt for tracking

* add unit tests to decrypt path

* fix linter error

* add unit tests to test update and store methods for data protection calls

* stub fix: do not create separate files

* fix the tracking by coordinating replicated and local data, add unit tests

* update all reference to the new data struct

* revert to previous design with just one global counter for all calls for each cluster

* complete external test

* no need to check if current count is greater than 0, remove it

* feedback: remove unnacassary comments about atomic addition, standardize comments

* leave jira id on todo comment, remove unused method

* rename mathods by removing HWM and max in names, update jira id in todo comment, update response field key name

* feedback: remove explicit counter in cmac tests, instead put in the expected number

* feedback: remove explicit tracking in the rest of the tests

* feedback: separate transit testing into its own external test

* Update vault/consumption_billing_util_test.go

Co-authored-by: divyaac <divya.chandrasekaran@hashicorp.com>

* update comment after test name change

* fix comments

* fix comments in test

* another comment fix

* feedback: remove incorrect comment

* fix a CE test

* fix the update method: instead of storing max, increment by the current count value

* update the unit test, remove local prefix as argument to the methods since we store only to non-replicated paths

* update the external test

* Adds a field to backend to track billing data

removed file

* Changed implementation to use a map instead

* Some more comments

* Add more implementation

* Edited grpc server backend

* Refactored a bit

* Fix one more test

* Modified map:

* Revert "Modified map:"

This reverts commit 1730fe1f358b210e6abae43fbdca09e585aaaaa8.

* Removed some other things

* Edited consumption billing files a bit

* Testing function

* Fix transit stuff and make sure tests pass

* Changes

* More changes

* More changes

* Edited external test

* Edited some more tests

* Edited and fixed tests

* One more fix

* Fix some more tests

* Moved some testing structures around and added error checking

* Fixed some nits

* Update builtin/logical/transit/path_sign_verify.go

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>

* Edited some errors

* Fixed error logs

* Edited one more thing

* Decorate the error

* Update vault/consumption_billing.go

Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>

---------

Co-authored-by: Amir Aslamov <amir.aslamov@hashicorp.com>
Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>

* Edited stub function

---------

Co-authored-by: divyaac <divya.chandrasekaran@hashicorp.com>
Co-authored-by: Amir Aslamov <amir.aslamov@hashicorp.com>
Co-authored-by: Nick Cabatoff <ncabatoff@hashicorp.com>
Co-authored-by: divyaac <divyaac@berkeley.edu>
This commit is contained in:
Vault Automation 2026-02-03 17:48:12 -05:00 committed by GitHub
parent b3f173756d
commit caf642b7d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 249 additions and 210 deletions

View file

@ -10,7 +10,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/go-multierror"
@ -113,7 +112,6 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
if err != nil {
return nil, err
}
b.setupEnt()
return &b, nil
@ -124,6 +122,10 @@ type backend struct {
entBackend
lm *keysutil.LockManager
// billingDataCounts tracks successful data protection operations
// for this backend instance. It's intended for test assertions and avoids
// cross-test/package contamination from global counters.
billingDataCounts billing.DataProtectionCallCounts
// Lock to make changes to any of the backend's cache configuration.
configMutex sync.RWMutex
cacheSizeChanged bool
@ -148,9 +150,17 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error
return size, nil
}
// incrementDataProtectionCounter atomically increments the data protection call counter to avoid race conditions
func (b *backend) incrementDataProtectionCounter(count int64) {
atomic.AddInt64(&billing.CurrentDataProtectionCallCounts.Transit, count)
// incrementBillingCounts atomically increments the transit billing data counts
func (b *backend) incrementBillingCounts(ctx context.Context, count uint64) error {
// If we are a test, we need to increment this testing structure to verify the counts are correct.
if b.billingDataCounts.Transit != nil {
b.billingDataCounts.Transit.Add(count)
}
// Write billling data
return b.ConsumptionBillingManager.WriteBillingData(ctx, "transit", map[string]interface{}{
"count": count,
})
}
// Update cache size and get policy

View file

@ -20,6 +20,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@ -33,6 +34,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/billing"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
)
@ -53,6 +55,9 @@ func createBackendWithStorage(t testing.TB) (*backend, logical.Storage) {
if err != nil {
t.Fatal(err)
}
b.billingDataCounts = billing.DataProtectionCallCounts{
Transit: &atomic.Uint64{},
}
return b, config.StorageView
}
@ -74,6 +79,9 @@ func createBackendWithSysView(t testing.TB) (*backend, logical.Storage) {
if err != nil {
t.Fatal(err)
}
b.billingDataCounts = billing.DataProtectionCallCounts{
Transit: &atomic.Uint64{},
}
return b, storage
}
@ -95,6 +103,9 @@ func createBackendWithSysViewWithStorage(t testing.TB, s logical.Storage) *backe
if err != nil {
t.Fatal(err)
}
b.billingDataCounts = billing.DataProtectionCallCounts{
Transit: &atomic.Uint64{},
}
return b
}
@ -117,6 +128,9 @@ func createBackendWithForceNoCacheWithSysViewWithStorage(t testing.TB, s logical
if err != nil {
t.Fatal(err)
}
b.billingDataCounts = billing.DataProtectionCallCounts{
Transit: &atomic.Uint64{},
}
return b
}

View file

@ -172,7 +172,9 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d
// Increment the counter for successful operations
// Since there are not batched operations, we can add one successful
// request to the transit request counter.
b.incrementDataProtectionCounter(1)
if err = b.incrementBillingCounts(ctx, 1); err != nil {
b.Logger().Error("failed to track transit data key request count", "error", err.Error())
}
return resp, nil
}

View file

@ -8,7 +8,6 @@ import (
"testing"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/billing"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
)
@ -16,9 +15,6 @@ import (
// TestDataKeyWithPaddingScheme validates that we properly leverage padding scheme
// args for the returned keys
func TestDataKeyWithPaddingScheme(t *testing.T) {
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
keyName := "test"
createKeyReq := &logical.Request{
@ -95,15 +91,12 @@ func TestDataKeyWithPaddingScheme(t *testing.T) {
})
}
// We expect 8 successful requests ((3 valid cases x 2 operations) + (2 invalid cases x 1 operation))
require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(8), b.billingDataCounts.Transit.Load())
}
// TestDataKeyWithPaddingSchemeInvalidKeyType validates we fail when we specify a
// padding_scheme value on an invalid key type (non-RSA)
func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) {
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
keyName := "test"
createKeyReq := &logical.Request{
@ -132,5 +125,5 @@ func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) {
require.NotNil(t, resp, "response should not be nil")
require.Contains(t, resp.Error().Error(), "padding_scheme argument invalid: unsupported key")
// We expect 0 successful requests
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}

View file

@ -278,8 +278,9 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
}
}
// Increment the counter for successful operations
b.incrementDataProtectionCounter(int64(successfulRequests))
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
b.Logger().Error("failed to track transit decrypt request count", "error", err.Error())
}
return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch)
}

View file

@ -12,7 +12,6 @@ import (
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/billing"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
)
@ -23,9 +22,6 @@ func TestTransit_BatchDecryption(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
batchEncryptionInput := []interface{}{
map[string]interface{}{"plaintext": "", "reference": "foo"}, // empty string
map[string]interface{}{"plaintext": "Cg==", "reference": "bar"}, // newline
@ -71,12 +67,15 @@ func TestTransit_BatchDecryption(t *testing.T) {
expectedResult := "[{\"plaintext\":\"\",\"reference\":\"foo\"},{\"plaintext\":\"Cg==\",\"reference\":\"bar\"},{\"plaintext\":\"dGhlIHF1aWNrIGJyb3duIGZveA==\",\"reference\":\"baz\"}]"
jsonResponse, err := json.Marshal(batchDecryptionResponseItems)
if err != nil || err == nil && string(jsonResponse) != expectedResult {
if err != nil {
t.Fatalf("bad: failed to marshal response items: err=%v json=%s", err, jsonResponse)
}
if string(jsonResponse) != expectedResult {
t.Fatalf("bad: expected json response [%s]", jsonResponse)
}
// We expect 6 successful requests (3 for batch encryption, 3 for batch decryption)
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load())
}
func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
@ -86,9 +85,6 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
// Create a derived key.
req = &logical.Request{
Operation: logical.UpdateOperation,
@ -291,5 +287,5 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
}
// We expect 7 successful requests (2 for batch encryption + 1 single-item decryption + 2 batch decryption + 2 batch decryption)
require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(7), b.billingDataCounts.Transit.Load())
}

View file

@ -650,7 +650,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
}
// Increment the counter for successful operations
b.incrementDataProtectionCounter(int64(successfulRequests))
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
b.Logger().Error("failed to track transit encrypt request count", "error", err.Error())
}
return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch)
}

View file

@ -15,7 +15,6 @@ import (
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/billing"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
)
@ -24,9 +23,6 @@ func TestTransit_MissingPlaintext(t *testing.T) {
var resp *logical.Response
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
// Create the policy
@ -51,16 +47,13 @@ func TestTransit_MissingPlaintext(t *testing.T) {
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
}
// We expect 0 successful calls
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
var resp *logical.Response
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
// Create the policy
@ -92,7 +85,7 @@ func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
}
// We expect 0 successful calls
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
// Case1: Ensure that batch encryption did not affect the normal flow of
@ -101,9 +94,6 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) {
var resp *logical.Response
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
// Create the policy
@ -160,7 +150,7 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) {
}
// We expect 2 successful requests (1 for encrypt, 1 for decrypt)
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
}
// Case2: Ensure that batch encryption did not affect the normal flow of
@ -170,9 +160,6 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) {
var err error
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
// Upsert the key and encrypt the data
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
@ -228,14 +215,12 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) {
}
// We expect 2 successful requests (1 for encrypt, 1 for decrypt)
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
}
// Case3: If batch encryption input is not base64 encoded, it should fail.
func TestTransit_BatchEncryptionCase3(t *testing.T) {
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
@ -256,7 +241,7 @@ func TestTransit_BatchEncryptionCase3(t *testing.T) {
}
// We expect 0 successful requests
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
// Case4: Test batch encryption with an existing key (and test references)
@ -264,9 +249,6 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) {
var resp *logical.Response
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
policyReq := &logical.Request{
@ -331,7 +313,7 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) {
}
// We expect 4 successful requests (2 batch requests + 2 decrypt requests)
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
}
// Case5: Test batch encryption with an existing derived key
@ -339,9 +321,6 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) {
var resp *logical.Response
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
policyData := map[string]interface{}{
@ -409,7 +388,7 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) {
}
}
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
}
// Case6: Test batch encryption with an upserted non-derived key
@ -417,9 +396,6 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) {
var resp *logical.Response
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
batchInput := []interface{}{
@ -475,7 +451,7 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) {
}
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
}
// Case7: Test batch encryption with an upserted derived key
@ -485,9 +461,6 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
batchInput := []interface{}{
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
@ -536,7 +509,7 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) {
}
}
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
}
// Case8: If plaintext is not base64 encoded, encryption should fail
@ -546,9 +519,6 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
// Create the policy
policyReq := &logical.Request{
Operation: logical.UpdateOperation,
@ -594,7 +564,7 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) {
t.Fatal("expected an error")
}
// We expect 0 successful transit requests
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
// Case9: If both plaintext and batch inputs are supplied, plaintext should be
@ -605,9 +575,6 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
batchInput := []interface{}{
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
@ -634,7 +601,7 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) {
}
// We expect 2 successful batch encryptions
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
}
// Case10: Inconsistent presence of 'context' in batch input should be caught
@ -643,9 +610,6 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
batchInput := []interface{}{
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
@ -666,7 +630,7 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) {
t.Fatalf("expected an error")
}
// We expect no successful transit requests
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
// Case11: Incorrect inputs for context and nonce should not fail the operation
@ -675,9 +639,6 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
batchInput := []interface{}{
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "not-encoded"},
@ -697,7 +658,7 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) {
t.Fatal(err)
}
// We expect 1 successful encryption out of the 2-item batch
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(1), b.billingDataCounts.Transit.Load())
}
// Case12: Invalid batch input
@ -705,9 +666,6 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) {
var err error
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
batchInput := []interface{}{
map[string]interface{}{},
"unexpected_interface",
@ -727,7 +685,7 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) {
t.Fatalf("expected an error")
}
// We expect no successful requests
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
// Case13: Incorrect input for nonce when we aren't in convergent encryption should fail the operation
@ -736,9 +694,6 @@ func TestTransit_EncryptionCase13(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
// Non-batch first
data := map[string]interface{}{"plaintext": "bXkgc2VjcmV0IGRhdGE=", "nonce": "R80hr9eNUIuFV52e"}
req := &logical.Request{
@ -774,7 +729,7 @@ func TestTransit_EncryptionCase13(t *testing.T) {
t.Fatal("expected request error")
}
// We expect no successful transit requests
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
// Case14: Incorrect input for nonce when we are in convergent version 3 should fail
@ -783,9 +738,6 @@ func TestTransit_EncryptionCase14(t *testing.T) {
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
cReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/my-key",
@ -836,7 +788,7 @@ func TestTransit_EncryptionCase14(t *testing.T) {
t.Fatal("expected request error")
}
// We expect no successful transit requests
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
// Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem.

View file

@ -254,8 +254,9 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr
}
}
// Increment the counter for successful operations
b.incrementDataProtectionCounter(int64(successfulRequests))
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
b.Logger().Error("failed to track transit hmac request count", "error", err.Error())
}
return resp, nil
}
@ -432,8 +433,9 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f
}
}
// Increment the counter for successful operations
b.incrementDataProtectionCounter(int64(successfulRequests))
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
b.Logger().Error("failed to track transit hmac verify request count", "error", err.Error())
}
return resp, nil
}

View file

@ -16,14 +16,10 @@ import (
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/billing"
"github.com/stretchr/testify/require"
)
func TestTransit_HMAC(t *testing.T) {
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, storage := createBackendWithSysView(t)
cases := []struct {
@ -249,13 +245,10 @@ func TestTransit_HMAC(t *testing.T) {
}
}
// Verify the total successful transit requests
require.Equal(t, int64(72), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(72), b.billingDataCounts.Transit.Load())
}
func TestTransit_batchHMAC(t *testing.T) {
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, storage := createBackendWithSysView(t)
// First create a key
@ -411,7 +404,7 @@ func TestTransit_batchHMAC(t *testing.T) {
}
// Verify the total successful transit requests
require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(5), b.billingDataCounts.Transit.Load())
}
// TestHMACBatchResultsFields checks that responses to HMAC verify requests using batch_input
@ -440,9 +433,6 @@ func TestHMACBatchResultsFields(t *testing.T) {
})
require.NoError(t, err)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
keyName := "hmac-test-key"
_, err = client.Logical().Write("transit/keys/"+keyName, map[string]interface{}{"type": "hmac", "key_size": 32})
require.NoError(t, err)
@ -505,5 +495,5 @@ func TestHMACBatchResultsFields(t *testing.T) {
}
// We expect 4 successful requests (2 for batch HMAC generation, 2 for batch verification)
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(4), cores[0].GetInMemoryTransitDataProtectionCallCounts())
}

View file

@ -308,8 +308,9 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d *
resp.AddWarning("A provided nonce value was used within FIPS mode, this violates FIPS 140 compliance.")
}
// Increment the counter for successful operations
b.incrementDataProtectionCounter(int64(successfulRequests))
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
b.Logger().Error("failed to track transit rewrap request count", "error", err.Error())
}
return resp, nil
}

View file

@ -9,7 +9,6 @@ import (
"testing"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/billing"
"github.com/stretchr/testify/require"
)
@ -19,9 +18,6 @@ func TestTransit_BatchRewrapCase1(t *testing.T) {
var err error
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
// Upsert the key and encrypt the data
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
@ -119,7 +115,7 @@ func TestTransit_BatchRewrapCase1(t *testing.T) {
}
// We expect 2 successful requests (1 for encrypt, 1 for rewrap)
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
}
// Check the normal flow of rewrap with upserted key
@ -128,9 +124,6 @@ func TestTransit_BatchRewrapCase2(t *testing.T) {
var err error
b, s := createBackendWithStorage(t)
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
// Upsert the key and encrypt the data
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
@ -229,7 +222,7 @@ func TestTransit_BatchRewrapCase2(t *testing.T) {
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2)
}
// We expect 2 successful transit requests (1 for encrypt, 1 for rewrap)
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
}
// Batch encrypt plaintexts, rotate the keys and rewrap all the ciphertexts
@ -237,9 +230,6 @@ func TestTransit_BatchRewrapCase3(t *testing.T) {
var resp *logical.Response
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
batchEncryptionInput := []interface{}{
@ -341,7 +331,7 @@ func TestTransit_BatchRewrapCase3(t *testing.T) {
}
}
// We expect 6 successful transit requests (2 for batch encryption, 2 for batch rewrap, and 2 for decryption)
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load())
}
// TestTransit_BatchRewrapCase4 batch rewrap leveraging RSA padding schemes
@ -349,9 +339,6 @@ func TestTransit_BatchRewrapCase4(t *testing.T) {
var resp *logical.Response
var err error
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, s := createBackendWithStorage(t)
batchEncryptionInput := []interface{}{
@ -459,5 +446,5 @@ func TestTransit_BatchRewrapCase4(t *testing.T) {
}
}
// We expect 6 succcessful calls to the transit backend (2 for batch encryption, 2 for batch decryption, and 2 for batch rewrap)
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load())
}

View file

@ -492,8 +492,9 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
}
}
// Increment the counter for successful operations
b.incrementDataProtectionCounter(int64(successfulRequests))
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
b.Logger().Error("failed to track transit sign request count", "error", err.Error())
}
return resp, nil
}
@ -750,8 +751,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
}
}
// Increment the counter for successful operations
b.incrementDataProtectionCounter(int64(successfulRequests))
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
b.Logger().Error("failed to track transit sign verify request count", "error", err.Error())
}
return resp, nil
}

View file

@ -15,7 +15,6 @@ import (
"github.com/hashicorp/vault/helper/constants"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/billing"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519"
@ -33,9 +32,6 @@ type signOutcome struct {
}
func TestTransit_SignVerify_ECDSA(t *testing.T) {
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
t.Run("256", func(t *testing.T) {
testTransit_SignVerify_ECDSA(t, 256)
})
@ -45,9 +41,6 @@ func TestTransit_SignVerify_ECDSA(t *testing.T) {
t.Run("521", func(t *testing.T) {
testTransit_SignVerify_ECDSA(t, 521)
})
// Verify the total successful transit requests
require.Equal(t, int64(84), billing.CurrentDataProtectionCallCounts.Transit)
}
func testTransit_SignVerify_ECDSA(t *testing.T, bits int) {
@ -330,6 +323,7 @@ func testTransit_SignVerify_ECDSA(t *testing.T, bits int) {
verifyRequest(req, false, "", sig)
// Now try the v1
verifyRequest(req, true, "", v1sig)
require.Equal(t, uint64(28), b.billingDataCounts.Transit.Load())
}
func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, expectValid bool, postpath string, b *backend) {
@ -380,9 +374,6 @@ func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, ex
// TestTransit_SignVerify_Ed25519Behavior makes sure the options on ENT for a
// Ed25519ph/ctx signature fail on CE and ENT if invalid
func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) {
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, storage := createBackendWithSysView(t)
// First create a key
@ -477,17 +468,14 @@ func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) {
}
// Verify the total successful transit requests
if constants.IsEnterprise {
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
} else {
// We expect 0 successful calls on CE because we expect the verify to fail
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
}
}
func TestTransit_SignVerify_ED25519(t *testing.T) {
// Reset the transit counter
billing.CurrentDataProtectionCallCounts.Transit = 0
b, storage := createBackendWithSysView(t)
// First create a key
@ -832,7 +820,7 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
verifyRequest(req, false, outcome, "bar", goodsig, true)
// Verify the total successful transit requests
require.Equal(t, int64(24), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(24), b.billingDataCounts.Transit.Load())
}
func TestTransit_SignVerify_RSA_PSS(t *testing.T) {

View file

@ -118,6 +118,9 @@ type Backend struct {
// communicate with a plugin to activate a feature.
ActivationFunc func(context.Context, *logical.Request, string) error
// ConsumptionBillingManager is the consumption billing manager the backend can use to write billing data.
ConsumptionBillingManager logical.ConsumptionBillingManager
logger log.Logger
system logical.SystemView
events logical.EventSender
@ -440,6 +443,11 @@ func (b *Backend) Setup(ctx context.Context, config *logical.BackendConfig) erro
b.system = config.System
b.events = config.EventsSender
b.observations = config.ObservationRecorder
if b.System() != nil && b.System().GetConsumptionBillingManager() != nil {
b.ConsumptionBillingManager = b.System().GetConsumptionBillingManager()
} else {
b.ConsumptionBillingManager = logical.NewNullConsumptionBillingManager()
}
return nil
}
@ -546,7 +554,7 @@ func (b *Backend) init() {
for i, p := range b.Paths {
// Detect the coding error of failing to initialise Pattern
if len(p.Pattern) == 0 {
panic(fmt.Sprintf("Routing pattern cannot be blank"))
panic("Routing pattern cannot be blank")
}
// Detect the coding error of attempting to define a CreateOperation without defining an ExistenceCheck

View file

@ -0,0 +1,25 @@
// Copyright IBM Corp. 2016, 2025
// SPDX-License-Identifier: MPL-2.0
package logical
import "context"
// billing.ConsumptionBillingManager is an implementation of this interface that the backend can use to write billing data.
type ConsumptionBillingManager interface {
WriteBillingData(ctx context.Context, pluginType string, data map[string]interface{}) error
}
// ================================
// Creates a null consumption billing manager that does nothing
var _ ConsumptionBillingManager = (*nullConsumptionBillingManager)(nil)
func NewNullConsumptionBillingManager() ConsumptionBillingManager {
return &nullConsumptionBillingManager{}
}
type nullConsumptionBillingManager struct{}
func (n *nullConsumptionBillingManager) WriteBillingData(ctx context.Context, pluginType string, data map[string]interface{}) error {
return nil
}

View file

@ -117,6 +117,9 @@ type SystemView interface {
// ExtractVerifyPlugin extracts and verifies the plugin artifact
DownloadExtractVerifyPlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error
// GetConsumptionBillingManager returns the consumption billing manager
GetConsumptionBillingManager() ConsumptionBillingManager
}
type PasswordPolicy interface {
@ -320,6 +323,10 @@ func (d StaticSystemView) DownloadExtractVerifyPlugin(_ context.Context, _ *plug
return errors.New("DownloadExtractVerifyPlugin is not implemented in StaticSystemView")
}
func (d StaticSystemView) GetConsumptionBillingManager() ConsumptionBillingManager {
return nil
}
// PluginLicenseUtil defines the functions needed to request License and PluginEnv
// by the plugin licensing under github.com/hashicorp/vault-licensing
// This only should be used by the plugin to get the license and plugin environment

View file

@ -36,6 +36,11 @@ type gRPCSystemViewClient struct {
client pb.SystemViewClient
}
func (s *gRPCSystemViewClient) GetConsumptionBillingManager() logical.ConsumptionBillingManager {
// Not implemented on pluginbackend
return nil
}
func (s *gRPCSystemViewClient) DefaultLeaseTTL() time.Duration {
reply, err := s.client.DefaultLeaseTTL(context.Background(), &pb.Empty{})
if err != nil {

View file

@ -4,19 +4,24 @@
package billing
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/logical"
)
const (
BillingSubPath = "billing/"
ReplicatedPrefix = "replicated/"
RoleHWMCountsHWM = "maxRoleCounts/"
KvHWMCountsHWM = "maxKvCounts/"
DataProtectionCallCountsMetric = "dataProtectionCallCounts/"
LocalPrefix = "local/"
ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/"
BillingSubPath = "billing/"
ReplicatedPrefix = "replicated/"
RoleHWMCountsHWM = "maxRoleCounts/"
KvHWMCountsHWM = "maxKvCounts/"
TransitDataProtectionCallCountsPrefix = "transitDataProtectionCallCounts/"
LocalPrefix = "local/"
ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/"
BillingWriteInterval = 10 * time.Minute
)
@ -27,7 +32,9 @@ type ConsumptionBilling struct {
// BillingStorageLock controls access to the billing storage paths
BillingStorageLock sync.RWMutex
BillingConfig BillingConfig
BillingConfig BillingConfig
DataProtectionCallCounts DataProtectionCallCounts
Logger log.Logger
}
type BillingConfig struct {
@ -45,10 +52,26 @@ func GetMonthlyBillingPath(localPrefix string, now time.Time, billingMetric stri
}
type DataProtectionCallCounts struct {
Transit int64 `json:"transit,omitempty"`
Transit *atomic.Uint64 `json:"transit,omitempty"`
// TODO: Uncomment when we add support for Transform tracking (VAULT-41205)
// Transform int64 `json:"transform,omitempty"`
// Transform atomic.int64 `json:"transform,omitempty"`
}
// Global counter for all data protection calls on this cluster
var CurrentDataProtectionCallCounts = DataProtectionCallCounts{}
var _ logical.ConsumptionBillingManager = (*ConsumptionBilling)(nil)
func (s *ConsumptionBilling) WriteBillingData(ctx context.Context, mountType string, data map[string]interface{}) error {
switch mountType {
case "transit":
val, ok := data["count"].(uint64)
if !ok {
err := fmt.Errorf("invalid value type for transit")
return err
}
s.DataProtectionCallCounts.Transit.Add(val)
default:
err := fmt.Errorf("unknown metric type: %s", mountType)
return err
}
return nil
}

View file

@ -5,6 +5,8 @@ package vault
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/hashicorp/vault/helper/timeutil"
@ -15,8 +17,14 @@ func (c *Core) setupConsumptionBilling(ctx context.Context) error {
// We need replication (post unseal) to start before we run the consumption billing metrics worker
// This is because there is primary/secondary cluster specific logic
c.consumptionBillingLock.Lock()
logger := c.baseLogger.Named("billing")
c.AddLogger(logger)
c.consumptionBilling = &billing.ConsumptionBilling{
BillingConfig: c.billingConfig,
DataProtectionCallCounts: billing.DataProtectionCallCounts{
Transit: &atomic.Uint64{},
},
Logger: logger,
}
c.consumptionBillingLock.Unlock()
c.postUnsealFuncs = append(c.postUnsealFuncs, func() {
@ -72,7 +80,12 @@ func (c *Core) updateBillingMetrics(ctx context.Context) error {
c.UpdateReplicatedHWMMetrics(ctx, currentMonth)
}
c.UpdateLocalHWMMetrics(ctx, currentMonth)
c.UpdateLocalAggregatedMetrics(ctx, currentMonth)
if err := c.UpdateLocalAggregatedMetrics(ctx, currentMonth); err != nil {
c.logger.Error("error updating cluster data protection call counts", "error", err)
} else {
c.logger.Info("updated cluster data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth)
}
}
return nil
}
@ -119,10 +132,7 @@ func (c *Core) UpdateLocalHWMMetrics(ctx context.Context, currentMonth time.Time
func (c *Core) UpdateLocalAggregatedMetrics(ctx context.Context, currentMonth time.Time) error {
if _, err := c.UpdateDataProtectionCallCounts(ctx, currentMonth); err != nil {
c.logger.Error("error updating local max data protection call counts", "error", err)
} else {
c.logger.Info("updated local max data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth)
return fmt.Errorf("could not store transit data protection call counts: %w", err)
}
return nil
}

View file

@ -0,0 +1,12 @@
// Copyright IBM Corp. 2016, 2025
// SPDX-License-Identifier: MPL-2.0
package vault
func (c *Core) ResetInMemoryDataProtectionCallCounts() {
c.consumptionBilling.DataProtectionCallCounts.Transit.Store(0)
}
func (c *Core) GetInMemoryTransitDataProtectionCallCounts() uint64 {
return c.consumptionBilling.DataProtectionCallCounts.Transit.Load()
}

View file

@ -267,29 +267,36 @@ func (c *Core) GetBillingSubView() *BarrierView {
// storeDataProtectionCallCountsLocked must be called with BillingStorageLock held
func (c *Core) storeDataProtectionCallCountsLocked(ctx context.Context, maxCounts *billing.DataProtectionCallCounts, localPathPrefix string, month time.Time) error {
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric)
entry, err := logical.StorageEntryJSON(billingPath, maxCounts)
if err != nil {
return err
// Store count for each data protection type separately because they are atomic counters
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.TransitDataProtectionCallCountsPrefix)
transitCount := maxCounts.Transit.Load()
entry := &logical.StorageEntry{
Key: billingPath,
Value: []byte(strconv.FormatUint(transitCount, 10)),
}
return c.GetBillingSubView().Put(ctx, entry)
}
// getStoredDataProtectionCallCountsLocked must be called with BillingStorageLock held
func (c *Core) getStoredDataProtectionCallCountsLocked(ctx context.Context, localPathPrefix string, month time.Time) (*billing.DataProtectionCallCounts, error) {
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric)
// Retrieve count for each data protection type separately because they are atomic counters
ret := &billing.DataProtectionCallCounts{
Transit: &atomic.Uint64{},
}
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.TransitDataProtectionCallCountsPrefix)
entry, err := c.GetBillingSubView().Get(ctx, billingPath)
if err != nil {
return nil, err
}
if entry == nil {
return &billing.DataProtectionCallCounts{}, nil
return ret, nil
}
var maxCounts billing.DataProtectionCallCounts
if err := entry.DecodeJSON(&maxCounts); err != nil {
transitCount, err := strconv.ParseUint(string(entry.Value), 10, 64)
if err != nil {
return nil, err
}
return &maxCounts, nil
ret.Transit.Store(transitCount)
return ret, nil
}
func (c *Core) GetStoredDataProtectionCallCounts(ctx context.Context, month time.Time) (*billing.DataProtectionCallCounts, error) {
@ -302,10 +309,6 @@ func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth
c.consumptionBilling.BillingStorageLock.Lock()
defer c.consumptionBilling.BillingStorageLock.Unlock()
// Read and atomically reset the current counter value
// TODO: Reset Transform call counts too (VAULT-41205)
currentTransitCount := atomic.SwapInt64(&billing.CurrentDataProtectionCallCounts.Transit, 0)
storedDataProtectionCallCounts, err := c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, currentMonth)
if err != nil {
return nil, err
@ -315,7 +318,8 @@ func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth
}
// Sum the current count with the stored count
storedDataProtectionCallCounts.Transit += currentTransitCount
transitCount := c.consumptionBilling.DataProtectionCallCounts.Transit.Swap(0)
storedDataProtectionCallCounts.Transit.Add(transitCount)
// TODO: Update Transform call counts (VAULT-41205)
err = c.storeDataProtectionCallCountsLocked(ctx, storedDataProtectionCallCounts, billing.LocalPrefix, currentMonth)

View file

@ -42,7 +42,7 @@ func verifyExpectedRoleCounts(t *testing.T, actual *RoleCounts, baseCount int) {
// testCMACOperations is a no-op in OSS since CMAC is an Enterprise-only feature.
// Returns the current count unchanged.
func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount int64) int64 {
func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount uint64) uint64 {
// CMAC is not supported in OSS, so we don't perform any operations
return currentCount
}

View file

@ -420,9 +420,6 @@ func TestDataProtectionCallCounts(t *testing.T) {
_, err := core.HandleRequest(ctx, req)
require.NoError(t, err)
// Reset the transit counters
billing.CurrentDataProtectionCallCounts.Transit = 0
// Create an encryption key
req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/foo")
req.Data["type"] = "aes256-gcm96"
@ -439,8 +436,8 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NotNil(t, resp)
require.NotNil(t, resp.Data)
// Verify that the transit counter is incremented (replicated mount by default)
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
// Verify that the transit counter is incremented
require.Equal(t, uint64(1), core.GetInMemoryTransitDataProtectionCallCounts())
// Get the ciphertext from the encryption response
ciphertext, ok := resp.Data["ciphertext"].(string)
@ -455,7 +452,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NoError(t, err)
// Verify that the transit counter is incremented
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(2), core.GetInMemoryTransitDataProtectionCallCounts())
// Test rewrap operation
req = logical.TestRequest(t, logical.UpdateOperation, "transit/rewrap/foo")
@ -467,7 +464,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NotNil(t, resp.Data)
// Verify that the transit counter is incremented
require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(3), core.GetInMemoryTransitDataProtectionCallCounts())
// Get the new ciphertext from rewrap
newCiphertext, ok := resp.Data["ciphertext"].(string)
@ -483,9 +480,9 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NotNil(t, resp.Data)
// Verify that the transit counter is incremented
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(4), core.GetInMemoryTransitDataProtectionCallCounts())
// Test HMAC generation
// Test HMAC operation
req = logical.TestRequest(t, logical.UpdateOperation, "transit/hmac/foo")
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
req.ClientToken = root
@ -495,7 +492,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NotNil(t, resp.Data)
// Verify that the transit counter is incremented
require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(5), core.GetInMemoryTransitDataProtectionCallCounts())
// Get the HMAC value
hmacValue, ok := resp.Data["hmac"].(string)
@ -513,7 +510,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NotNil(t, resp.Data)
// Verify that the transit counter is incremented
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(6), core.GetInMemoryTransitDataProtectionCallCounts())
// Verify the HMAC is valid
hmacValid, ok := resp.Data["valid"].(bool)
@ -537,7 +534,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NotNil(t, resp.Data)
// Verify that the transit counter is incremented
require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(7), core.GetInMemoryTransitDataProtectionCallCounts())
// Get the signature
signature, ok := resp.Data["signature"].(string)
@ -555,7 +552,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NotNil(t, resp.Data)
// Verify that the transit counter is incremented
require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(8), core.GetInMemoryTransitDataProtectionCallCounts())
// Verify the signature is valid
signatureValid, ok := resp.Data["valid"].(bool)
@ -563,27 +560,27 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.True(t, signatureValid)
// Test CMAC operations (ENT only - will be no-op in OSS)
currentCount := billing.CurrentDataProtectionCallCounts.Transit
currentCount := core.GetInMemoryTransitDataProtectionCallCounts()
currentCount = testCMACOperations(t, core, ctx, root, currentCount)
// Verify that the transit counter matches expected count
require.Equal(t, currentCount, billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, currentCount, core.GetInMemoryTransitDataProtectionCallCounts())
// Now test persisting the summed counts - store and retrieve counts
// First, update the data protection call counts (this will sum current counter with stored value)
summedCounts, err := core.UpdateDataProtectionCallCounts(ctx, time.Now())
require.NoError(t, err)
require.NotNil(t, summedCounts)
require.Equal(t, currentCount, summedCounts.Transit)
require.Equal(t, currentCount, summedCounts.Transit.Load())
// Verify the counter was reset after update
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update")
// Retrieve the stored counts
storedCounts, err := core.GetStoredDataProtectionCallCounts(ctx, time.Now())
require.NoError(t, err)
require.NotNil(t, storedCounts)
require.Equal(t, currentCount, storedCounts.Transit)
require.Equal(t, currentCount, storedCounts.Transit.Load())
// Perform more operations to increase the counter
req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo")
@ -593,21 +590,21 @@ func TestDataProtectionCallCounts(t *testing.T) {
require.NoError(t, err)
// Counter should now be 1 (reset + 1 operation)
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(1), core.GetInMemoryTransitDataProtectionCallCounts())
// Update counts again - should sum the new count (1) with the stored count (currentCount)
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
require.NoError(t, err)
expectedSum := currentCount + 1
require.Equal(t, expectedSum, summedCounts.Transit, "Count should be sum of stored and current")
require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should be sum of stored and current")
// Verify the counter was reset after update
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update")
// Verify stored counts are now the sum
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
require.NoError(t, err)
require.Equal(t, expectedSum, storedCounts.Transit)
require.Equal(t, expectedSum, storedCounts.Transit.Load())
// Add more operations without manually resetting
for i := 0; i < 3; i++ {
@ -619,35 +616,35 @@ func TestDataProtectionCallCounts(t *testing.T) {
}
// Counter should be 3
require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit)
require.Equal(t, uint64(3), core.GetInMemoryTransitDataProtectionCallCounts())
// Update counts - should sum 3 with the previous stored sum
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
require.NoError(t, err)
expectedSum = expectedSum + 3
require.Equal(t, expectedSum, summedCounts.Transit, "Count should continue to sum")
require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should continue to sum")
// Verify the counter was reset after update
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update")
// Verify stored counts
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
require.NoError(t, err)
require.Equal(t, expectedSum, storedCounts.Transit)
require.Equal(t, expectedSum, storedCounts.Transit.Load())
// Update again without any new operations
// This verifies we don't double-count
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
require.NoError(t, err)
require.Equal(t, expectedSum, summedCounts.Transit, "Count should remain the same when no new operations occurred")
require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should remain the same when no new operations occurred")
// Verify stored counts haven't changed
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
require.NoError(t, err)
require.Equal(t, expectedSum, storedCounts.Transit, "Stored count should remain the same")
require.Equal(t, expectedSum, storedCounts.Transit.Load(), "Stored count should remain the same")
// Verify counter is still at 0
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should still be 0")
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should still be 0")
}
func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) {

View file

@ -78,3 +78,7 @@ func (c *Core) setupHeaderHMACKey(ctx context.Context, isPerfStandby bool) error
func (c *Core) GetPkiCertificateCounter() logical.CertificateCounter {
return c.pkiCertCountManager
}
func (c *Core) GetConsumptionBillingManager() logical.ConsumptionBillingManager {
return c.consumptionBilling
}

View file

@ -401,3 +401,7 @@ func (d dynamicSystemView) DeregisterRotationJob(ctx context.Context, req *rotat
return d.core.DeregisterRotationJob(nsCtx, req)
}
func (d dynamicSystemView) GetConsumptionBillingManager() logical.ConsumptionBillingManager {
return d.core.GetConsumptionBillingManager()
}