diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index 971ac7e2ca..662d0bf2b6 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -10,7 +10,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -113,7 +112,6 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error) if err != nil { return nil, err } - b.setupEnt() return &b, nil @@ -124,6 +122,10 @@ type backend struct { entBackend lm *keysutil.LockManager + // billingDataCounts tracks successful data protection operations + // for this backend instance. It's intended for test assertions and avoids + // cross-test/package contamination from global counters. + billingDataCounts billing.DataProtectionCallCounts // Lock to make changes to any of the backend's cache configuration. configMutex sync.RWMutex cacheSizeChanged bool @@ -148,9 +150,17 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error return size, nil } -// incrementDataProtectionCounter atomically increments the data protection call counter to avoid race conditions -func (b *backend) incrementDataProtectionCounter(count int64) { - atomic.AddInt64(&billing.CurrentDataProtectionCallCounts.Transit, count) +// incrementBillingCounts atomically increments the transit billing data counts +func (b *backend) incrementBillingCounts(ctx context.Context, count uint64) error { + // If we are a test, we need to increment this testing structure to verify the counts are correct. + if b.billingDataCounts.Transit != nil { + b.billingDataCounts.Transit.Add(count) + } + + // Write billling data + return b.ConsumptionBillingManager.WriteBillingData(ctx, "transit", map[string]interface{}{ + "count": count, + }) } // Update cache size and get policy diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index f99df68ba1..1d6ffcd0b1 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -20,6 +20,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -33,6 +34,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -53,6 +55,9 @@ func createBackendWithStorage(t testing.TB) (*backend, logical.Storage) { if err != nil { t.Fatal(err) } + b.billingDataCounts = billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } return b, config.StorageView } @@ -74,6 +79,9 @@ func createBackendWithSysView(t testing.TB) (*backend, logical.Storage) { if err != nil { t.Fatal(err) } + b.billingDataCounts = billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } return b, storage } @@ -95,6 +103,9 @@ func createBackendWithSysViewWithStorage(t testing.TB, s logical.Storage) *backe if err != nil { t.Fatal(err) } + b.billingDataCounts = billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } return b } @@ -117,6 +128,9 @@ func createBackendWithForceNoCacheWithSysViewWithStorage(t testing.TB, s logical if err != nil { t.Fatal(err) } + b.billingDataCounts = billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } return b } diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index f8b851c11e..18b6db37ea 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -172,7 +172,9 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d // Increment the counter for successful operations // Since there are not batched operations, we can add one successful // request to the transit request counter. - b.incrementDataProtectionCounter(1) + if err = b.incrementBillingCounts(ctx, 1); err != nil { + b.Logger().Error("failed to track transit data key request count", "error", err.Error()) + } return resp, nil } diff --git a/builtin/logical/transit/path_datakey_test.go b/builtin/logical/transit/path_datakey_test.go index c8150afc69..2ed1c640d0 100644 --- a/builtin/logical/transit/path_datakey_test.go +++ b/builtin/logical/transit/path_datakey_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -16,9 +15,6 @@ import ( // TestDataKeyWithPaddingScheme validates that we properly leverage padding scheme // args for the returned keys func TestDataKeyWithPaddingScheme(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) keyName := "test" createKeyReq := &logical.Request{ @@ -95,15 +91,12 @@ func TestDataKeyWithPaddingScheme(t *testing.T) { }) } // We expect 8 successful requests ((3 valid cases x 2 operations) + (2 invalid cases x 1 operation)) - require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(8), b.billingDataCounts.Transit.Load()) } // TestDataKeyWithPaddingSchemeInvalidKeyType validates we fail when we specify a // padding_scheme value on an invalid key type (non-RSA) func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) keyName := "test" createKeyReq := &logical.Request{ @@ -132,5 +125,5 @@ func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) { require.NotNil(t, resp, "response should not be nil") require.Contains(t, resp.Error().Error(), "padding_scheme argument invalid: unsupported key") // We expect 0 successful requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index f25275d352..08c20fc0a3 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -278,8 +278,9 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit decrypt request count", "error", err.Error()) + } return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch) } diff --git a/builtin/logical/transit/path_decrypt_test.go b/builtin/logical/transit/path_decrypt_test.go index b2f8046f16..ecf0139198 100644 --- a/builtin/logical/transit/path_decrypt_test.go +++ b/builtin/logical/transit/path_decrypt_test.go @@ -12,7 +12,6 @@ import ( "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -23,9 +22,6 @@ func TestTransit_BatchDecryption(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchEncryptionInput := []interface{}{ map[string]interface{}{"plaintext": "", "reference": "foo"}, // empty string map[string]interface{}{"plaintext": "Cg==", "reference": "bar"}, // newline @@ -71,12 +67,15 @@ func TestTransit_BatchDecryption(t *testing.T) { expectedResult := "[{\"plaintext\":\"\",\"reference\":\"foo\"},{\"plaintext\":\"Cg==\",\"reference\":\"bar\"},{\"plaintext\":\"dGhlIHF1aWNrIGJyb3duIGZveA==\",\"reference\":\"baz\"}]" jsonResponse, err := json.Marshal(batchDecryptionResponseItems) - if err != nil || err == nil && string(jsonResponse) != expectedResult { + if err != nil { + t.Fatalf("bad: failed to marshal response items: err=%v json=%s", err, jsonResponse) + } + if string(jsonResponse) != expectedResult { t.Fatalf("bad: expected json response [%s]", jsonResponse) } // We expect 6 successful requests (3 for batch encryption, 3 for batch decryption) - require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load()) } func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { @@ -86,9 +85,6 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Create a derived key. req = &logical.Request{ Operation: logical.UpdateOperation, @@ -291,5 +287,5 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { } // We expect 7 successful requests (2 for batch encryption + 1 single-item decryption + 2 batch decryption + 2 batch decryption) - require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(7), b.billingDataCounts.Transit.Load()) } diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index 4ebaf2e2b1..f568bf64ec 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -650,7 +650,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d } // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit encrypt request count", "error", err.Error()) + } return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch) } diff --git a/builtin/logical/transit/path_encrypt_test.go b/builtin/logical/transit/path_encrypt_test.go index 05ab5e2fd3..526fe9956f 100644 --- a/builtin/logical/transit/path_encrypt_test.go +++ b/builtin/logical/transit/path_encrypt_test.go @@ -15,7 +15,6 @@ import ( uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -24,9 +23,6 @@ func TestTransit_MissingPlaintext(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) // Create the policy @@ -51,16 +47,13 @@ func TestTransit_MissingPlaintext(t *testing.T) { t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp) } // We expect 0 successful calls - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } func TestTransit_MissingPlaintextInBatchInput(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) // Create the policy @@ -92,7 +85,7 @@ func TestTransit_MissingPlaintextInBatchInput(t *testing.T) { t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp) } // We expect 0 successful calls - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case1: Ensure that batch encryption did not affect the normal flow of @@ -101,9 +94,6 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) // Create the policy @@ -160,7 +150,7 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) { } // We expect 2 successful requests (1 for encrypt, 1 for decrypt) - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Case2: Ensure that batch encryption did not affect the normal flow of @@ -170,9 +160,6 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) { var err error b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -228,14 +215,12 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) { } // We expect 2 successful requests (1 for encrypt, 1 for decrypt) - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Case3: If batch encryption input is not base64 encoded, it should fail. func TestTransit_BatchEncryptionCase3(t *testing.T) { var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 b, s := createBackendWithStorage(t) @@ -256,7 +241,7 @@ func TestTransit_BatchEncryptionCase3(t *testing.T) { } // We expect 0 successful requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case4: Test batch encryption with an existing key (and test references) @@ -264,9 +249,6 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) policyReq := &logical.Request{ @@ -331,7 +313,7 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) { } // We expect 4 successful requests (2 batch requests + 2 decrypt requests) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } // Case5: Test batch encryption with an existing derived key @@ -339,9 +321,6 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) policyData := map[string]interface{}{ @@ -409,7 +388,7 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) { } } // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } // Case6: Test batch encryption with an upserted non-derived key @@ -417,9 +396,6 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) batchInput := []interface{}{ @@ -475,7 +451,7 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) { } // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } // Case7: Test batch encryption with an upserted derived key @@ -485,9 +461,6 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, @@ -536,7 +509,7 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) { } } // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } // Case8: If plaintext is not base64 encoded, encryption should fail @@ -546,9 +519,6 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Create the policy policyReq := &logical.Request{ Operation: logical.UpdateOperation, @@ -594,7 +564,7 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) { t.Fatal("expected an error") } // We expect 0 successful transit requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case9: If both plaintext and batch inputs are supplied, plaintext should be @@ -605,9 +575,6 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, @@ -634,7 +601,7 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) { } // We expect 2 successful batch encryptions - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Case10: Inconsistent presence of 'context' in batch input should be caught @@ -643,9 +610,6 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, @@ -666,7 +630,7 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) { t.Fatalf("expected an error") } // We expect no successful transit requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case11: Incorrect inputs for context and nonce should not fail the operation @@ -675,9 +639,6 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "not-encoded"}, @@ -697,7 +658,7 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) { t.Fatal(err) } // We expect 1 successful encryption out of the 2-item batch - require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(1), b.billingDataCounts.Transit.Load()) } // Case12: Invalid batch input @@ -705,9 +666,6 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) { var err error b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{}, "unexpected_interface", @@ -727,7 +685,7 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) { t.Fatalf("expected an error") } // We expect no successful requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case13: Incorrect input for nonce when we aren't in convergent encryption should fail the operation @@ -736,9 +694,6 @@ func TestTransit_EncryptionCase13(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Non-batch first data := map[string]interface{}{"plaintext": "bXkgc2VjcmV0IGRhdGE=", "nonce": "R80hr9eNUIuFV52e"} req := &logical.Request{ @@ -774,7 +729,7 @@ func TestTransit_EncryptionCase13(t *testing.T) { t.Fatal("expected request error") } // We expect no successful transit requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case14: Incorrect input for nonce when we are in convergent version 3 should fail @@ -783,9 +738,6 @@ func TestTransit_EncryptionCase14(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - cReq := &logical.Request{ Operation: logical.UpdateOperation, Path: "keys/my-key", @@ -836,7 +788,7 @@ func TestTransit_EncryptionCase14(t *testing.T) { t.Fatal("expected request error") } // We expect no successful transit requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem. diff --git a/builtin/logical/transit/path_hmac.go b/builtin/logical/transit/path_hmac.go index 316fa608f1..47776a3bb7 100644 --- a/builtin/logical/transit/path_hmac.go +++ b/builtin/logical/transit/path_hmac.go @@ -254,8 +254,9 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit hmac request count", "error", err.Error()) + } return resp, nil } @@ -432,8 +433,9 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit hmac verify request count", "error", err.Error()) + } return resp, nil } diff --git a/builtin/logical/transit/path_hmac_test.go b/builtin/logical/transit/path_hmac_test.go index 13c1e5e556..29fba94488 100644 --- a/builtin/logical/transit/path_hmac_test.go +++ b/builtin/logical/transit/path_hmac_test.go @@ -16,14 +16,10 @@ import ( "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" - "github.com/hashicorp/vault/vault/billing" "github.com/stretchr/testify/require" ) func TestTransit_HMAC(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, storage := createBackendWithSysView(t) cases := []struct { @@ -249,13 +245,10 @@ func TestTransit_HMAC(t *testing.T) { } } // Verify the total successful transit requests - require.Equal(t, int64(72), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(72), b.billingDataCounts.Transit.Load()) } func TestTransit_batchHMAC(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, storage := createBackendWithSysView(t) // First create a key @@ -411,7 +404,7 @@ func TestTransit_batchHMAC(t *testing.T) { } // Verify the total successful transit requests - require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(5), b.billingDataCounts.Transit.Load()) } // TestHMACBatchResultsFields checks that responses to HMAC verify requests using batch_input @@ -440,9 +433,6 @@ func TestHMACBatchResultsFields(t *testing.T) { }) require.NoError(t, err) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - keyName := "hmac-test-key" _, err = client.Logical().Write("transit/keys/"+keyName, map[string]interface{}{"type": "hmac", "key_size": 32}) require.NoError(t, err) @@ -505,5 +495,5 @@ func TestHMACBatchResultsFields(t *testing.T) { } // We expect 4 successful requests (2 for batch HMAC generation, 2 for batch verification) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), cores[0].GetInMemoryTransitDataProtectionCallCounts()) } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index 1451aa768c..df0dbea991 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -308,8 +308,9 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d * resp.AddWarning("A provided nonce value was used within FIPS mode, this violates FIPS 140 compliance.") } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit rewrap request count", "error", err.Error()) + } return resp, nil } diff --git a/builtin/logical/transit/path_rewrap_test.go b/builtin/logical/transit/path_rewrap_test.go index a3734da6da..7c8ddbeab0 100644 --- a/builtin/logical/transit/path_rewrap_test.go +++ b/builtin/logical/transit/path_rewrap_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/stretchr/testify/require" ) @@ -19,9 +18,6 @@ func TestTransit_BatchRewrapCase1(t *testing.T) { var err error b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -119,7 +115,7 @@ func TestTransit_BatchRewrapCase1(t *testing.T) { } // We expect 2 successful requests (1 for encrypt, 1 for rewrap) - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Check the normal flow of rewrap with upserted key @@ -128,9 +124,6 @@ func TestTransit_BatchRewrapCase2(t *testing.T) { var err error b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -229,7 +222,7 @@ func TestTransit_BatchRewrapCase2(t *testing.T) { t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2) } // We expect 2 successful transit requests (1 for encrypt, 1 for rewrap) - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Batch encrypt plaintexts, rotate the keys and rewrap all the ciphertexts @@ -237,9 +230,6 @@ func TestTransit_BatchRewrapCase3(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) batchEncryptionInput := []interface{}{ @@ -341,7 +331,7 @@ func TestTransit_BatchRewrapCase3(t *testing.T) { } } // We expect 6 successful transit requests (2 for batch encryption, 2 for batch rewrap, and 2 for decryption) - require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load()) } // TestTransit_BatchRewrapCase4 batch rewrap leveraging RSA padding schemes @@ -349,9 +339,6 @@ func TestTransit_BatchRewrapCase4(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) batchEncryptionInput := []interface{}{ @@ -459,5 +446,5 @@ func TestTransit_BatchRewrapCase4(t *testing.T) { } } // We expect 6 succcessful calls to the transit backend (2 for batch encryption, 2 for batch decryption, and 2 for batch rewrap) - require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load()) } diff --git a/builtin/logical/transit/path_sign_verify.go b/builtin/logical/transit/path_sign_verify.go index 114fc98209..4ebebec5b3 100644 --- a/builtin/logical/transit/path_sign_verify.go +++ b/builtin/logical/transit/path_sign_verify.go @@ -492,8 +492,9 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit sign request count", "error", err.Error()) + } return resp, nil } @@ -750,8 +751,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit sign verify request count", "error", err.Error()) + } return resp, nil } diff --git a/builtin/logical/transit/path_sign_verify_test.go b/builtin/logical/transit/path_sign_verify_test.go index 482f6627dc..c116a2cb3c 100644 --- a/builtin/logical/transit/path_sign_verify_test.go +++ b/builtin/logical/transit/path_sign_verify_test.go @@ -15,7 +15,6 @@ import ( "github.com/hashicorp/vault/helper/constants" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" @@ -33,9 +32,6 @@ type signOutcome struct { } func TestTransit_SignVerify_ECDSA(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - t.Run("256", func(t *testing.T) { testTransit_SignVerify_ECDSA(t, 256) }) @@ -45,9 +41,6 @@ func TestTransit_SignVerify_ECDSA(t *testing.T) { t.Run("521", func(t *testing.T) { testTransit_SignVerify_ECDSA(t, 521) }) - - // Verify the total successful transit requests - require.Equal(t, int64(84), billing.CurrentDataProtectionCallCounts.Transit) } func testTransit_SignVerify_ECDSA(t *testing.T, bits int) { @@ -330,6 +323,7 @@ func testTransit_SignVerify_ECDSA(t *testing.T, bits int) { verifyRequest(req, false, "", sig) // Now try the v1 verifyRequest(req, true, "", v1sig) + require.Equal(t, uint64(28), b.billingDataCounts.Transit.Load()) } func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, expectValid bool, postpath string, b *backend) { @@ -380,9 +374,6 @@ func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, ex // TestTransit_SignVerify_Ed25519Behavior makes sure the options on ENT for a // Ed25519ph/ctx signature fail on CE and ENT if invalid func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, storage := createBackendWithSysView(t) // First create a key @@ -477,17 +468,14 @@ func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) { } // Verify the total successful transit requests if constants.IsEnterprise { - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } else { // We expect 0 successful calls on CE because we expect the verify to fail - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } } func TestTransit_SignVerify_ED25519(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, storage := createBackendWithSysView(t) // First create a key @@ -832,7 +820,7 @@ func TestTransit_SignVerify_ED25519(t *testing.T) { verifyRequest(req, false, outcome, "bar", goodsig, true) // Verify the total successful transit requests - require.Equal(t, int64(24), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(24), b.billingDataCounts.Transit.Load()) } func TestTransit_SignVerify_RSA_PSS(t *testing.T) { diff --git a/sdk/framework/backend.go b/sdk/framework/backend.go index e46ab8b75f..148a65da6a 100644 --- a/sdk/framework/backend.go +++ b/sdk/framework/backend.go @@ -118,6 +118,9 @@ type Backend struct { // communicate with a plugin to activate a feature. ActivationFunc func(context.Context, *logical.Request, string) error + // ConsumptionBillingManager is the consumption billing manager the backend can use to write billing data. + ConsumptionBillingManager logical.ConsumptionBillingManager + logger log.Logger system logical.SystemView events logical.EventSender @@ -440,6 +443,11 @@ func (b *Backend) Setup(ctx context.Context, config *logical.BackendConfig) erro b.system = config.System b.events = config.EventsSender b.observations = config.ObservationRecorder + if b.System() != nil && b.System().GetConsumptionBillingManager() != nil { + b.ConsumptionBillingManager = b.System().GetConsumptionBillingManager() + } else { + b.ConsumptionBillingManager = logical.NewNullConsumptionBillingManager() + } return nil } @@ -546,7 +554,7 @@ func (b *Backend) init() { for i, p := range b.Paths { // Detect the coding error of failing to initialise Pattern if len(p.Pattern) == 0 { - panic(fmt.Sprintf("Routing pattern cannot be blank")) + panic("Routing pattern cannot be blank") } // Detect the coding error of attempting to define a CreateOperation without defining an ExistenceCheck diff --git a/sdk/logical/billing_system_view.go b/sdk/logical/billing_system_view.go new file mode 100644 index 0000000000..caa7b83e8c --- /dev/null +++ b/sdk/logical/billing_system_view.go @@ -0,0 +1,25 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +package logical + +import "context" + +// billing.ConsumptionBillingManager is an implementation of this interface that the backend can use to write billing data. +type ConsumptionBillingManager interface { + WriteBillingData(ctx context.Context, pluginType string, data map[string]interface{}) error +} + +// ================================ +// Creates a null consumption billing manager that does nothing +var _ ConsumptionBillingManager = (*nullConsumptionBillingManager)(nil) + +func NewNullConsumptionBillingManager() ConsumptionBillingManager { + return &nullConsumptionBillingManager{} +} + +type nullConsumptionBillingManager struct{} + +func (n *nullConsumptionBillingManager) WriteBillingData(ctx context.Context, pluginType string, data map[string]interface{}) error { + return nil +} diff --git a/sdk/logical/system_view.go b/sdk/logical/system_view.go index 89e71563a5..98562c3a78 100644 --- a/sdk/logical/system_view.go +++ b/sdk/logical/system_view.go @@ -117,6 +117,9 @@ type SystemView interface { // ExtractVerifyPlugin extracts and verifies the plugin artifact DownloadExtractVerifyPlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error + + // GetConsumptionBillingManager returns the consumption billing manager + GetConsumptionBillingManager() ConsumptionBillingManager } type PasswordPolicy interface { @@ -320,6 +323,10 @@ func (d StaticSystemView) DownloadExtractVerifyPlugin(_ context.Context, _ *plug return errors.New("DownloadExtractVerifyPlugin is not implemented in StaticSystemView") } +func (d StaticSystemView) GetConsumptionBillingManager() ConsumptionBillingManager { + return nil +} + // PluginLicenseUtil defines the functions needed to request License and PluginEnv // by the plugin licensing under github.com/hashicorp/vault-licensing // This only should be used by the plugin to get the license and plugin environment diff --git a/sdk/plugin/grpc_system.go b/sdk/plugin/grpc_system.go index 726f61262c..b45dfc1a2f 100644 --- a/sdk/plugin/grpc_system.go +++ b/sdk/plugin/grpc_system.go @@ -36,6 +36,11 @@ type gRPCSystemViewClient struct { client pb.SystemViewClient } +func (s *gRPCSystemViewClient) GetConsumptionBillingManager() logical.ConsumptionBillingManager { + // Not implemented on pluginbackend + return nil +} + func (s *gRPCSystemViewClient) DefaultLeaseTTL() time.Duration { reply, err := s.client.DefaultLeaseTTL(context.Background(), &pb.Empty{}) if err != nil { diff --git a/vault/billing/billing_counts.go b/vault/billing/billing_counts.go index d92bfe4d2e..ced2785d6a 100644 --- a/vault/billing/billing_counts.go +++ b/vault/billing/billing_counts.go @@ -4,19 +4,24 @@ package billing import ( + "context" "fmt" "sync" + "sync/atomic" "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/logical" ) const ( - BillingSubPath = "billing/" - ReplicatedPrefix = "replicated/" - RoleHWMCountsHWM = "maxRoleCounts/" - KvHWMCountsHWM = "maxKvCounts/" - DataProtectionCallCountsMetric = "dataProtectionCallCounts/" - LocalPrefix = "local/" - ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/" + BillingSubPath = "billing/" + ReplicatedPrefix = "replicated/" + RoleHWMCountsHWM = "maxRoleCounts/" + KvHWMCountsHWM = "maxKvCounts/" + TransitDataProtectionCallCountsPrefix = "transitDataProtectionCallCounts/" + LocalPrefix = "local/" + ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/" BillingWriteInterval = 10 * time.Minute ) @@ -27,7 +32,9 @@ type ConsumptionBilling struct { // BillingStorageLock controls access to the billing storage paths BillingStorageLock sync.RWMutex - BillingConfig BillingConfig + BillingConfig BillingConfig + DataProtectionCallCounts DataProtectionCallCounts + Logger log.Logger } type BillingConfig struct { @@ -45,10 +52,26 @@ func GetMonthlyBillingPath(localPrefix string, now time.Time, billingMetric stri } type DataProtectionCallCounts struct { - Transit int64 `json:"transit,omitempty"` + Transit *atomic.Uint64 `json:"transit,omitempty"` // TODO: Uncomment when we add support for Transform tracking (VAULT-41205) - // Transform int64 `json:"transform,omitempty"` + // Transform atomic.int64 `json:"transform,omitempty"` } -// Global counter for all data protection calls on this cluster -var CurrentDataProtectionCallCounts = DataProtectionCallCounts{} +var _ logical.ConsumptionBillingManager = (*ConsumptionBilling)(nil) + +func (s *ConsumptionBilling) WriteBillingData(ctx context.Context, mountType string, data map[string]interface{}) error { + switch mountType { + case "transit": + val, ok := data["count"].(uint64) + if !ok { + err := fmt.Errorf("invalid value type for transit") + return err + } + + s.DataProtectionCallCounts.Transit.Add(val) + default: + err := fmt.Errorf("unknown metric type: %s", mountType) + return err + } + return nil +} diff --git a/vault/consumption_billing.go b/vault/consumption_billing.go index 4c7bdd6c53..ef1c28bdb4 100644 --- a/vault/consumption_billing.go +++ b/vault/consumption_billing.go @@ -5,6 +5,8 @@ package vault import ( "context" + "fmt" + "sync/atomic" "time" "github.com/hashicorp/vault/helper/timeutil" @@ -15,8 +17,14 @@ func (c *Core) setupConsumptionBilling(ctx context.Context) error { // We need replication (post unseal) to start before we run the consumption billing metrics worker // This is because there is primary/secondary cluster specific logic c.consumptionBillingLock.Lock() + logger := c.baseLogger.Named("billing") + c.AddLogger(logger) c.consumptionBilling = &billing.ConsumptionBilling{ BillingConfig: c.billingConfig, + DataProtectionCallCounts: billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + }, + Logger: logger, } c.consumptionBillingLock.Unlock() c.postUnsealFuncs = append(c.postUnsealFuncs, func() { @@ -72,7 +80,12 @@ func (c *Core) updateBillingMetrics(ctx context.Context) error { c.UpdateReplicatedHWMMetrics(ctx, currentMonth) } c.UpdateLocalHWMMetrics(ctx, currentMonth) - c.UpdateLocalAggregatedMetrics(ctx, currentMonth) + if err := c.UpdateLocalAggregatedMetrics(ctx, currentMonth); err != nil { + c.logger.Error("error updating cluster data protection call counts", "error", err) + } else { + c.logger.Info("updated cluster data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth) + } + } return nil } @@ -119,10 +132,7 @@ func (c *Core) UpdateLocalHWMMetrics(ctx context.Context, currentMonth time.Time func (c *Core) UpdateLocalAggregatedMetrics(ctx context.Context, currentMonth time.Time) error { if _, err := c.UpdateDataProtectionCallCounts(ctx, currentMonth); err != nil { - c.logger.Error("error updating local max data protection call counts", "error", err) - } else { - c.logger.Info("updated local max data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth) + return fmt.Errorf("could not store transit data protection call counts: %w", err) } - return nil } diff --git a/vault/consumption_billing_testing_util.go b/vault/consumption_billing_testing_util.go new file mode 100644 index 0000000000..4b4b8beab3 --- /dev/null +++ b/vault/consumption_billing_testing_util.go @@ -0,0 +1,12 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +package vault + +func (c *Core) ResetInMemoryDataProtectionCallCounts() { + c.consumptionBilling.DataProtectionCallCounts.Transit.Store(0) +} + +func (c *Core) GetInMemoryTransitDataProtectionCallCounts() uint64 { + return c.consumptionBilling.DataProtectionCallCounts.Transit.Load() +} diff --git a/vault/consumption_billing_util.go b/vault/consumption_billing_util.go index e6656a3145..09312ccf84 100644 --- a/vault/consumption_billing_util.go +++ b/vault/consumption_billing_util.go @@ -267,29 +267,36 @@ func (c *Core) GetBillingSubView() *BarrierView { // storeDataProtectionCallCountsLocked must be called with BillingStorageLock held func (c *Core) storeDataProtectionCallCountsLocked(ctx context.Context, maxCounts *billing.DataProtectionCallCounts, localPathPrefix string, month time.Time) error { - billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric) - entry, err := logical.StorageEntryJSON(billingPath, maxCounts) - if err != nil { - return err + // Store count for each data protection type separately because they are atomic counters + billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.TransitDataProtectionCallCountsPrefix) + transitCount := maxCounts.Transit.Load() + entry := &logical.StorageEntry{ + Key: billingPath, + Value: []byte(strconv.FormatUint(transitCount, 10)), } return c.GetBillingSubView().Put(ctx, entry) } // getStoredDataProtectionCallCountsLocked must be called with BillingStorageLock held func (c *Core) getStoredDataProtectionCallCountsLocked(ctx context.Context, localPathPrefix string, month time.Time) (*billing.DataProtectionCallCounts, error) { - billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric) + // Retrieve count for each data protection type separately because they are atomic counters + ret := &billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } + billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.TransitDataProtectionCallCountsPrefix) entry, err := c.GetBillingSubView().Get(ctx, billingPath) if err != nil { return nil, err } if entry == nil { - return &billing.DataProtectionCallCounts{}, nil + return ret, nil } - var maxCounts billing.DataProtectionCallCounts - if err := entry.DecodeJSON(&maxCounts); err != nil { + transitCount, err := strconv.ParseUint(string(entry.Value), 10, 64) + if err != nil { return nil, err } - return &maxCounts, nil + ret.Transit.Store(transitCount) + return ret, nil } func (c *Core) GetStoredDataProtectionCallCounts(ctx context.Context, month time.Time) (*billing.DataProtectionCallCounts, error) { @@ -302,10 +309,6 @@ func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth c.consumptionBilling.BillingStorageLock.Lock() defer c.consumptionBilling.BillingStorageLock.Unlock() - // Read and atomically reset the current counter value - // TODO: Reset Transform call counts too (VAULT-41205) - currentTransitCount := atomic.SwapInt64(&billing.CurrentDataProtectionCallCounts.Transit, 0) - storedDataProtectionCallCounts, err := c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, currentMonth) if err != nil { return nil, err @@ -315,7 +318,8 @@ func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth } // Sum the current count with the stored count - storedDataProtectionCallCounts.Transit += currentTransitCount + transitCount := c.consumptionBilling.DataProtectionCallCounts.Transit.Swap(0) + storedDataProtectionCallCounts.Transit.Add(transitCount) // TODO: Update Transform call counts (VAULT-41205) err = c.storeDataProtectionCallCountsLocked(ctx, storedDataProtectionCallCounts, billing.LocalPrefix, currentMonth) diff --git a/vault/consumption_billing_util_oss_test.go b/vault/consumption_billing_util_oss_test.go index effde53104..0b1a69ae45 100644 --- a/vault/consumption_billing_util_oss_test.go +++ b/vault/consumption_billing_util_oss_test.go @@ -42,7 +42,7 @@ func verifyExpectedRoleCounts(t *testing.T, actual *RoleCounts, baseCount int) { // testCMACOperations is a no-op in OSS since CMAC is an Enterprise-only feature. // Returns the current count unchanged. -func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount int64) int64 { +func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount uint64) uint64 { // CMAC is not supported in OSS, so we don't perform any operations return currentCount } diff --git a/vault/consumption_billing_util_test.go b/vault/consumption_billing_util_test.go index a438c1b0c2..fd0810bfd8 100644 --- a/vault/consumption_billing_util_test.go +++ b/vault/consumption_billing_util_test.go @@ -420,9 +420,6 @@ func TestDataProtectionCallCounts(t *testing.T) { _, err := core.HandleRequest(ctx, req) require.NoError(t, err) - // Reset the transit counters - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Create an encryption key req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/foo") req.Data["type"] = "aes256-gcm96" @@ -439,8 +436,8 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp) require.NotNil(t, resp.Data) - // Verify that the transit counter is incremented (replicated mount by default) - require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) + // Verify that the transit counter is incremented + require.Equal(t, uint64(1), core.GetInMemoryTransitDataProtectionCallCounts()) // Get the ciphertext from the encryption response ciphertext, ok := resp.Data["ciphertext"].(string) @@ -455,7 +452,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NoError(t, err) // Verify that the transit counter is incremented - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), core.GetInMemoryTransitDataProtectionCallCounts()) // Test rewrap operation req = logical.TestRequest(t, logical.UpdateOperation, "transit/rewrap/foo") @@ -467,7 +464,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(3), core.GetInMemoryTransitDataProtectionCallCounts()) // Get the new ciphertext from rewrap newCiphertext, ok := resp.Data["ciphertext"].(string) @@ -483,9 +480,9 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), core.GetInMemoryTransitDataProtectionCallCounts()) - // Test HMAC generation + // Test HMAC operation req = logical.TestRequest(t, logical.UpdateOperation, "transit/hmac/foo") req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" req.ClientToken = root @@ -495,7 +492,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(5), core.GetInMemoryTransitDataProtectionCallCounts()) // Get the HMAC value hmacValue, ok := resp.Data["hmac"].(string) @@ -513,7 +510,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(6), core.GetInMemoryTransitDataProtectionCallCounts()) // Verify the HMAC is valid hmacValid, ok := resp.Data["valid"].(bool) @@ -537,7 +534,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(7), core.GetInMemoryTransitDataProtectionCallCounts()) // Get the signature signature, ok := resp.Data["signature"].(string) @@ -555,7 +552,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(8), core.GetInMemoryTransitDataProtectionCallCounts()) // Verify the signature is valid signatureValid, ok := resp.Data["valid"].(bool) @@ -563,27 +560,27 @@ func TestDataProtectionCallCounts(t *testing.T) { require.True(t, signatureValid) // Test CMAC operations (ENT only - will be no-op in OSS) - currentCount := billing.CurrentDataProtectionCallCounts.Transit + currentCount := core.GetInMemoryTransitDataProtectionCallCounts() currentCount = testCMACOperations(t, core, ctx, root, currentCount) // Verify that the transit counter matches expected count - require.Equal(t, currentCount, billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, currentCount, core.GetInMemoryTransitDataProtectionCallCounts()) // Now test persisting the summed counts - store and retrieve counts // First, update the data protection call counts (this will sum current counter with stored value) summedCounts, err := core.UpdateDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) require.NotNil(t, summedCounts) - require.Equal(t, currentCount, summedCounts.Transit) + require.Equal(t, currentCount, summedCounts.Transit.Load()) // Verify the counter was reset after update - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update") // Retrieve the stored counts storedCounts, err := core.GetStoredDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) require.NotNil(t, storedCounts) - require.Equal(t, currentCount, storedCounts.Transit) + require.Equal(t, currentCount, storedCounts.Transit.Load()) // Perform more operations to increase the counter req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo") @@ -593,21 +590,21 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NoError(t, err) // Counter should now be 1 (reset + 1 operation) - require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(1), core.GetInMemoryTransitDataProtectionCallCounts()) // Update counts again - should sum the new count (1) with the stored count (currentCount) summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) expectedSum := currentCount + 1 - require.Equal(t, expectedSum, summedCounts.Transit, "Count should be sum of stored and current") + require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should be sum of stored and current") // Verify the counter was reset after update - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update") // Verify stored counts are now the sum storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) - require.Equal(t, expectedSum, storedCounts.Transit) + require.Equal(t, expectedSum, storedCounts.Transit.Load()) // Add more operations without manually resetting for i := 0; i < 3; i++ { @@ -619,35 +616,35 @@ func TestDataProtectionCallCounts(t *testing.T) { } // Counter should be 3 - require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(3), core.GetInMemoryTransitDataProtectionCallCounts()) // Update counts - should sum 3 with the previous stored sum summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) expectedSum = expectedSum + 3 - require.Equal(t, expectedSum, summedCounts.Transit, "Count should continue to sum") + require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should continue to sum") // Verify the counter was reset after update - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update") // Verify stored counts storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) - require.Equal(t, expectedSum, storedCounts.Transit) + require.Equal(t, expectedSum, storedCounts.Transit.Load()) // Update again without any new operations // This verifies we don't double-count summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) - require.Equal(t, expectedSum, summedCounts.Transit, "Count should remain the same when no new operations occurred") + require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should remain the same when no new operations occurred") // Verify stored counts haven't changed storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) - require.Equal(t, expectedSum, storedCounts.Transit, "Stored count should remain the same") + require.Equal(t, expectedSum, storedCounts.Transit.Load(), "Stored count should remain the same") // Verify counter is still at 0 - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should still be 0") + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should still be 0") } func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) { diff --git a/vault/core_util_common.go b/vault/core_util_common.go index d15f51e0f1..51608ef000 100644 --- a/vault/core_util_common.go +++ b/vault/core_util_common.go @@ -78,3 +78,7 @@ func (c *Core) setupHeaderHMACKey(ctx context.Context, isPerfStandby bool) error func (c *Core) GetPkiCertificateCounter() logical.CertificateCounter { return c.pkiCertCountManager } + +func (c *Core) GetConsumptionBillingManager() logical.ConsumptionBillingManager { + return c.consumptionBilling +} diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 74530d9b09..405ad60585 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -401,3 +401,7 @@ func (d dynamicSystemView) DeregisterRotationJob(ctx context.Context, req *rotat return d.core.DeregisterRotationJob(nsCtx, req) } + +func (d dynamicSystemView) GetConsumptionBillingManager() logical.ConsumptionBillingManager { + return d.core.GetConsumptionBillingManager() +}