From b3f173756d12ee9cb7df2b4feb3ecdfafac59f25 Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Tue, 3 Feb 2026 17:39:49 -0500 Subject: [PATCH 1/3] actions: pin to latest actions (#12144) (#12146) Update to the latest actions. The primary motivation here is to get the latest action-setup-enos. - actions/cache => v5.0.3: security patches - actions/checkout => v6.0.2: small fixes to git user-agent and tag fetching - hashicorp/action-setup-enos => v1.50: security patches Signed-off-by: Ryan Cragun Co-authored-by: Ryan Cragun --- .github/actions/build-vault/action.yml | 2 +- .github/actions/create-dynamic-config/action.yml | 2 +- .github/actions/install-tools/action.yml | 2 +- .github/actions/set-up-go/action.yml | 2 +- .github/actions/set-up-pipeline/action.yml | 2 +- .github/workflows/build.yml | 2 +- .github/workflows/enos-lint.yml | 2 +- .github/workflows/test-enos-scenario-ui.yml | 2 +- .github/workflows/test-go.yml | 4 ++-- .github/workflows/test-run-enos-scenario-containers.yml | 4 ++-- .github/workflows/test-run-enos-scenario-matrix.yml | 4 ++-- .github/workflows/test-run-enos-scenario.yml | 2 +- 12 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/actions/build-vault/action.yml b/.github/actions/build-vault/action.yml index 7184931297..616631b282 100644 --- a/.github/actions/build-vault/action.yml +++ b/.github/actions/build-vault/action.yml @@ -69,7 +69,7 @@ runs: shell: bash run: git config --global url."https://${{ inputs.github-token }}:@github.com".insteadOf "https://github.com" - name: Restore UI from cache - uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2 + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 with: # Restore the UI asset from the UI build workflow. Never use a partial restore key. enableCrossOsArchive: true diff --git a/.github/actions/create-dynamic-config/action.yml b/.github/actions/create-dynamic-config/action.yml index 7c71a915f1..8a7999f895 100644 --- a/.github/actions/create-dynamic-config/action.yml +++ b/.github/actions/create-dynamic-config/action.yml @@ -39,7 +39,7 @@ runs: } | tee -a "$GITHUB_ENV" - name: Try to restore dynamic config from cache id: dyn-cfg-cache - uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2 + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 with: path: ${{ env.DYNAMIC_CONFIG_PATH }} key: dyn-cfg-${{ env.DYNAMIC_CONFIG_KEY }} diff --git a/.github/actions/install-tools/action.yml b/.github/actions/install-tools/action.yml index 0b786060a7..7274562fa2 100644 --- a/.github/actions/install-tools/action.yml +++ b/.github/actions/install-tools/action.yml @@ -69,7 +69,7 @@ runs: echo "VAULT_TOOLS_CACHE_KEY=${cache_key}" } | tee -a "$GITHUB_ENV" - id: cache-tools - uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2 + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 with: lookup-only: ${{ inputs.no-restore }} path: ${{ env.VAULT_TOOLS_PATH }} diff --git a/.github/actions/set-up-go/action.yml b/.github/actions/set-up-go/action.yml index f0d2be66a6..f30bc3ea7c 100644 --- a/.github/actions/set-up-go/action.yml +++ b/.github/actions/set-up-go/action.yml @@ -63,7 +63,7 @@ runs: echo "cache-key=go-modules-${wd_hash}-${{ hashFiles('**/go.sum') }}" } | tee -a "$GITHUB_OUTPUT" - id: cache-modules - uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2 + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 with: enableCrossOsArchive: true lookup-only: ${{ inputs.no-restore }} diff --git a/.github/actions/set-up-pipeline/action.yml b/.github/actions/set-up-pipeline/action.yml index 9ece8f1f61..2206fc7d75 100644 --- a/.github/actions/set-up-pipeline/action.yml +++ b/.github/actions/set-up-pipeline/action.yml @@ -33,7 +33,7 @@ runs: } | tee -a "$GITHUB_ENV" - name: Try to restore pipeline from cache id: pipeline-cache - uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2 + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 with: path: ${{ env.PIPELINE_PATH }} key: pipeline-${{ env.PIPELINE_HASH }} diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 61db9c7a78..6808cbc484 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -303,7 +303,7 @@ jobs: run: echo "ui-hash=$(git ls-tree HEAD ui --object-only)" | tee -a "$GITHUB_OUTPUT" - name: Set up UI asset cache id: cache-ui-assets - uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2 + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 with: enableCrossOsArchive: true lookup-only: true diff --git a/.github/workflows/enos-lint.yml b/.github/workflows/enos-lint.yml index ea44e07d0f..4dad11e45d 100644 --- a/.github/workflows/enos-lint.yml +++ b/.github/workflows/enos-lint.yml @@ -45,7 +45,7 @@ jobs: - uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2 with: terraform_wrapper: false - - uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40 + - uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50 - name: Ensure shellcheck is available for linting run: which shellcheck || (sudo apt update && sudo apt install -y shellcheck) - name: lint diff --git a/.github/workflows/test-enos-scenario-ui.yml b/.github/workflows/test-enos-scenario-ui.yml index 53312b8c65..a43b20d1fb 100644 --- a/.github/workflows/test-enos-scenario-ui.yml +++ b/.github/workflows/test-enos-scenario-ui.yml @@ -82,7 +82,7 @@ jobs: - uses: ./.github/actions/set-up-go with: github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }} - - uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40 + - uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50 with: github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }} - name: Set Up Git diff --git a/.github/workflows/test-go.yml b/.github/workflows/test-go.yml index 687ba95c4f..4b51d22a10 100644 --- a/.github/workflows/test-go.yml +++ b/.github/workflows/test-go.yml @@ -145,7 +145,7 @@ jobs: - uses: ./.github/actions/install-tools # for gotestsum - run: mkdir -p ${{ steps.local-metadata.outputs.go-test-dir }} - if: inputs.test-timing-cache-restore || inputs.test-timing-cache-save - uses: actions/cache/restore@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2 + uses: actions/cache/restore@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 with: path: ${{ steps.local-metadata.outputs.go-test-dir }} key: ${{ inputs.test-timing-cache-key }}-${{ github.run_number }} @@ -647,7 +647,7 @@ jobs: } | tee -a "$GITHUB_OUTPUT" # Aggregate, prune, and cache our timing data - if: ${{ ! cancelled() && needs.test-go.result == 'success' && inputs.test-timing-cache-save }} - uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2 + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 with: path: ${{ needs.test-matrix.outputs.go-test-dir }} key: ${{ inputs.test-timing-cache-key }}-${{ github.run_number }} diff --git a/.github/workflows/test-run-enos-scenario-containers.yml b/.github/workflows/test-run-enos-scenario-containers.yml index 16d7f866af..4e8dd5fc21 100644 --- a/.github/workflows/test-run-enos-scenario-containers.yml +++ b/.github/workflows/test-run-enos-scenario-containers.yml @@ -44,7 +44,7 @@ jobs: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ inputs.vault-revision }} - - uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40 + - uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50 with: github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }} - uses: ./.github/actions/metadata @@ -87,7 +87,7 @@ jobs: # the Terraform wrapper will break Terraform execution in Enos because # it changes the output to text when we expect it to be JSON. terraform_wrapper: false - - uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40 + - uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50 with: github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }} - name: Download Docker Image diff --git a/.github/workflows/test-run-enos-scenario-matrix.yml b/.github/workflows/test-run-enos-scenario-matrix.yml index 47df34a0ef..f3ebc3517b 100644 --- a/.github/workflows/test-run-enos-scenario-matrix.yml +++ b/.github/workflows/test-run-enos-scenario-matrix.yml @@ -70,7 +70,7 @@ jobs: token: ${{ steps.vault-auth.outputs.token }} secrets: | kv/data/github/${{ github.repository }}/github-token token | ELEVATED_GITHUB_TOKEN; - - uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40 + - uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50 with: github-token: ${{ github.repository == 'hashicorp/vault' && secrets.ELEVATED_GITHUB_TOKEN || steps.vault-secrets.outputs.ELEVATED_GITHUB_TOKEN }} - uses: ./.github/actions/create-dynamic-config @@ -214,7 +214,7 @@ jobs: role-to-assume: ${{ steps.secrets.outputs.aws-role-arn }} role-skip-session-tagging: true role-duration-seconds: 3600 - - uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40 + - uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50 with: github-token: ${{ steps.secrets.outputs.github-token }} - uses: ./.github/actions/create-dynamic-config diff --git a/.github/workflows/test-run-enos-scenario.yml b/.github/workflows/test-run-enos-scenario.yml index efa70493a3..5d8949d5bc 100644 --- a/.github/workflows/test-run-enos-scenario.yml +++ b/.github/workflows/test-run-enos-scenario.yml @@ -91,7 +91,7 @@ jobs: role-to-assume: ${{ secrets.AWS_ROLE_ARN_CI }} role-skip-session-tagging: true role-duration-seconds: 3600 - - uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40 + - uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50 with: github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }} - name: Prepare scenario dependencies From caf642b7d21173ecd3b62137317551e7cf8941a2 Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Tue, 3 Feb 2026 17:48:12 -0500 Subject: [PATCH 2/3] Backport Vault 42177 Add Backend Field into ce/main (#12152) * Vault 42177 Add Backend Field (#12092) * add a new struct for the total number of successful requests for transit and transform * implement tracking for encrypt path * implement tracking in encrypt path * add tracking in rewrap * add tracking to datakey path * add tracking to hmac path * add tracking to sign path * add tracking to verify path * unit tests for verify path * add tracking to cmac path * reset the global counter in each unit test * add tracking to hmac verify * add methods to retrieve and flush transit count * modify the methods that store and update data protection call counts * update the methods * add a helper method to combine replicated and local data call counts * add tracking to the endpoint * fix some formatting errors * add unit tests to path encrypt for tracking * add unit tests to decrypt path * fix linter error * add unit tests to test update and store methods for data protection calls * stub fix: do not create separate files * fix the tracking by coordinating replicated and local data, add unit tests * update all reference to the new data struct * revert to previous design with just one global counter for all calls for each cluster * complete external test * no need to check if current count is greater than 0, remove it * feedback: remove unnacassary comments about atomic addition, standardize comments * leave jira id on todo comment, remove unused method * rename mathods by removing HWM and max in names, update jira id in todo comment, update response field key name * feedback: remove explicit counter in cmac tests, instead put in the expected number * feedback: remove explicit tracking in the rest of the tests * feedback: separate transit testing into its own external test * Update vault/consumption_billing_util_test.go Co-authored-by: divyaac * update comment after test name change * fix comments * fix comments in test * another comment fix * feedback: remove incorrect comment * fix a CE test * fix the update method: instead of storing max, increment by the current count value * update the unit test, remove local prefix as argument to the methods since we store only to non-replicated paths * update the external test * Adds a field to backend to track billing data removed file * Changed implementation to use a map instead * Some more comments * Add more implementation * Edited grpc server backend * Refactored a bit * Fix one more test * Modified map: * Revert "Modified map:" This reverts commit 1730fe1f358b210e6abae43fbdca09e585aaaaa8. * Removed some other things * Edited consumption billing files a bit * Testing function * Fix transit stuff and make sure tests pass * Changes * More changes * More changes * Edited external test * Edited some more tests * Edited and fixed tests * One more fix * Fix some more tests * Moved some testing structures around and added error checking * Fixed some nits * Update builtin/logical/transit/path_sign_verify.go Co-authored-by: Nick Cabatoff * Edited some errors * Fixed error logs * Edited one more thing * Decorate the error * Update vault/consumption_billing.go Co-authored-by: Nick Cabatoff --------- Co-authored-by: Amir Aslamov Co-authored-by: Nick Cabatoff * Edited stub function --------- Co-authored-by: divyaac Co-authored-by: Amir Aslamov Co-authored-by: Nick Cabatoff Co-authored-by: divyaac --- builtin/logical/transit/backend.go | 20 +++-- builtin/logical/transit/backend_test.go | 14 ++++ builtin/logical/transit/path_datakey.go | 4 +- builtin/logical/transit/path_datakey_test.go | 11 +-- builtin/logical/transit/path_decrypt.go | 5 +- builtin/logical/transit/path_decrypt_test.go | 16 ++-- builtin/logical/transit/path_encrypt.go | 4 +- builtin/logical/transit/path_encrypt_test.go | 80 ++++--------------- builtin/logical/transit/path_hmac.go | 10 ++- builtin/logical/transit/path_hmac_test.go | 16 +--- builtin/logical/transit/path_rewrap.go | 5 +- builtin/logical/transit/path_rewrap_test.go | 21 +---- builtin/logical/transit/path_sign_verify.go | 10 ++- .../logical/transit/path_sign_verify_test.go | 20 +---- sdk/framework/backend.go | 10 ++- sdk/logical/billing_system_view.go | 25 ++++++ sdk/logical/system_view.go | 7 ++ sdk/plugin/grpc_system.go | 5 ++ vault/billing/billing_counts.go | 47 ++++++++--- vault/consumption_billing.go | 20 +++-- vault/consumption_billing_testing_util.go | 12 +++ vault/consumption_billing_util.go | 32 ++++---- vault/consumption_billing_util_oss_test.go | 2 +- vault/consumption_billing_util_test.go | 55 ++++++------- vault/core_util_common.go | 4 + vault/dynamic_system_view.go | 4 + 26 files changed, 249 insertions(+), 210 deletions(-) create mode 100644 sdk/logical/billing_system_view.go create mode 100644 vault/consumption_billing_testing_util.go diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index 971ac7e2ca..662d0bf2b6 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -10,7 +10,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -113,7 +112,6 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error) if err != nil { return nil, err } - b.setupEnt() return &b, nil @@ -124,6 +122,10 @@ type backend struct { entBackend lm *keysutil.LockManager + // billingDataCounts tracks successful data protection operations + // for this backend instance. It's intended for test assertions and avoids + // cross-test/package contamination from global counters. + billingDataCounts billing.DataProtectionCallCounts // Lock to make changes to any of the backend's cache configuration. configMutex sync.RWMutex cacheSizeChanged bool @@ -148,9 +150,17 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error return size, nil } -// incrementDataProtectionCounter atomically increments the data protection call counter to avoid race conditions -func (b *backend) incrementDataProtectionCounter(count int64) { - atomic.AddInt64(&billing.CurrentDataProtectionCallCounts.Transit, count) +// incrementBillingCounts atomically increments the transit billing data counts +func (b *backend) incrementBillingCounts(ctx context.Context, count uint64) error { + // If we are a test, we need to increment this testing structure to verify the counts are correct. + if b.billingDataCounts.Transit != nil { + b.billingDataCounts.Transit.Add(count) + } + + // Write billling data + return b.ConsumptionBillingManager.WriteBillingData(ctx, "transit", map[string]interface{}{ + "count": count, + }) } // Update cache size and get policy diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index f99df68ba1..1d6ffcd0b1 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -20,6 +20,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -33,6 +34,7 @@ import ( "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -53,6 +55,9 @@ func createBackendWithStorage(t testing.TB) (*backend, logical.Storage) { if err != nil { t.Fatal(err) } + b.billingDataCounts = billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } return b, config.StorageView } @@ -74,6 +79,9 @@ func createBackendWithSysView(t testing.TB) (*backend, logical.Storage) { if err != nil { t.Fatal(err) } + b.billingDataCounts = billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } return b, storage } @@ -95,6 +103,9 @@ func createBackendWithSysViewWithStorage(t testing.TB, s logical.Storage) *backe if err != nil { t.Fatal(err) } + b.billingDataCounts = billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } return b } @@ -117,6 +128,9 @@ func createBackendWithForceNoCacheWithSysViewWithStorage(t testing.TB, s logical if err != nil { t.Fatal(err) } + b.billingDataCounts = billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } return b } diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go index f8b851c11e..18b6db37ea 100644 --- a/builtin/logical/transit/path_datakey.go +++ b/builtin/logical/transit/path_datakey.go @@ -172,7 +172,9 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d // Increment the counter for successful operations // Since there are not batched operations, we can add one successful // request to the transit request counter. - b.incrementDataProtectionCounter(1) + if err = b.incrementBillingCounts(ctx, 1); err != nil { + b.Logger().Error("failed to track transit data key request count", "error", err.Error()) + } return resp, nil } diff --git a/builtin/logical/transit/path_datakey_test.go b/builtin/logical/transit/path_datakey_test.go index c8150afc69..2ed1c640d0 100644 --- a/builtin/logical/transit/path_datakey_test.go +++ b/builtin/logical/transit/path_datakey_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -16,9 +15,6 @@ import ( // TestDataKeyWithPaddingScheme validates that we properly leverage padding scheme // args for the returned keys func TestDataKeyWithPaddingScheme(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) keyName := "test" createKeyReq := &logical.Request{ @@ -95,15 +91,12 @@ func TestDataKeyWithPaddingScheme(t *testing.T) { }) } // We expect 8 successful requests ((3 valid cases x 2 operations) + (2 invalid cases x 1 operation)) - require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(8), b.billingDataCounts.Transit.Load()) } // TestDataKeyWithPaddingSchemeInvalidKeyType validates we fail when we specify a // padding_scheme value on an invalid key type (non-RSA) func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) keyName := "test" createKeyReq := &logical.Request{ @@ -132,5 +125,5 @@ func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) { require.NotNil(t, resp, "response should not be nil") require.Contains(t, resp.Error().Error(), "padding_scheme argument invalid: unsupported key") // We expect 0 successful requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index f25275d352..08c20fc0a3 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -278,8 +278,9 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit decrypt request count", "error", err.Error()) + } return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch) } diff --git a/builtin/logical/transit/path_decrypt_test.go b/builtin/logical/transit/path_decrypt_test.go index b2f8046f16..ecf0139198 100644 --- a/builtin/logical/transit/path_decrypt_test.go +++ b/builtin/logical/transit/path_decrypt_test.go @@ -12,7 +12,6 @@ import ( "github.com/hashicorp/vault/sdk/helper/jsonutil" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -23,9 +22,6 @@ func TestTransit_BatchDecryption(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchEncryptionInput := []interface{}{ map[string]interface{}{"plaintext": "", "reference": "foo"}, // empty string map[string]interface{}{"plaintext": "Cg==", "reference": "bar"}, // newline @@ -71,12 +67,15 @@ func TestTransit_BatchDecryption(t *testing.T) { expectedResult := "[{\"plaintext\":\"\",\"reference\":\"foo\"},{\"plaintext\":\"Cg==\",\"reference\":\"bar\"},{\"plaintext\":\"dGhlIHF1aWNrIGJyb3duIGZveA==\",\"reference\":\"baz\"}]" jsonResponse, err := json.Marshal(batchDecryptionResponseItems) - if err != nil || err == nil && string(jsonResponse) != expectedResult { + if err != nil { + t.Fatalf("bad: failed to marshal response items: err=%v json=%s", err, jsonResponse) + } + if string(jsonResponse) != expectedResult { t.Fatalf("bad: expected json response [%s]", jsonResponse) } // We expect 6 successful requests (3 for batch encryption, 3 for batch decryption) - require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load()) } func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { @@ -86,9 +85,6 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Create a derived key. req = &logical.Request{ Operation: logical.UpdateOperation, @@ -291,5 +287,5 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) { } // We expect 7 successful requests (2 for batch encryption + 1 single-item decryption + 2 batch decryption + 2 batch decryption) - require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(7), b.billingDataCounts.Transit.Load()) } diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index 4ebaf2e2b1..f568bf64ec 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -650,7 +650,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d } // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit encrypt request count", "error", err.Error()) + } return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch) } diff --git a/builtin/logical/transit/path_encrypt_test.go b/builtin/logical/transit/path_encrypt_test.go index 05ab5e2fd3..526fe9956f 100644 --- a/builtin/logical/transit/path_encrypt_test.go +++ b/builtin/logical/transit/path_encrypt_test.go @@ -15,7 +15,6 @@ import ( uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" ) @@ -24,9 +23,6 @@ func TestTransit_MissingPlaintext(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) // Create the policy @@ -51,16 +47,13 @@ func TestTransit_MissingPlaintext(t *testing.T) { t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp) } // We expect 0 successful calls - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } func TestTransit_MissingPlaintextInBatchInput(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) // Create the policy @@ -92,7 +85,7 @@ func TestTransit_MissingPlaintextInBatchInput(t *testing.T) { t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp) } // We expect 0 successful calls - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case1: Ensure that batch encryption did not affect the normal flow of @@ -101,9 +94,6 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) // Create the policy @@ -160,7 +150,7 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) { } // We expect 2 successful requests (1 for encrypt, 1 for decrypt) - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Case2: Ensure that batch encryption did not affect the normal flow of @@ -170,9 +160,6 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) { var err error b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -228,14 +215,12 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) { } // We expect 2 successful requests (1 for encrypt, 1 for decrypt) - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Case3: If batch encryption input is not base64 encoded, it should fail. func TestTransit_BatchEncryptionCase3(t *testing.T) { var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 b, s := createBackendWithStorage(t) @@ -256,7 +241,7 @@ func TestTransit_BatchEncryptionCase3(t *testing.T) { } // We expect 0 successful requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case4: Test batch encryption with an existing key (and test references) @@ -264,9 +249,6 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) policyReq := &logical.Request{ @@ -331,7 +313,7 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) { } // We expect 4 successful requests (2 batch requests + 2 decrypt requests) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } // Case5: Test batch encryption with an existing derived key @@ -339,9 +321,6 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) policyData := map[string]interface{}{ @@ -409,7 +388,7 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) { } } // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } // Case6: Test batch encryption with an upserted non-derived key @@ -417,9 +396,6 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) batchInput := []interface{}{ @@ -475,7 +451,7 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) { } // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } // Case7: Test batch encryption with an upserted derived key @@ -485,9 +461,6 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, @@ -536,7 +509,7 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) { } } // We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } // Case8: If plaintext is not base64 encoded, encryption should fail @@ -546,9 +519,6 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Create the policy policyReq := &logical.Request{ Operation: logical.UpdateOperation, @@ -594,7 +564,7 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) { t.Fatal("expected an error") } // We expect 0 successful transit requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case9: If both plaintext and batch inputs are supplied, plaintext should be @@ -605,9 +575,6 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, @@ -634,7 +601,7 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) { } // We expect 2 successful batch encryptions - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Case10: Inconsistent presence of 'context' in batch input should be caught @@ -643,9 +610,6 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, @@ -666,7 +630,7 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) { t.Fatalf("expected an error") } // We expect no successful transit requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case11: Incorrect inputs for context and nonce should not fail the operation @@ -675,9 +639,6 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="}, map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "not-encoded"}, @@ -697,7 +658,7 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) { t.Fatal(err) } // We expect 1 successful encryption out of the 2-item batch - require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(1), b.billingDataCounts.Transit.Load()) } // Case12: Invalid batch input @@ -705,9 +666,6 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) { var err error b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - batchInput := []interface{}{ map[string]interface{}{}, "unexpected_interface", @@ -727,7 +685,7 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) { t.Fatalf("expected an error") } // We expect no successful requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case13: Incorrect input for nonce when we aren't in convergent encryption should fail the operation @@ -736,9 +694,6 @@ func TestTransit_EncryptionCase13(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Non-batch first data := map[string]interface{}{"plaintext": "bXkgc2VjcmV0IGRhdGE=", "nonce": "R80hr9eNUIuFV52e"} req := &logical.Request{ @@ -774,7 +729,7 @@ func TestTransit_EncryptionCase13(t *testing.T) { t.Fatal("expected request error") } // We expect no successful transit requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Case14: Incorrect input for nonce when we are in convergent version 3 should fail @@ -783,9 +738,6 @@ func TestTransit_EncryptionCase14(t *testing.T) { b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - cReq := &logical.Request{ Operation: logical.UpdateOperation, Path: "keys/my-key", @@ -836,7 +788,7 @@ func TestTransit_EncryptionCase14(t *testing.T) { t.Fatal("expected request error") } // We expect no successful transit requests - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } // Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem. diff --git a/builtin/logical/transit/path_hmac.go b/builtin/logical/transit/path_hmac.go index 316fa608f1..47776a3bb7 100644 --- a/builtin/logical/transit/path_hmac.go +++ b/builtin/logical/transit/path_hmac.go @@ -254,8 +254,9 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit hmac request count", "error", err.Error()) + } return resp, nil } @@ -432,8 +433,9 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit hmac verify request count", "error", err.Error()) + } return resp, nil } diff --git a/builtin/logical/transit/path_hmac_test.go b/builtin/logical/transit/path_hmac_test.go index 13c1e5e556..29fba94488 100644 --- a/builtin/logical/transit/path_hmac_test.go +++ b/builtin/logical/transit/path_hmac_test.go @@ -16,14 +16,10 @@ import ( "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" - "github.com/hashicorp/vault/vault/billing" "github.com/stretchr/testify/require" ) func TestTransit_HMAC(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, storage := createBackendWithSysView(t) cases := []struct { @@ -249,13 +245,10 @@ func TestTransit_HMAC(t *testing.T) { } } // Verify the total successful transit requests - require.Equal(t, int64(72), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(72), b.billingDataCounts.Transit.Load()) } func TestTransit_batchHMAC(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, storage := createBackendWithSysView(t) // First create a key @@ -411,7 +404,7 @@ func TestTransit_batchHMAC(t *testing.T) { } // Verify the total successful transit requests - require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(5), b.billingDataCounts.Transit.Load()) } // TestHMACBatchResultsFields checks that responses to HMAC verify requests using batch_input @@ -440,9 +433,6 @@ func TestHMACBatchResultsFields(t *testing.T) { }) require.NoError(t, err) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - keyName := "hmac-test-key" _, err = client.Logical().Write("transit/keys/"+keyName, map[string]interface{}{"type": "hmac", "key_size": 32}) require.NoError(t, err) @@ -505,5 +495,5 @@ func TestHMACBatchResultsFields(t *testing.T) { } // We expect 4 successful requests (2 for batch HMAC generation, 2 for batch verification) - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), cores[0].GetInMemoryTransitDataProtectionCallCounts()) } diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go index 1451aa768c..df0dbea991 100644 --- a/builtin/logical/transit/path_rewrap.go +++ b/builtin/logical/transit/path_rewrap.go @@ -308,8 +308,9 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d * resp.AddWarning("A provided nonce value was used within FIPS mode, this violates FIPS 140 compliance.") } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit rewrap request count", "error", err.Error()) + } return resp, nil } diff --git a/builtin/logical/transit/path_rewrap_test.go b/builtin/logical/transit/path_rewrap_test.go index a3734da6da..7c8ddbeab0 100644 --- a/builtin/logical/transit/path_rewrap_test.go +++ b/builtin/logical/transit/path_rewrap_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/stretchr/testify/require" ) @@ -19,9 +18,6 @@ func TestTransit_BatchRewrapCase1(t *testing.T) { var err error b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -119,7 +115,7 @@ func TestTransit_BatchRewrapCase1(t *testing.T) { } // We expect 2 successful requests (1 for encrypt, 1 for rewrap) - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Check the normal flow of rewrap with upserted key @@ -128,9 +124,6 @@ func TestTransit_BatchRewrapCase2(t *testing.T) { var err error b, s := createBackendWithStorage(t) - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Upsert the key and encrypt the data plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA==" @@ -229,7 +222,7 @@ func TestTransit_BatchRewrapCase2(t *testing.T) { t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2) } // We expect 2 successful transit requests (1 for encrypt, 1 for rewrap) - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load()) } // Batch encrypt plaintexts, rotate the keys and rewrap all the ciphertexts @@ -237,9 +230,6 @@ func TestTransit_BatchRewrapCase3(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) batchEncryptionInput := []interface{}{ @@ -341,7 +331,7 @@ func TestTransit_BatchRewrapCase3(t *testing.T) { } } // We expect 6 successful transit requests (2 for batch encryption, 2 for batch rewrap, and 2 for decryption) - require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load()) } // TestTransit_BatchRewrapCase4 batch rewrap leveraging RSA padding schemes @@ -349,9 +339,6 @@ func TestTransit_BatchRewrapCase4(t *testing.T) { var resp *logical.Response var err error - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, s := createBackendWithStorage(t) batchEncryptionInput := []interface{}{ @@ -459,5 +446,5 @@ func TestTransit_BatchRewrapCase4(t *testing.T) { } } // We expect 6 succcessful calls to the transit backend (2 for batch encryption, 2 for batch decryption, and 2 for batch rewrap) - require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load()) } diff --git a/builtin/logical/transit/path_sign_verify.go b/builtin/logical/transit/path_sign_verify.go index 114fc98209..4ebebec5b3 100644 --- a/builtin/logical/transit/path_sign_verify.go +++ b/builtin/logical/transit/path_sign_verify.go @@ -492,8 +492,9 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit sign request count", "error", err.Error()) + } return resp, nil } @@ -750,8 +751,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * } } - // Increment the counter for successful operations - b.incrementDataProtectionCounter(int64(successfulRequests)) + if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil { + b.Logger().Error("failed to track transit sign verify request count", "error", err.Error()) + } return resp, nil } diff --git a/builtin/logical/transit/path_sign_verify_test.go b/builtin/logical/transit/path_sign_verify_test.go index 482f6627dc..c116a2cb3c 100644 --- a/builtin/logical/transit/path_sign_verify_test.go +++ b/builtin/logical/transit/path_sign_verify_test.go @@ -15,7 +15,6 @@ import ( "github.com/hashicorp/vault/helper/constants" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/vault/billing" "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" @@ -33,9 +32,6 @@ type signOutcome struct { } func TestTransit_SignVerify_ECDSA(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - t.Run("256", func(t *testing.T) { testTransit_SignVerify_ECDSA(t, 256) }) @@ -45,9 +41,6 @@ func TestTransit_SignVerify_ECDSA(t *testing.T) { t.Run("521", func(t *testing.T) { testTransit_SignVerify_ECDSA(t, 521) }) - - // Verify the total successful transit requests - require.Equal(t, int64(84), billing.CurrentDataProtectionCallCounts.Transit) } func testTransit_SignVerify_ECDSA(t *testing.T, bits int) { @@ -330,6 +323,7 @@ func testTransit_SignVerify_ECDSA(t *testing.T, bits int) { verifyRequest(req, false, "", sig) // Now try the v1 verifyRequest(req, true, "", v1sig) + require.Equal(t, uint64(28), b.billingDataCounts.Transit.Load()) } func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, expectValid bool, postpath string, b *backend) { @@ -380,9 +374,6 @@ func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, ex // TestTransit_SignVerify_Ed25519Behavior makes sure the options on ENT for a // Ed25519ph/ctx signature fail on CE and ENT if invalid func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, storage := createBackendWithSysView(t) // First create a key @@ -477,17 +468,14 @@ func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) { } // Verify the total successful transit requests if constants.IsEnterprise { - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load()) } else { // We expect 0 successful calls on CE because we expect the verify to fail - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load()) } } func TestTransit_SignVerify_ED25519(t *testing.T) { - // Reset the transit counter - billing.CurrentDataProtectionCallCounts.Transit = 0 - b, storage := createBackendWithSysView(t) // First create a key @@ -832,7 +820,7 @@ func TestTransit_SignVerify_ED25519(t *testing.T) { verifyRequest(req, false, outcome, "bar", goodsig, true) // Verify the total successful transit requests - require.Equal(t, int64(24), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(24), b.billingDataCounts.Transit.Load()) } func TestTransit_SignVerify_RSA_PSS(t *testing.T) { diff --git a/sdk/framework/backend.go b/sdk/framework/backend.go index e46ab8b75f..148a65da6a 100644 --- a/sdk/framework/backend.go +++ b/sdk/framework/backend.go @@ -118,6 +118,9 @@ type Backend struct { // communicate with a plugin to activate a feature. ActivationFunc func(context.Context, *logical.Request, string) error + // ConsumptionBillingManager is the consumption billing manager the backend can use to write billing data. + ConsumptionBillingManager logical.ConsumptionBillingManager + logger log.Logger system logical.SystemView events logical.EventSender @@ -440,6 +443,11 @@ func (b *Backend) Setup(ctx context.Context, config *logical.BackendConfig) erro b.system = config.System b.events = config.EventsSender b.observations = config.ObservationRecorder + if b.System() != nil && b.System().GetConsumptionBillingManager() != nil { + b.ConsumptionBillingManager = b.System().GetConsumptionBillingManager() + } else { + b.ConsumptionBillingManager = logical.NewNullConsumptionBillingManager() + } return nil } @@ -546,7 +554,7 @@ func (b *Backend) init() { for i, p := range b.Paths { // Detect the coding error of failing to initialise Pattern if len(p.Pattern) == 0 { - panic(fmt.Sprintf("Routing pattern cannot be blank")) + panic("Routing pattern cannot be blank") } // Detect the coding error of attempting to define a CreateOperation without defining an ExistenceCheck diff --git a/sdk/logical/billing_system_view.go b/sdk/logical/billing_system_view.go new file mode 100644 index 0000000000..caa7b83e8c --- /dev/null +++ b/sdk/logical/billing_system_view.go @@ -0,0 +1,25 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +package logical + +import "context" + +// billing.ConsumptionBillingManager is an implementation of this interface that the backend can use to write billing data. +type ConsumptionBillingManager interface { + WriteBillingData(ctx context.Context, pluginType string, data map[string]interface{}) error +} + +// ================================ +// Creates a null consumption billing manager that does nothing +var _ ConsumptionBillingManager = (*nullConsumptionBillingManager)(nil) + +func NewNullConsumptionBillingManager() ConsumptionBillingManager { + return &nullConsumptionBillingManager{} +} + +type nullConsumptionBillingManager struct{} + +func (n *nullConsumptionBillingManager) WriteBillingData(ctx context.Context, pluginType string, data map[string]interface{}) error { + return nil +} diff --git a/sdk/logical/system_view.go b/sdk/logical/system_view.go index 89e71563a5..98562c3a78 100644 --- a/sdk/logical/system_view.go +++ b/sdk/logical/system_view.go @@ -117,6 +117,9 @@ type SystemView interface { // ExtractVerifyPlugin extracts and verifies the plugin artifact DownloadExtractVerifyPlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error + + // GetConsumptionBillingManager returns the consumption billing manager + GetConsumptionBillingManager() ConsumptionBillingManager } type PasswordPolicy interface { @@ -320,6 +323,10 @@ func (d StaticSystemView) DownloadExtractVerifyPlugin(_ context.Context, _ *plug return errors.New("DownloadExtractVerifyPlugin is not implemented in StaticSystemView") } +func (d StaticSystemView) GetConsumptionBillingManager() ConsumptionBillingManager { + return nil +} + // PluginLicenseUtil defines the functions needed to request License and PluginEnv // by the plugin licensing under github.com/hashicorp/vault-licensing // This only should be used by the plugin to get the license and plugin environment diff --git a/sdk/plugin/grpc_system.go b/sdk/plugin/grpc_system.go index 726f61262c..b45dfc1a2f 100644 --- a/sdk/plugin/grpc_system.go +++ b/sdk/plugin/grpc_system.go @@ -36,6 +36,11 @@ type gRPCSystemViewClient struct { client pb.SystemViewClient } +func (s *gRPCSystemViewClient) GetConsumptionBillingManager() logical.ConsumptionBillingManager { + // Not implemented on pluginbackend + return nil +} + func (s *gRPCSystemViewClient) DefaultLeaseTTL() time.Duration { reply, err := s.client.DefaultLeaseTTL(context.Background(), &pb.Empty{}) if err != nil { diff --git a/vault/billing/billing_counts.go b/vault/billing/billing_counts.go index d92bfe4d2e..ced2785d6a 100644 --- a/vault/billing/billing_counts.go +++ b/vault/billing/billing_counts.go @@ -4,19 +4,24 @@ package billing import ( + "context" "fmt" "sync" + "sync/atomic" "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/logical" ) const ( - BillingSubPath = "billing/" - ReplicatedPrefix = "replicated/" - RoleHWMCountsHWM = "maxRoleCounts/" - KvHWMCountsHWM = "maxKvCounts/" - DataProtectionCallCountsMetric = "dataProtectionCallCounts/" - LocalPrefix = "local/" - ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/" + BillingSubPath = "billing/" + ReplicatedPrefix = "replicated/" + RoleHWMCountsHWM = "maxRoleCounts/" + KvHWMCountsHWM = "maxKvCounts/" + TransitDataProtectionCallCountsPrefix = "transitDataProtectionCallCounts/" + LocalPrefix = "local/" + ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/" BillingWriteInterval = 10 * time.Minute ) @@ -27,7 +32,9 @@ type ConsumptionBilling struct { // BillingStorageLock controls access to the billing storage paths BillingStorageLock sync.RWMutex - BillingConfig BillingConfig + BillingConfig BillingConfig + DataProtectionCallCounts DataProtectionCallCounts + Logger log.Logger } type BillingConfig struct { @@ -45,10 +52,26 @@ func GetMonthlyBillingPath(localPrefix string, now time.Time, billingMetric stri } type DataProtectionCallCounts struct { - Transit int64 `json:"transit,omitempty"` + Transit *atomic.Uint64 `json:"transit,omitempty"` // TODO: Uncomment when we add support for Transform tracking (VAULT-41205) - // Transform int64 `json:"transform,omitempty"` + // Transform atomic.int64 `json:"transform,omitempty"` } -// Global counter for all data protection calls on this cluster -var CurrentDataProtectionCallCounts = DataProtectionCallCounts{} +var _ logical.ConsumptionBillingManager = (*ConsumptionBilling)(nil) + +func (s *ConsumptionBilling) WriteBillingData(ctx context.Context, mountType string, data map[string]interface{}) error { + switch mountType { + case "transit": + val, ok := data["count"].(uint64) + if !ok { + err := fmt.Errorf("invalid value type for transit") + return err + } + + s.DataProtectionCallCounts.Transit.Add(val) + default: + err := fmt.Errorf("unknown metric type: %s", mountType) + return err + } + return nil +} diff --git a/vault/consumption_billing.go b/vault/consumption_billing.go index 4c7bdd6c53..ef1c28bdb4 100644 --- a/vault/consumption_billing.go +++ b/vault/consumption_billing.go @@ -5,6 +5,8 @@ package vault import ( "context" + "fmt" + "sync/atomic" "time" "github.com/hashicorp/vault/helper/timeutil" @@ -15,8 +17,14 @@ func (c *Core) setupConsumptionBilling(ctx context.Context) error { // We need replication (post unseal) to start before we run the consumption billing metrics worker // This is because there is primary/secondary cluster specific logic c.consumptionBillingLock.Lock() + logger := c.baseLogger.Named("billing") + c.AddLogger(logger) c.consumptionBilling = &billing.ConsumptionBilling{ BillingConfig: c.billingConfig, + DataProtectionCallCounts: billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + }, + Logger: logger, } c.consumptionBillingLock.Unlock() c.postUnsealFuncs = append(c.postUnsealFuncs, func() { @@ -72,7 +80,12 @@ func (c *Core) updateBillingMetrics(ctx context.Context) error { c.UpdateReplicatedHWMMetrics(ctx, currentMonth) } c.UpdateLocalHWMMetrics(ctx, currentMonth) - c.UpdateLocalAggregatedMetrics(ctx, currentMonth) + if err := c.UpdateLocalAggregatedMetrics(ctx, currentMonth); err != nil { + c.logger.Error("error updating cluster data protection call counts", "error", err) + } else { + c.logger.Info("updated cluster data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth) + } + } return nil } @@ -119,10 +132,7 @@ func (c *Core) UpdateLocalHWMMetrics(ctx context.Context, currentMonth time.Time func (c *Core) UpdateLocalAggregatedMetrics(ctx context.Context, currentMonth time.Time) error { if _, err := c.UpdateDataProtectionCallCounts(ctx, currentMonth); err != nil { - c.logger.Error("error updating local max data protection call counts", "error", err) - } else { - c.logger.Info("updated local max data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth) + return fmt.Errorf("could not store transit data protection call counts: %w", err) } - return nil } diff --git a/vault/consumption_billing_testing_util.go b/vault/consumption_billing_testing_util.go new file mode 100644 index 0000000000..4b4b8beab3 --- /dev/null +++ b/vault/consumption_billing_testing_util.go @@ -0,0 +1,12 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: MPL-2.0 + +package vault + +func (c *Core) ResetInMemoryDataProtectionCallCounts() { + c.consumptionBilling.DataProtectionCallCounts.Transit.Store(0) +} + +func (c *Core) GetInMemoryTransitDataProtectionCallCounts() uint64 { + return c.consumptionBilling.DataProtectionCallCounts.Transit.Load() +} diff --git a/vault/consumption_billing_util.go b/vault/consumption_billing_util.go index e6656a3145..09312ccf84 100644 --- a/vault/consumption_billing_util.go +++ b/vault/consumption_billing_util.go @@ -267,29 +267,36 @@ func (c *Core) GetBillingSubView() *BarrierView { // storeDataProtectionCallCountsLocked must be called with BillingStorageLock held func (c *Core) storeDataProtectionCallCountsLocked(ctx context.Context, maxCounts *billing.DataProtectionCallCounts, localPathPrefix string, month time.Time) error { - billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric) - entry, err := logical.StorageEntryJSON(billingPath, maxCounts) - if err != nil { - return err + // Store count for each data protection type separately because they are atomic counters + billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.TransitDataProtectionCallCountsPrefix) + transitCount := maxCounts.Transit.Load() + entry := &logical.StorageEntry{ + Key: billingPath, + Value: []byte(strconv.FormatUint(transitCount, 10)), } return c.GetBillingSubView().Put(ctx, entry) } // getStoredDataProtectionCallCountsLocked must be called with BillingStorageLock held func (c *Core) getStoredDataProtectionCallCountsLocked(ctx context.Context, localPathPrefix string, month time.Time) (*billing.DataProtectionCallCounts, error) { - billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric) + // Retrieve count for each data protection type separately because they are atomic counters + ret := &billing.DataProtectionCallCounts{ + Transit: &atomic.Uint64{}, + } + billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.TransitDataProtectionCallCountsPrefix) entry, err := c.GetBillingSubView().Get(ctx, billingPath) if err != nil { return nil, err } if entry == nil { - return &billing.DataProtectionCallCounts{}, nil + return ret, nil } - var maxCounts billing.DataProtectionCallCounts - if err := entry.DecodeJSON(&maxCounts); err != nil { + transitCount, err := strconv.ParseUint(string(entry.Value), 10, 64) + if err != nil { return nil, err } - return &maxCounts, nil + ret.Transit.Store(transitCount) + return ret, nil } func (c *Core) GetStoredDataProtectionCallCounts(ctx context.Context, month time.Time) (*billing.DataProtectionCallCounts, error) { @@ -302,10 +309,6 @@ func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth c.consumptionBilling.BillingStorageLock.Lock() defer c.consumptionBilling.BillingStorageLock.Unlock() - // Read and atomically reset the current counter value - // TODO: Reset Transform call counts too (VAULT-41205) - currentTransitCount := atomic.SwapInt64(&billing.CurrentDataProtectionCallCounts.Transit, 0) - storedDataProtectionCallCounts, err := c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, currentMonth) if err != nil { return nil, err @@ -315,7 +318,8 @@ func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth } // Sum the current count with the stored count - storedDataProtectionCallCounts.Transit += currentTransitCount + transitCount := c.consumptionBilling.DataProtectionCallCounts.Transit.Swap(0) + storedDataProtectionCallCounts.Transit.Add(transitCount) // TODO: Update Transform call counts (VAULT-41205) err = c.storeDataProtectionCallCountsLocked(ctx, storedDataProtectionCallCounts, billing.LocalPrefix, currentMonth) diff --git a/vault/consumption_billing_util_oss_test.go b/vault/consumption_billing_util_oss_test.go index effde53104..0b1a69ae45 100644 --- a/vault/consumption_billing_util_oss_test.go +++ b/vault/consumption_billing_util_oss_test.go @@ -42,7 +42,7 @@ func verifyExpectedRoleCounts(t *testing.T, actual *RoleCounts, baseCount int) { // testCMACOperations is a no-op in OSS since CMAC is an Enterprise-only feature. // Returns the current count unchanged. -func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount int64) int64 { +func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount uint64) uint64 { // CMAC is not supported in OSS, so we don't perform any operations return currentCount } diff --git a/vault/consumption_billing_util_test.go b/vault/consumption_billing_util_test.go index a438c1b0c2..fd0810bfd8 100644 --- a/vault/consumption_billing_util_test.go +++ b/vault/consumption_billing_util_test.go @@ -420,9 +420,6 @@ func TestDataProtectionCallCounts(t *testing.T) { _, err := core.HandleRequest(ctx, req) require.NoError(t, err) - // Reset the transit counters - billing.CurrentDataProtectionCallCounts.Transit = 0 - // Create an encryption key req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/foo") req.Data["type"] = "aes256-gcm96" @@ -439,8 +436,8 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp) require.NotNil(t, resp.Data) - // Verify that the transit counter is incremented (replicated mount by default) - require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) + // Verify that the transit counter is incremented + require.Equal(t, uint64(1), core.GetInMemoryTransitDataProtectionCallCounts()) // Get the ciphertext from the encryption response ciphertext, ok := resp.Data["ciphertext"].(string) @@ -455,7 +452,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NoError(t, err) // Verify that the transit counter is incremented - require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(2), core.GetInMemoryTransitDataProtectionCallCounts()) // Test rewrap operation req = logical.TestRequest(t, logical.UpdateOperation, "transit/rewrap/foo") @@ -467,7 +464,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(3), core.GetInMemoryTransitDataProtectionCallCounts()) // Get the new ciphertext from rewrap newCiphertext, ok := resp.Data["ciphertext"].(string) @@ -483,9 +480,9 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(4), core.GetInMemoryTransitDataProtectionCallCounts()) - // Test HMAC generation + // Test HMAC operation req = logical.TestRequest(t, logical.UpdateOperation, "transit/hmac/foo") req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" req.ClientToken = root @@ -495,7 +492,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(5), core.GetInMemoryTransitDataProtectionCallCounts()) // Get the HMAC value hmacValue, ok := resp.Data["hmac"].(string) @@ -513,7 +510,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(6), core.GetInMemoryTransitDataProtectionCallCounts()) // Verify the HMAC is valid hmacValid, ok := resp.Data["valid"].(bool) @@ -537,7 +534,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(7), core.GetInMemoryTransitDataProtectionCallCounts()) // Get the signature signature, ok := resp.Data["signature"].(string) @@ -555,7 +552,7 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NotNil(t, resp.Data) // Verify that the transit counter is incremented - require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(8), core.GetInMemoryTransitDataProtectionCallCounts()) // Verify the signature is valid signatureValid, ok := resp.Data["valid"].(bool) @@ -563,27 +560,27 @@ func TestDataProtectionCallCounts(t *testing.T) { require.True(t, signatureValid) // Test CMAC operations (ENT only - will be no-op in OSS) - currentCount := billing.CurrentDataProtectionCallCounts.Transit + currentCount := core.GetInMemoryTransitDataProtectionCallCounts() currentCount = testCMACOperations(t, core, ctx, root, currentCount) // Verify that the transit counter matches expected count - require.Equal(t, currentCount, billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, currentCount, core.GetInMemoryTransitDataProtectionCallCounts()) // Now test persisting the summed counts - store and retrieve counts // First, update the data protection call counts (this will sum current counter with stored value) summedCounts, err := core.UpdateDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) require.NotNil(t, summedCounts) - require.Equal(t, currentCount, summedCounts.Transit) + require.Equal(t, currentCount, summedCounts.Transit.Load()) // Verify the counter was reset after update - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update") // Retrieve the stored counts storedCounts, err := core.GetStoredDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) require.NotNil(t, storedCounts) - require.Equal(t, currentCount, storedCounts.Transit) + require.Equal(t, currentCount, storedCounts.Transit.Load()) // Perform more operations to increase the counter req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo") @@ -593,21 +590,21 @@ func TestDataProtectionCallCounts(t *testing.T) { require.NoError(t, err) // Counter should now be 1 (reset + 1 operation) - require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(1), core.GetInMemoryTransitDataProtectionCallCounts()) // Update counts again - should sum the new count (1) with the stored count (currentCount) summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) expectedSum := currentCount + 1 - require.Equal(t, expectedSum, summedCounts.Transit, "Count should be sum of stored and current") + require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should be sum of stored and current") // Verify the counter was reset after update - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update") // Verify stored counts are now the sum storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) - require.Equal(t, expectedSum, storedCounts.Transit) + require.Equal(t, expectedSum, storedCounts.Transit.Load()) // Add more operations without manually resetting for i := 0; i < 3; i++ { @@ -619,35 +616,35 @@ func TestDataProtectionCallCounts(t *testing.T) { } // Counter should be 3 - require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit) + require.Equal(t, uint64(3), core.GetInMemoryTransitDataProtectionCallCounts()) // Update counts - should sum 3 with the previous stored sum summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) expectedSum = expectedSum + 3 - require.Equal(t, expectedSum, summedCounts.Transit, "Count should continue to sum") + require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should continue to sum") // Verify the counter was reset after update - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update") + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update") // Verify stored counts storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) - require.Equal(t, expectedSum, storedCounts.Transit) + require.Equal(t, expectedSum, storedCounts.Transit.Load()) // Update again without any new operations // This verifies we don't double-count summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) - require.Equal(t, expectedSum, summedCounts.Transit, "Count should remain the same when no new operations occurred") + require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should remain the same when no new operations occurred") // Verify stored counts haven't changed storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now()) require.NoError(t, err) - require.Equal(t, expectedSum, storedCounts.Transit, "Stored count should remain the same") + require.Equal(t, expectedSum, storedCounts.Transit.Load(), "Stored count should remain the same") // Verify counter is still at 0 - require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should still be 0") + require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should still be 0") } func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) { diff --git a/vault/core_util_common.go b/vault/core_util_common.go index d15f51e0f1..51608ef000 100644 --- a/vault/core_util_common.go +++ b/vault/core_util_common.go @@ -78,3 +78,7 @@ func (c *Core) setupHeaderHMACKey(ctx context.Context, isPerfStandby bool) error func (c *Core) GetPkiCertificateCounter() logical.CertificateCounter { return c.pkiCertCountManager } + +func (c *Core) GetConsumptionBillingManager() logical.ConsumptionBillingManager { + return c.consumptionBilling +} diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 74530d9b09..405ad60585 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -401,3 +401,7 @@ func (d dynamicSystemView) DeregisterRotationJob(ctx context.Context, req *rotat return d.core.DeregisterRotationJob(nsCtx, req) } + +func (d dynamicSystemView) GetConsumptionBillingManager() logical.ConsumptionBillingManager { + return d.core.GetConsumptionBillingManager() +} From 4e78a0bfc51a4aad3d7586aa4ca52731028522bd Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Tue, 3 Feb 2026 17:50:14 -0500 Subject: [PATCH 3/3] UI: Prioritize direct link when multiple mounts are visible (#12142) (#12156) * override auth form with direct link * add changelog Co-authored-by: claire bontempo <68122737+hellobontempo@users.noreply.github.com> --- changelog/_12142.txt | 3 ++ ui/app/components/auth/page.ts | 28 +++++++++---------- ui/app/routes/vault/cluster/auth.js | 1 - .../auth/page/listing-visibility-test.js | 14 ++++++++++ 4 files changed, 30 insertions(+), 16 deletions(-) create mode 100644 changelog/_12142.txt diff --git a/changelog/_12142.txt b/changelog/_12142.txt new file mode 100644 index 0000000000..161cb1e7b6 --- /dev/null +++ b/changelog/_12142.txt @@ -0,0 +1,3 @@ +```release-note:bug +ui: Fixes login form so `?with=` query param correctly displays only the specified mount when multiple mounts of the same auth type are configured with `listing_visibility="unauth"` +``` \ No newline at end of file diff --git a/ui/app/components/auth/page.ts b/ui/app/components/auth/page.ts index a3ce5d58c4..e81f2b791d 100644 --- a/ui/app/components/auth/page.ts +++ b/ui/app/components/auth/page.ts @@ -58,17 +58,6 @@ import type { Task } from 'ember-concurrency'; * 🔀 Multiple visible mounts: * ▸ Path dropdown is shown. * - * @example - * * * @param {object} cluster - the ember data cluster model. contains information such as cluster id, name and boolean for if the cluster is in standby * @param {object} directLinkData - mount data built from the "with" query param. If param is a mount path and maps to a visible mount, the login form defaults to this mount. Otherwise the form preselects the passed auth type. @@ -161,10 +150,14 @@ export default class AuthPage extends Component { get directLinkViews() { const { directLinkData } = this.args; - // If "path" key exists we know the "with" query param references a mount with listing_visibility="unauth" - // Treat it as a preferred method and hide all other tabs. + // If "path" key exists then the "with" query param references a specific mount with listing_visibility="unauth". + // Show only this mount and hide any others for this auth type as well as any tabs for different auth types. if (directLinkData?.path) { - const tabData = this.filterVisibleMountsByType([directLinkData.type]); + const mounts = this.mountsByType(directLinkData.type); + const selectedMount = mounts?.find((m) => m.path === directLinkData?.path); + const tabData: UnauthMountsByType = { + [directLinkData.type]: selectedMount ? [selectedMount] : null, + }; const defaultView = this.constructViews(FormView.TABS, tabData); const alternateView = this.constructViews(FormView.DROPDOWN, null); @@ -263,11 +256,16 @@ export default class AuthPage extends Component { const tabs: UnauthMountsByType = {}; for (const type of authTypes) { // adds visible mounts for each type, if they exist - tabs[type] = this.visibleMountsByType?.[type] || null; + tabs[type] = this.mountsByType(type); } return tabs; } + private mountsByType(type: string) { + // Return null and not an empty array to distinguish between "dropdown mode" and "tabs with no mounts" in downstream components + return this.visibleMountsByType?.[type] || null; + } + private constructViews(view: FormView, tabData: UnauthMountsByType | null) { return { view, tabData }; } diff --git a/ui/app/routes/vault/cluster/auth.js b/ui/app/routes/vault/cluster/auth.js index 8e0e0ddf7d..2d1d3839dc 100644 --- a/ui/app/routes/vault/cluster/auth.js +++ b/ui/app/routes/vault/cluster/auth.js @@ -102,7 +102,6 @@ export default class AuthRoute extends Route { const { default_auth_type, backup_auth_types } = response.data; return { defaultType: default_auth_type, - // TODO WIP backend PR consistently return empty array when no backup_auth_types backupTypes: backup_auth_types?.length ? backup_auth_types : null, }; } diff --git a/ui/tests/integration/components/auth/page/listing-visibility-test.js b/ui/tests/integration/components/auth/page/listing-visibility-test.js index 10a005e027..54eb768ba9 100644 --- a/ui/tests/integration/components/auth/page/listing-visibility-test.js +++ b/ui/tests/integration/components/auth/page/listing-visibility-test.js @@ -132,6 +132,20 @@ module('Integration | Component | auth | page | listing visibility', function (h assert.dom(GENERAL.backButton).doesNotExist(); }); + test('it treats direct link as only mount when multiple mounts are tuned with listing_visibility="unauth"', async function (assert) { + this.directLinkData = { path: 'userpass2/', type: 'userpass' }; + await this.renderComponent(); + assert.dom(AUTH_FORM.authForm('userpass')).exists; + assert.dom(AUTH_FORM.tabBtn('userpass')).hasText('Userpass', 'it renders tab for type'); + assert.dom(AUTH_FORM.tabs).exists({ count: 1 }, 'only one tab renders'); + assert.dom(GENERAL.inputByAttr('path')).hasAttribute('type', 'hidden'); + assert.dom(GENERAL.inputByAttr('path')).hasValue('userpass2/'); + assert.dom(GENERAL.button('Sign in with other methods')).exists('"Sign in with other methods" renders'); + assert.dom(GENERAL.selectByAttr('auth type')).doesNotExist(); + assert.dom(AUTH_FORM.advancedSettings).doesNotExist(); + assert.dom(GENERAL.backButton).doesNotExist(); + }); + test('it prioritizes auth type from canceled mfa instead of direct link for path', async function (assert) { assert.expect(1); this.directLinkData = this.directLinkIsVisibleMount; // type is "oidc"