mirror of
https://github.com/hashicorp/vault.git
synced 2026-02-03 20:40:45 -05:00
Merge remote-tracking branch 'remotes/from/ce/main'
Some checks are pending
build / setup (push) Waiting to run
build / Check ce/* Pull Requests (push) Blocked by required conditions
build / ui (push) Blocked by required conditions
build / artifacts-ce (push) Blocked by required conditions
build / artifacts-ent (push) Blocked by required conditions
build / hcp-image (push) Blocked by required conditions
build / test (push) Blocked by required conditions
build / test-hcp-image (push) Blocked by required conditions
build / completed-successfully (push) Blocked by required conditions
CI / setup (push) Waiting to run
CI / Run Autopilot upgrade tool (push) Blocked by required conditions
CI / Run Go tests (push) Blocked by required conditions
CI / Run Go tests tagged with testonly (push) Blocked by required conditions
CI / Run Go tests with data race detection (push) Blocked by required conditions
CI / Run Go tests with FIPS configuration (push) Blocked by required conditions
CI / Test UI (push) Blocked by required conditions
CI / tests-completed (push) Blocked by required conditions
Run linters / Setup (push) Waiting to run
Run linters / Deprecated functions (push) Blocked by required conditions
Run linters / Code checks (push) Blocked by required conditions
Run linters / Protobuf generate delta (push) Blocked by required conditions
Run linters / Format (push) Blocked by required conditions
Run linters / Semgrep (push) Waiting to run
Check Copywrite Headers / copywrite (push) Waiting to run
Security Scan / scan (push) Waiting to run
Some checks are pending
build / setup (push) Waiting to run
build / Check ce/* Pull Requests (push) Blocked by required conditions
build / ui (push) Blocked by required conditions
build / artifacts-ce (push) Blocked by required conditions
build / artifacts-ent (push) Blocked by required conditions
build / hcp-image (push) Blocked by required conditions
build / test (push) Blocked by required conditions
build / test-hcp-image (push) Blocked by required conditions
build / completed-successfully (push) Blocked by required conditions
CI / setup (push) Waiting to run
CI / Run Autopilot upgrade tool (push) Blocked by required conditions
CI / Run Go tests (push) Blocked by required conditions
CI / Run Go tests tagged with testonly (push) Blocked by required conditions
CI / Run Go tests with data race detection (push) Blocked by required conditions
CI / Run Go tests with FIPS configuration (push) Blocked by required conditions
CI / Test UI (push) Blocked by required conditions
CI / tests-completed (push) Blocked by required conditions
Run linters / Setup (push) Waiting to run
Run linters / Deprecated functions (push) Blocked by required conditions
Run linters / Code checks (push) Blocked by required conditions
Run linters / Protobuf generate delta (push) Blocked by required conditions
Run linters / Format (push) Blocked by required conditions
Run linters / Semgrep (push) Waiting to run
Check Copywrite Headers / copywrite (push) Waiting to run
Security Scan / scan (push) Waiting to run
This commit is contained in:
commit
0c61ce5ac5
42 changed files with 294 additions and 241 deletions
2
.github/actions/build-vault/action.yml
vendored
2
.github/actions/build-vault/action.yml
vendored
|
|
@ -69,7 +69,7 @@ runs:
|
||||||
shell: bash
|
shell: bash
|
||||||
run: git config --global url."https://${{ inputs.github-token }}:@github.com".insteadOf "https://github.com"
|
run: git config --global url."https://${{ inputs.github-token }}:@github.com".insteadOf "https://github.com"
|
||||||
- name: Restore UI from cache
|
- name: Restore UI from cache
|
||||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||||
with:
|
with:
|
||||||
# Restore the UI asset from the UI build workflow. Never use a partial restore key.
|
# Restore the UI asset from the UI build workflow. Never use a partial restore key.
|
||||||
enableCrossOsArchive: true
|
enableCrossOsArchive: true
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ runs:
|
||||||
} | tee -a "$GITHUB_ENV"
|
} | tee -a "$GITHUB_ENV"
|
||||||
- name: Try to restore dynamic config from cache
|
- name: Try to restore dynamic config from cache
|
||||||
id: dyn-cfg-cache
|
id: dyn-cfg-cache
|
||||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||||
with:
|
with:
|
||||||
path: ${{ env.DYNAMIC_CONFIG_PATH }}
|
path: ${{ env.DYNAMIC_CONFIG_PATH }}
|
||||||
key: dyn-cfg-${{ env.DYNAMIC_CONFIG_KEY }}
|
key: dyn-cfg-${{ env.DYNAMIC_CONFIG_KEY }}
|
||||||
|
|
|
||||||
2
.github/actions/install-tools/action.yml
vendored
2
.github/actions/install-tools/action.yml
vendored
|
|
@ -69,7 +69,7 @@ runs:
|
||||||
echo "VAULT_TOOLS_CACHE_KEY=${cache_key}"
|
echo "VAULT_TOOLS_CACHE_KEY=${cache_key}"
|
||||||
} | tee -a "$GITHUB_ENV"
|
} | tee -a "$GITHUB_ENV"
|
||||||
- id: cache-tools
|
- id: cache-tools
|
||||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||||
with:
|
with:
|
||||||
lookup-only: ${{ inputs.no-restore }}
|
lookup-only: ${{ inputs.no-restore }}
|
||||||
path: ${{ env.VAULT_TOOLS_PATH }}
|
path: ${{ env.VAULT_TOOLS_PATH }}
|
||||||
|
|
|
||||||
2
.github/actions/set-up-go/action.yml
vendored
2
.github/actions/set-up-go/action.yml
vendored
|
|
@ -63,7 +63,7 @@ runs:
|
||||||
echo "cache-key=go-modules-${wd_hash}-${{ hashFiles('**/go.sum') }}"
|
echo "cache-key=go-modules-${wd_hash}-${{ hashFiles('**/go.sum') }}"
|
||||||
} | tee -a "$GITHUB_OUTPUT"
|
} | tee -a "$GITHUB_OUTPUT"
|
||||||
- id: cache-modules
|
- id: cache-modules
|
||||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||||
with:
|
with:
|
||||||
enableCrossOsArchive: true
|
enableCrossOsArchive: true
|
||||||
lookup-only: ${{ inputs.no-restore }}
|
lookup-only: ${{ inputs.no-restore }}
|
||||||
|
|
|
||||||
2
.github/actions/set-up-pipeline/action.yml
vendored
2
.github/actions/set-up-pipeline/action.yml
vendored
|
|
@ -33,7 +33,7 @@ runs:
|
||||||
} | tee -a "$GITHUB_ENV"
|
} | tee -a "$GITHUB_ENV"
|
||||||
- name: Try to restore pipeline from cache
|
- name: Try to restore pipeline from cache
|
||||||
id: pipeline-cache
|
id: pipeline-cache
|
||||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||||
with:
|
with:
|
||||||
path: ${{ env.PIPELINE_PATH }}
|
path: ${{ env.PIPELINE_PATH }}
|
||||||
key: pipeline-${{ env.PIPELINE_HASH }}
|
key: pipeline-${{ env.PIPELINE_HASH }}
|
||||||
|
|
|
||||||
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
|
|
@ -303,7 +303,7 @@ jobs:
|
||||||
run: echo "ui-hash=$(git ls-tree HEAD ui --object-only)" | tee -a "$GITHUB_OUTPUT"
|
run: echo "ui-hash=$(git ls-tree HEAD ui --object-only)" | tee -a "$GITHUB_OUTPUT"
|
||||||
- name: Set up UI asset cache
|
- name: Set up UI asset cache
|
||||||
id: cache-ui-assets
|
id: cache-ui-assets
|
||||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||||
with:
|
with:
|
||||||
enableCrossOsArchive: true
|
enableCrossOsArchive: true
|
||||||
lookup-only: true
|
lookup-only: true
|
||||||
|
|
|
||||||
2
.github/workflows/enos-lint.yml
vendored
2
.github/workflows/enos-lint.yml
vendored
|
|
@ -45,7 +45,7 @@ jobs:
|
||||||
- uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
|
- uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
|
||||||
with:
|
with:
|
||||||
terraform_wrapper: false
|
terraform_wrapper: false
|
||||||
- uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40
|
- uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50
|
||||||
- name: Ensure shellcheck is available for linting
|
- name: Ensure shellcheck is available for linting
|
||||||
run: which shellcheck || (sudo apt update && sudo apt install -y shellcheck)
|
run: which shellcheck || (sudo apt update && sudo apt install -y shellcheck)
|
||||||
- name: lint
|
- name: lint
|
||||||
|
|
|
||||||
2
.github/workflows/test-enos-scenario-ui.yml
vendored
2
.github/workflows/test-enos-scenario-ui.yml
vendored
|
|
@ -82,7 +82,7 @@ jobs:
|
||||||
- uses: ./.github/actions/set-up-go
|
- uses: ./.github/actions/set-up-go
|
||||||
with:
|
with:
|
||||||
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
||||||
- uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40
|
- uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50
|
||||||
with:
|
with:
|
||||||
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
||||||
- name: Set Up Git
|
- name: Set Up Git
|
||||||
|
|
|
||||||
4
.github/workflows/test-go.yml
vendored
4
.github/workflows/test-go.yml
vendored
|
|
@ -145,7 +145,7 @@ jobs:
|
||||||
- uses: ./.github/actions/install-tools # for gotestsum
|
- uses: ./.github/actions/install-tools # for gotestsum
|
||||||
- run: mkdir -p ${{ steps.local-metadata.outputs.go-test-dir }}
|
- run: mkdir -p ${{ steps.local-metadata.outputs.go-test-dir }}
|
||||||
- if: inputs.test-timing-cache-restore || inputs.test-timing-cache-save
|
- if: inputs.test-timing-cache-restore || inputs.test-timing-cache-save
|
||||||
uses: actions/cache/restore@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
uses: actions/cache/restore@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||||
with:
|
with:
|
||||||
path: ${{ steps.local-metadata.outputs.go-test-dir }}
|
path: ${{ steps.local-metadata.outputs.go-test-dir }}
|
||||||
key: ${{ inputs.test-timing-cache-key }}-${{ github.run_number }}
|
key: ${{ inputs.test-timing-cache-key }}-${{ github.run_number }}
|
||||||
|
|
@ -647,7 +647,7 @@ jobs:
|
||||||
} | tee -a "$GITHUB_OUTPUT"
|
} | tee -a "$GITHUB_OUTPUT"
|
||||||
# Aggregate, prune, and cache our timing data
|
# Aggregate, prune, and cache our timing data
|
||||||
- if: ${{ ! cancelled() && needs.test-go.result == 'success' && inputs.test-timing-cache-save }}
|
- if: ${{ ! cancelled() && needs.test-go.result == 'success' && inputs.test-timing-cache-save }}
|
||||||
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5.0.2
|
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
|
||||||
with:
|
with:
|
||||||
path: ${{ needs.test-matrix.outputs.go-test-dir }}
|
path: ${{ needs.test-matrix.outputs.go-test-dir }}
|
||||||
key: ${{ inputs.test-timing-cache-key }}-${{ github.run_number }}
|
key: ${{ inputs.test-timing-cache-key }}-${{ github.run_number }}
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ jobs:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.vault-revision }}
|
ref: ${{ inputs.vault-revision }}
|
||||||
- uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40
|
- uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50
|
||||||
with:
|
with:
|
||||||
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
||||||
- uses: ./.github/actions/metadata
|
- uses: ./.github/actions/metadata
|
||||||
|
|
@ -87,7 +87,7 @@ jobs:
|
||||||
# the Terraform wrapper will break Terraform execution in Enos because
|
# the Terraform wrapper will break Terraform execution in Enos because
|
||||||
# it changes the output to text when we expect it to be JSON.
|
# it changes the output to text when we expect it to be JSON.
|
||||||
terraform_wrapper: false
|
terraform_wrapper: false
|
||||||
- uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40
|
- uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50
|
||||||
with:
|
with:
|
||||||
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
||||||
- name: Download Docker Image
|
- name: Download Docker Image
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ jobs:
|
||||||
token: ${{ steps.vault-auth.outputs.token }}
|
token: ${{ steps.vault-auth.outputs.token }}
|
||||||
secrets: |
|
secrets: |
|
||||||
kv/data/github/${{ github.repository }}/github-token token | ELEVATED_GITHUB_TOKEN;
|
kv/data/github/${{ github.repository }}/github-token token | ELEVATED_GITHUB_TOKEN;
|
||||||
- uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40
|
- uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50
|
||||||
with:
|
with:
|
||||||
github-token: ${{ github.repository == 'hashicorp/vault' && secrets.ELEVATED_GITHUB_TOKEN || steps.vault-secrets.outputs.ELEVATED_GITHUB_TOKEN }}
|
github-token: ${{ github.repository == 'hashicorp/vault' && secrets.ELEVATED_GITHUB_TOKEN || steps.vault-secrets.outputs.ELEVATED_GITHUB_TOKEN }}
|
||||||
- uses: ./.github/actions/create-dynamic-config
|
- uses: ./.github/actions/create-dynamic-config
|
||||||
|
|
@ -214,7 +214,7 @@ jobs:
|
||||||
role-to-assume: ${{ steps.secrets.outputs.aws-role-arn }}
|
role-to-assume: ${{ steps.secrets.outputs.aws-role-arn }}
|
||||||
role-skip-session-tagging: true
|
role-skip-session-tagging: true
|
||||||
role-duration-seconds: 3600
|
role-duration-seconds: 3600
|
||||||
- uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40
|
- uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50
|
||||||
with:
|
with:
|
||||||
github-token: ${{ steps.secrets.outputs.github-token }}
|
github-token: ${{ steps.secrets.outputs.github-token }}
|
||||||
- uses: ./.github/actions/create-dynamic-config
|
- uses: ./.github/actions/create-dynamic-config
|
||||||
|
|
|
||||||
2
.github/workflows/test-run-enos-scenario.yml
vendored
2
.github/workflows/test-run-enos-scenario.yml
vendored
|
|
@ -91,7 +91,7 @@ jobs:
|
||||||
role-to-assume: ${{ secrets.AWS_ROLE_ARN_CI }}
|
role-to-assume: ${{ secrets.AWS_ROLE_ARN_CI }}
|
||||||
role-skip-session-tagging: true
|
role-skip-session-tagging: true
|
||||||
role-duration-seconds: 3600
|
role-duration-seconds: 3600
|
||||||
- uses: hashicorp/action-setup-enos@80a17fa25605989a7a53199137dae1244e32353f # v1.40
|
- uses: hashicorp/action-setup-enos@17b90fcf9591275b468a94aefb9dc6a93017de8a # v1.50
|
||||||
with:
|
with:
|
||||||
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
github-token: ${{ secrets.ELEVATED_GITHUB_TOKEN }}
|
||||||
- name: Prepare scenario dependencies
|
- name: Prepare scenario dependencies
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
|
|
@ -113,7 +112,6 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.setupEnt()
|
b.setupEnt()
|
||||||
|
|
||||||
return &b, nil
|
return &b, nil
|
||||||
|
|
@ -124,6 +122,10 @@ type backend struct {
|
||||||
entBackend
|
entBackend
|
||||||
|
|
||||||
lm *keysutil.LockManager
|
lm *keysutil.LockManager
|
||||||
|
// billingDataCounts tracks successful data protection operations
|
||||||
|
// for this backend instance. It's intended for test assertions and avoids
|
||||||
|
// cross-test/package contamination from global counters.
|
||||||
|
billingDataCounts billing.DataProtectionCallCounts
|
||||||
// Lock to make changes to any of the backend's cache configuration.
|
// Lock to make changes to any of the backend's cache configuration.
|
||||||
configMutex sync.RWMutex
|
configMutex sync.RWMutex
|
||||||
cacheSizeChanged bool
|
cacheSizeChanged bool
|
||||||
|
|
@ -148,9 +150,17 @@ func GetCacheSizeFromStorage(ctx context.Context, s logical.Storage) (int, error
|
||||||
return size, nil
|
return size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// incrementDataProtectionCounter atomically increments the data protection call counter to avoid race conditions
|
// incrementBillingCounts atomically increments the transit billing data counts
|
||||||
func (b *backend) incrementDataProtectionCounter(count int64) {
|
func (b *backend) incrementBillingCounts(ctx context.Context, count uint64) error {
|
||||||
atomic.AddInt64(&billing.CurrentDataProtectionCallCounts.Transit, count)
|
// If we are a test, we need to increment this testing structure to verify the counts are correct.
|
||||||
|
if b.billingDataCounts.Transit != nil {
|
||||||
|
b.billingDataCounts.Transit.Add(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write billling data
|
||||||
|
return b.ConsumptionBillingManager.WriteBillingData(ctx, "transit", map[string]interface{}{
|
||||||
|
"count": count,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update cache size and get policy
|
// Update cache size and get policy
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -33,6 +34,7 @@ import (
|
||||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
|
"github.com/hashicorp/vault/vault/billing"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
@ -53,6 +55,9 @@ func createBackendWithStorage(t testing.TB) (*backend, logical.Storage) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
b.billingDataCounts = billing.DataProtectionCallCounts{
|
||||||
|
Transit: &atomic.Uint64{},
|
||||||
|
}
|
||||||
return b, config.StorageView
|
return b, config.StorageView
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -74,6 +79,9 @@ func createBackendWithSysView(t testing.TB) (*backend, logical.Storage) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
b.billingDataCounts = billing.DataProtectionCallCounts{
|
||||||
|
Transit: &atomic.Uint64{},
|
||||||
|
}
|
||||||
|
|
||||||
return b, storage
|
return b, storage
|
||||||
}
|
}
|
||||||
|
|
@ -95,6 +103,9 @@ func createBackendWithSysViewWithStorage(t testing.TB, s logical.Storage) *backe
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
b.billingDataCounts = billing.DataProtectionCallCounts{
|
||||||
|
Transit: &atomic.Uint64{},
|
||||||
|
}
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
@ -117,6 +128,9 @@ func createBackendWithForceNoCacheWithSysViewWithStorage(t testing.TB, s logical
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
b.billingDataCounts = billing.DataProtectionCallCounts{
|
||||||
|
Transit: &atomic.Uint64{},
|
||||||
|
}
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -172,7 +172,9 @@ func (b *backend) pathDatakeyWrite(ctx context.Context, req *logical.Request, d
|
||||||
// Increment the counter for successful operations
|
// Increment the counter for successful operations
|
||||||
// Since there are not batched operations, we can add one successful
|
// Since there are not batched operations, we can add one successful
|
||||||
// request to the transit request counter.
|
// request to the transit request counter.
|
||||||
b.incrementDataProtectionCounter(1)
|
if err = b.incrementBillingCounts(ctx, 1); err != nil {
|
||||||
|
b.Logger().Error("failed to track transit data key request count", "error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/vault/billing"
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
@ -16,9 +15,6 @@ import (
|
||||||
// TestDataKeyWithPaddingScheme validates that we properly leverage padding scheme
|
// TestDataKeyWithPaddingScheme validates that we properly leverage padding scheme
|
||||||
// args for the returned keys
|
// args for the returned keys
|
||||||
func TestDataKeyWithPaddingScheme(t *testing.T) {
|
func TestDataKeyWithPaddingScheme(t *testing.T) {
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
keyName := "test"
|
keyName := "test"
|
||||||
createKeyReq := &logical.Request{
|
createKeyReq := &logical.Request{
|
||||||
|
|
@ -95,15 +91,12 @@ func TestDataKeyWithPaddingScheme(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// We expect 8 successful requests ((3 valid cases x 2 operations) + (2 invalid cases x 1 operation))
|
// We expect 8 successful requests ((3 valid cases x 2 operations) + (2 invalid cases x 1 operation))
|
||||||
require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(8), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDataKeyWithPaddingSchemeInvalidKeyType validates we fail when we specify a
|
// TestDataKeyWithPaddingSchemeInvalidKeyType validates we fail when we specify a
|
||||||
// padding_scheme value on an invalid key type (non-RSA)
|
// padding_scheme value on an invalid key type (non-RSA)
|
||||||
func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) {
|
func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) {
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
keyName := "test"
|
keyName := "test"
|
||||||
createKeyReq := &logical.Request{
|
createKeyReq := &logical.Request{
|
||||||
|
|
@ -132,5 +125,5 @@ func TestDataKeyWithPaddingSchemeInvalidKeyType(t *testing.T) {
|
||||||
require.NotNil(t, resp, "response should not be nil")
|
require.NotNil(t, resp, "response should not be nil")
|
||||||
require.Contains(t, resp.Error().Error(), "padding_scheme argument invalid: unsupported key")
|
require.Contains(t, resp.Error().Error(), "padding_scheme argument invalid: unsupported key")
|
||||||
// We expect 0 successful requests
|
// We expect 0 successful requests
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -278,8 +278,9 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment the counter for successful operations
|
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
|
||||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
b.Logger().Error("failed to track transit decrypt request count", "error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch)
|
return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ import (
|
||||||
|
|
||||||
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
"github.com/hashicorp/vault/sdk/helper/jsonutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/vault/billing"
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
@ -23,9 +22,6 @@ func TestTransit_BatchDecryption(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
batchEncryptionInput := []interface{}{
|
batchEncryptionInput := []interface{}{
|
||||||
map[string]interface{}{"plaintext": "", "reference": "foo"}, // empty string
|
map[string]interface{}{"plaintext": "", "reference": "foo"}, // empty string
|
||||||
map[string]interface{}{"plaintext": "Cg==", "reference": "bar"}, // newline
|
map[string]interface{}{"plaintext": "Cg==", "reference": "bar"}, // newline
|
||||||
|
|
@ -71,12 +67,15 @@ func TestTransit_BatchDecryption(t *testing.T) {
|
||||||
expectedResult := "[{\"plaintext\":\"\",\"reference\":\"foo\"},{\"plaintext\":\"Cg==\",\"reference\":\"bar\"},{\"plaintext\":\"dGhlIHF1aWNrIGJyb3duIGZveA==\",\"reference\":\"baz\"}]"
|
expectedResult := "[{\"plaintext\":\"\",\"reference\":\"foo\"},{\"plaintext\":\"Cg==\",\"reference\":\"bar\"},{\"plaintext\":\"dGhlIHF1aWNrIGJyb3duIGZveA==\",\"reference\":\"baz\"}]"
|
||||||
|
|
||||||
jsonResponse, err := json.Marshal(batchDecryptionResponseItems)
|
jsonResponse, err := json.Marshal(batchDecryptionResponseItems)
|
||||||
if err != nil || err == nil && string(jsonResponse) != expectedResult {
|
if err != nil {
|
||||||
|
t.Fatalf("bad: failed to marshal response items: err=%v json=%s", err, jsonResponse)
|
||||||
|
}
|
||||||
|
if string(jsonResponse) != expectedResult {
|
||||||
t.Fatalf("bad: expected json response [%s]", jsonResponse)
|
t.Fatalf("bad: expected json response [%s]", jsonResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 6 successful requests (3 for batch encryption, 3 for batch decryption)
|
// We expect 6 successful requests (3 for batch encryption, 3 for batch decryption)
|
||||||
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
|
func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
|
||||||
|
|
@ -86,9 +85,6 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
// Create a derived key.
|
// Create a derived key.
|
||||||
req = &logical.Request{
|
req = &logical.Request{
|
||||||
Operation: logical.UpdateOperation,
|
Operation: logical.UpdateOperation,
|
||||||
|
|
@ -291,5 +287,5 @@ func TestTransit_BatchDecryption_DerivedKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 7 successful requests (2 for batch encryption + 1 single-item decryption + 2 batch decryption + 2 batch decryption)
|
// We expect 7 successful requests (2 for batch encryption + 1 single-item decryption + 2 batch decryption + 2 batch decryption)
|
||||||
require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(7), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -650,7 +650,9 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment the counter for successful operations
|
// Increment the counter for successful operations
|
||||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
|
||||||
|
b.Logger().Error("failed to track transit encrypt request count", "error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch)
|
return batchRequestResponse(d, resp, req, successesInBatch, userErrorInBatch, internalErrorInBatch)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ import (
|
||||||
uuid "github.com/hashicorp/go-uuid"
|
uuid "github.com/hashicorp/go-uuid"
|
||||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/vault/billing"
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
@ -24,9 +23,6 @@ func TestTransit_MissingPlaintext(t *testing.T) {
|
||||||
var resp *logical.Response
|
var resp *logical.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Create the policy
|
// Create the policy
|
||||||
|
|
@ -51,16 +47,13 @@ func TestTransit_MissingPlaintext(t *testing.T) {
|
||||||
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
|
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
|
||||||
}
|
}
|
||||||
// We expect 0 successful calls
|
// We expect 0 successful calls
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
|
func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
|
||||||
var resp *logical.Response
|
var resp *logical.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Create the policy
|
// Create the policy
|
||||||
|
|
@ -92,7 +85,7 @@ func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
|
||||||
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
|
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
|
||||||
}
|
}
|
||||||
// We expect 0 successful calls
|
// We expect 0 successful calls
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case1: Ensure that batch encryption did not affect the normal flow of
|
// Case1: Ensure that batch encryption did not affect the normal flow of
|
||||||
|
|
@ -101,9 +94,6 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) {
|
||||||
var resp *logical.Response
|
var resp *logical.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Create the policy
|
// Create the policy
|
||||||
|
|
@ -160,7 +150,7 @@ func TestTransit_BatchEncryptionCase1(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 2 successful requests (1 for encrypt, 1 for decrypt)
|
// We expect 2 successful requests (1 for encrypt, 1 for decrypt)
|
||||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case2: Ensure that batch encryption did not affect the normal flow of
|
// Case2: Ensure that batch encryption did not affect the normal flow of
|
||||||
|
|
@ -170,9 +160,6 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
// Upsert the key and encrypt the data
|
// Upsert the key and encrypt the data
|
||||||
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||||
|
|
||||||
|
|
@ -228,14 +215,12 @@ func TestTransit_BatchEncryptionCase2(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 2 successful requests (1 for encrypt, 1 for decrypt)
|
// We expect 2 successful requests (1 for encrypt, 1 for decrypt)
|
||||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case3: If batch encryption input is not base64 encoded, it should fail.
|
// Case3: If batch encryption input is not base64 encoded, it should fail.
|
||||||
func TestTransit_BatchEncryptionCase3(t *testing.T) {
|
func TestTransit_BatchEncryptionCase3(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
|
|
@ -256,7 +241,7 @@ func TestTransit_BatchEncryptionCase3(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 0 successful requests
|
// We expect 0 successful requests
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case4: Test batch encryption with an existing key (and test references)
|
// Case4: Test batch encryption with an existing key (and test references)
|
||||||
|
|
@ -264,9 +249,6 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) {
|
||||||
var resp *logical.Response
|
var resp *logical.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
policyReq := &logical.Request{
|
policyReq := &logical.Request{
|
||||||
|
|
@ -331,7 +313,7 @@ func TestTransit_BatchEncryptionCase4(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 4 successful requests (2 batch requests + 2 decrypt requests)
|
// We expect 4 successful requests (2 batch requests + 2 decrypt requests)
|
||||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case5: Test batch encryption with an existing derived key
|
// Case5: Test batch encryption with an existing derived key
|
||||||
|
|
@ -339,9 +321,6 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) {
|
||||||
var resp *logical.Response
|
var resp *logical.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
policyData := map[string]interface{}{
|
policyData := map[string]interface{}{
|
||||||
|
|
@ -409,7 +388,7 @@ func TestTransit_BatchEncryptionCase5(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
||||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case6: Test batch encryption with an upserted non-derived key
|
// Case6: Test batch encryption with an upserted non-derived key
|
||||||
|
|
@ -417,9 +396,6 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) {
|
||||||
var resp *logical.Response
|
var resp *logical.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
batchInput := []interface{}{
|
batchInput := []interface{}{
|
||||||
|
|
@ -475,7 +451,7 @@ func TestTransit_BatchEncryptionCase6(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
||||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case7: Test batch encryption with an upserted derived key
|
// Case7: Test batch encryption with an upserted derived key
|
||||||
|
|
@ -485,9 +461,6 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
batchInput := []interface{}{
|
batchInput := []interface{}{
|
||||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
||||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
||||||
|
|
@ -536,7 +509,7 @@ func TestTransit_BatchEncryptionCase7(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
// We expect 4 successful transit requests (2 for batch encryption, 2 for batch decryption)
|
||||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case8: If plaintext is not base64 encoded, encryption should fail
|
// Case8: If plaintext is not base64 encoded, encryption should fail
|
||||||
|
|
@ -546,9 +519,6 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
// Create the policy
|
// Create the policy
|
||||||
policyReq := &logical.Request{
|
policyReq := &logical.Request{
|
||||||
Operation: logical.UpdateOperation,
|
Operation: logical.UpdateOperation,
|
||||||
|
|
@ -594,7 +564,7 @@ func TestTransit_BatchEncryptionCase8(t *testing.T) {
|
||||||
t.Fatal("expected an error")
|
t.Fatal("expected an error")
|
||||||
}
|
}
|
||||||
// We expect 0 successful transit requests
|
// We expect 0 successful transit requests
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case9: If both plaintext and batch inputs are supplied, plaintext should be
|
// Case9: If both plaintext and batch inputs are supplied, plaintext should be
|
||||||
|
|
@ -605,9 +575,6 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
batchInput := []interface{}{
|
batchInput := []interface{}{
|
||||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
||||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
||||||
|
|
@ -634,7 +601,7 @@ func TestTransit_BatchEncryptionCase9(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 2 successful batch encryptions
|
// We expect 2 successful batch encryptions
|
||||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case10: Inconsistent presence of 'context' in batch input should be caught
|
// Case10: Inconsistent presence of 'context' in batch input should be caught
|
||||||
|
|
@ -643,9 +610,6 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
batchInput := []interface{}{
|
batchInput := []interface{}{
|
||||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="},
|
||||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
||||||
|
|
@ -666,7 +630,7 @@ func TestTransit_BatchEncryptionCase10(t *testing.T) {
|
||||||
t.Fatalf("expected an error")
|
t.Fatalf("expected an error")
|
||||||
}
|
}
|
||||||
// We expect no successful transit requests
|
// We expect no successful transit requests
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case11: Incorrect inputs for context and nonce should not fail the operation
|
// Case11: Incorrect inputs for context and nonce should not fail the operation
|
||||||
|
|
@ -675,9 +639,6 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
batchInput := []interface{}{
|
batchInput := []interface{}{
|
||||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "dmlzaGFsCg=="},
|
||||||
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "not-encoded"},
|
map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "context": "not-encoded"},
|
||||||
|
|
@ -697,7 +658,7 @@ func TestTransit_BatchEncryptionCase11(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
// We expect 1 successful encryption out of the 2-item batch
|
// We expect 1 successful encryption out of the 2-item batch
|
||||||
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(1), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case12: Invalid batch input
|
// Case12: Invalid batch input
|
||||||
|
|
@ -705,9 +666,6 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
batchInput := []interface{}{
|
batchInput := []interface{}{
|
||||||
map[string]interface{}{},
|
map[string]interface{}{},
|
||||||
"unexpected_interface",
|
"unexpected_interface",
|
||||||
|
|
@ -727,7 +685,7 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) {
|
||||||
t.Fatalf("expected an error")
|
t.Fatalf("expected an error")
|
||||||
}
|
}
|
||||||
// We expect no successful requests
|
// We expect no successful requests
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case13: Incorrect input for nonce when we aren't in convergent encryption should fail the operation
|
// Case13: Incorrect input for nonce when we aren't in convergent encryption should fail the operation
|
||||||
|
|
@ -736,9 +694,6 @@ func TestTransit_EncryptionCase13(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
// Non-batch first
|
// Non-batch first
|
||||||
data := map[string]interface{}{"plaintext": "bXkgc2VjcmV0IGRhdGE=", "nonce": "R80hr9eNUIuFV52e"}
|
data := map[string]interface{}{"plaintext": "bXkgc2VjcmV0IGRhdGE=", "nonce": "R80hr9eNUIuFV52e"}
|
||||||
req := &logical.Request{
|
req := &logical.Request{
|
||||||
|
|
@ -774,7 +729,7 @@ func TestTransit_EncryptionCase13(t *testing.T) {
|
||||||
t.Fatal("expected request error")
|
t.Fatal("expected request error")
|
||||||
}
|
}
|
||||||
// We expect no successful transit requests
|
// We expect no successful transit requests
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case14: Incorrect input for nonce when we are in convergent version 3 should fail
|
// Case14: Incorrect input for nonce when we are in convergent version 3 should fail
|
||||||
|
|
@ -783,9 +738,6 @@ func TestTransit_EncryptionCase14(t *testing.T) {
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
cReq := &logical.Request{
|
cReq := &logical.Request{
|
||||||
Operation: logical.UpdateOperation,
|
Operation: logical.UpdateOperation,
|
||||||
Path: "keys/my-key",
|
Path: "keys/my-key",
|
||||||
|
|
@ -836,7 +788,7 @@ func TestTransit_EncryptionCase14(t *testing.T) {
|
||||||
t.Fatal("expected request error")
|
t.Fatal("expected request error")
|
||||||
}
|
}
|
||||||
// We expect no successful transit requests
|
// We expect no successful transit requests
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem.
|
// Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem.
|
||||||
|
|
|
||||||
|
|
@ -254,8 +254,9 @@ func (b *backend) pathHMACWrite(ctx context.Context, req *logical.Request, d *fr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment the counter for successful operations
|
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
|
||||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
b.Logger().Error("failed to track transit hmac request count", "error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
@ -432,8 +433,9 @@ func (b *backend) pathHMACVerify(ctx context.Context, req *logical.Request, d *f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment the counter for successful operations
|
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
|
||||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
b.Logger().Error("failed to track transit hmac verify request count", "error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,10 @@ import (
|
||||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/vault"
|
"github.com/hashicorp/vault/vault"
|
||||||
"github.com/hashicorp/vault/vault/billing"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTransit_HMAC(t *testing.T) {
|
func TestTransit_HMAC(t *testing.T) {
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, storage := createBackendWithSysView(t)
|
b, storage := createBackendWithSysView(t)
|
||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
|
|
@ -249,13 +245,10 @@ func TestTransit_HMAC(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Verify the total successful transit requests
|
// Verify the total successful transit requests
|
||||||
require.Equal(t, int64(72), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(72), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransit_batchHMAC(t *testing.T) {
|
func TestTransit_batchHMAC(t *testing.T) {
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, storage := createBackendWithSysView(t)
|
b, storage := createBackendWithSysView(t)
|
||||||
|
|
||||||
// First create a key
|
// First create a key
|
||||||
|
|
@ -411,7 +404,7 @@ func TestTransit_batchHMAC(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the total successful transit requests
|
// Verify the total successful transit requests
|
||||||
require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(5), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestHMACBatchResultsFields checks that responses to HMAC verify requests using batch_input
|
// TestHMACBatchResultsFields checks that responses to HMAC verify requests using batch_input
|
||||||
|
|
@ -440,9 +433,6 @@ func TestHMACBatchResultsFields(t *testing.T) {
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
keyName := "hmac-test-key"
|
keyName := "hmac-test-key"
|
||||||
_, err = client.Logical().Write("transit/keys/"+keyName, map[string]interface{}{"type": "hmac", "key_size": 32})
|
_, err = client.Logical().Write("transit/keys/"+keyName, map[string]interface{}{"type": "hmac", "key_size": 32})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
@ -505,5 +495,5 @@ func TestHMACBatchResultsFields(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 4 successful requests (2 for batch HMAC generation, 2 for batch verification)
|
// We expect 4 successful requests (2 for batch HMAC generation, 2 for batch verification)
|
||||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(4), cores[0].GetInMemoryTransitDataProtectionCallCounts())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -308,8 +308,9 @@ func (b *backend) pathRewrapWrite(ctx context.Context, req *logical.Request, d *
|
||||||
resp.AddWarning("A provided nonce value was used within FIPS mode, this violates FIPS 140 compliance.")
|
resp.AddWarning("A provided nonce value was used within FIPS mode, this violates FIPS 140 compliance.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment the counter for successful operations
|
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
|
||||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
b.Logger().Error("failed to track transit rewrap request count", "error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/vault/billing"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -19,9 +18,6 @@ func TestTransit_BatchRewrapCase1(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
// Upsert the key and encrypt the data
|
// Upsert the key and encrypt the data
|
||||||
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||||
|
|
||||||
|
|
@ -119,7 +115,7 @@ func TestTransit_BatchRewrapCase1(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect 2 successful requests (1 for encrypt, 1 for rewrap)
|
// We expect 2 successful requests (1 for encrypt, 1 for rewrap)
|
||||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the normal flow of rewrap with upserted key
|
// Check the normal flow of rewrap with upserted key
|
||||||
|
|
@ -128,9 +124,6 @@ func TestTransit_BatchRewrapCase2(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
// Upsert the key and encrypt the data
|
// Upsert the key and encrypt the data
|
||||||
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
plaintext := "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||||
|
|
||||||
|
|
@ -229,7 +222,7 @@ func TestTransit_BatchRewrapCase2(t *testing.T) {
|
||||||
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2)
|
t.Fatalf("unexpected key version; got: %d, expected: %d", keyVersion, 2)
|
||||||
}
|
}
|
||||||
// We expect 2 successful transit requests (1 for encrypt, 1 for rewrap)
|
// We expect 2 successful transit requests (1 for encrypt, 1 for rewrap)
|
||||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(2), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch encrypt plaintexts, rotate the keys and rewrap all the ciphertexts
|
// Batch encrypt plaintexts, rotate the keys and rewrap all the ciphertexts
|
||||||
|
|
@ -237,9 +230,6 @@ func TestTransit_BatchRewrapCase3(t *testing.T) {
|
||||||
var resp *logical.Response
|
var resp *logical.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
batchEncryptionInput := []interface{}{
|
batchEncryptionInput := []interface{}{
|
||||||
|
|
@ -341,7 +331,7 @@ func TestTransit_BatchRewrapCase3(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// We expect 6 successful transit requests (2 for batch encryption, 2 for batch rewrap, and 2 for decryption)
|
// We expect 6 successful transit requests (2 for batch encryption, 2 for batch rewrap, and 2 for decryption)
|
||||||
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestTransit_BatchRewrapCase4 batch rewrap leveraging RSA padding schemes
|
// TestTransit_BatchRewrapCase4 batch rewrap leveraging RSA padding schemes
|
||||||
|
|
@ -349,9 +339,6 @@ func TestTransit_BatchRewrapCase4(t *testing.T) {
|
||||||
var resp *logical.Response
|
var resp *logical.Response
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, s := createBackendWithStorage(t)
|
b, s := createBackendWithStorage(t)
|
||||||
|
|
||||||
batchEncryptionInput := []interface{}{
|
batchEncryptionInput := []interface{}{
|
||||||
|
|
@ -459,5 +446,5 @@ func TestTransit_BatchRewrapCase4(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// We expect 6 succcessful calls to the transit backend (2 for batch encryption, 2 for batch decryption, and 2 for batch rewrap)
|
// We expect 6 succcessful calls to the transit backend (2 for batch encryption, 2 for batch decryption, and 2 for batch rewrap)
|
||||||
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(6), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -492,8 +492,9 @@ func (b *backend) pathSignWrite(ctx context.Context, req *logical.Request, d *fr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment the counter for successful operations
|
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
|
||||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
b.Logger().Error("failed to track transit sign request count", "error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
@ -750,8 +751,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment the counter for successful operations
|
if err = b.incrementBillingCounts(ctx, uint64(successfulRequests)); err != nil {
|
||||||
b.incrementDataProtectionCounter(int64(successfulRequests))
|
b.Logger().Error("failed to track transit sign verify request count", "error", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"github.com/hashicorp/vault/helper/constants"
|
"github.com/hashicorp/vault/helper/constants"
|
||||||
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
"github.com/hashicorp/vault/sdk/helper/keysutil"
|
||||||
"github.com/hashicorp/vault/sdk/logical"
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
"github.com/hashicorp/vault/vault/billing"
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/ed25519"
|
"golang.org/x/crypto/ed25519"
|
||||||
|
|
@ -33,9 +32,6 @@ type signOutcome struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransit_SignVerify_ECDSA(t *testing.T) {
|
func TestTransit_SignVerify_ECDSA(t *testing.T) {
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
t.Run("256", func(t *testing.T) {
|
t.Run("256", func(t *testing.T) {
|
||||||
testTransit_SignVerify_ECDSA(t, 256)
|
testTransit_SignVerify_ECDSA(t, 256)
|
||||||
})
|
})
|
||||||
|
|
@ -45,9 +41,6 @@ func TestTransit_SignVerify_ECDSA(t *testing.T) {
|
||||||
t.Run("521", func(t *testing.T) {
|
t.Run("521", func(t *testing.T) {
|
||||||
testTransit_SignVerify_ECDSA(t, 521)
|
testTransit_SignVerify_ECDSA(t, 521)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Verify the total successful transit requests
|
|
||||||
require.Equal(t, int64(84), billing.CurrentDataProtectionCallCounts.Transit)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testTransit_SignVerify_ECDSA(t *testing.T, bits int) {
|
func testTransit_SignVerify_ECDSA(t *testing.T, bits int) {
|
||||||
|
|
@ -330,6 +323,7 @@ func testTransit_SignVerify_ECDSA(t *testing.T, bits int) {
|
||||||
verifyRequest(req, false, "", sig)
|
verifyRequest(req, false, "", sig)
|
||||||
// Now try the v1
|
// Now try the v1
|
||||||
verifyRequest(req, true, "", v1sig)
|
verifyRequest(req, true, "", v1sig)
|
||||||
|
require.Equal(t, uint64(28), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, expectValid bool, postpath string, b *backend) {
|
func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, expectValid bool, postpath string, b *backend) {
|
||||||
|
|
@ -380,9 +374,6 @@ func validatePublicKey(t *testing.T, in string, sig string, pubKeyRaw []byte, ex
|
||||||
// TestTransit_SignVerify_Ed25519Behavior makes sure the options on ENT for a
|
// TestTransit_SignVerify_Ed25519Behavior makes sure the options on ENT for a
|
||||||
// Ed25519ph/ctx signature fail on CE and ENT if invalid
|
// Ed25519ph/ctx signature fail on CE and ENT if invalid
|
||||||
func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) {
|
func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) {
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, storage := createBackendWithSysView(t)
|
b, storage := createBackendWithSysView(t)
|
||||||
|
|
||||||
// First create a key
|
// First create a key
|
||||||
|
|
@ -477,17 +468,14 @@ func TestTransit_SignVerify_Ed25519Behavior(t *testing.T) {
|
||||||
}
|
}
|
||||||
// Verify the total successful transit requests
|
// Verify the total successful transit requests
|
||||||
if constants.IsEnterprise {
|
if constants.IsEnterprise {
|
||||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(4), b.billingDataCounts.Transit.Load())
|
||||||
} else {
|
} else {
|
||||||
// We expect 0 successful calls on CE because we expect the verify to fail
|
// We expect 0 successful calls on CE because we expect the verify to fail
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(0), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransit_SignVerify_ED25519(t *testing.T) {
|
func TestTransit_SignVerify_ED25519(t *testing.T) {
|
||||||
// Reset the transit counter
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
b, storage := createBackendWithSysView(t)
|
b, storage := createBackendWithSysView(t)
|
||||||
|
|
||||||
// First create a key
|
// First create a key
|
||||||
|
|
@ -832,7 +820,7 @@ func TestTransit_SignVerify_ED25519(t *testing.T) {
|
||||||
verifyRequest(req, false, outcome, "bar", goodsig, true)
|
verifyRequest(req, false, outcome, "bar", goodsig, true)
|
||||||
|
|
||||||
// Verify the total successful transit requests
|
// Verify the total successful transit requests
|
||||||
require.Equal(t, int64(24), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(24), b.billingDataCounts.Transit.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTransit_SignVerify_RSA_PSS(t *testing.T) {
|
func TestTransit_SignVerify_RSA_PSS(t *testing.T) {
|
||||||
|
|
|
||||||
3
changelog/_12142.txt
Normal file
3
changelog/_12142.txt
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
```release-note:bug
|
||||||
|
ui: Fixes login form so `?with=<path>` query param correctly displays only the specified mount when multiple mounts of the same auth type are configured with `listing_visibility="unauth"`
|
||||||
|
```
|
||||||
|
|
@ -118,6 +118,9 @@ type Backend struct {
|
||||||
// communicate with a plugin to activate a feature.
|
// communicate with a plugin to activate a feature.
|
||||||
ActivationFunc func(context.Context, *logical.Request, string) error
|
ActivationFunc func(context.Context, *logical.Request, string) error
|
||||||
|
|
||||||
|
// ConsumptionBillingManager is the consumption billing manager the backend can use to write billing data.
|
||||||
|
ConsumptionBillingManager logical.ConsumptionBillingManager
|
||||||
|
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
system logical.SystemView
|
system logical.SystemView
|
||||||
events logical.EventSender
|
events logical.EventSender
|
||||||
|
|
@ -440,6 +443,11 @@ func (b *Backend) Setup(ctx context.Context, config *logical.BackendConfig) erro
|
||||||
b.system = config.System
|
b.system = config.System
|
||||||
b.events = config.EventsSender
|
b.events = config.EventsSender
|
||||||
b.observations = config.ObservationRecorder
|
b.observations = config.ObservationRecorder
|
||||||
|
if b.System() != nil && b.System().GetConsumptionBillingManager() != nil {
|
||||||
|
b.ConsumptionBillingManager = b.System().GetConsumptionBillingManager()
|
||||||
|
} else {
|
||||||
|
b.ConsumptionBillingManager = logical.NewNullConsumptionBillingManager()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -546,7 +554,7 @@ func (b *Backend) init() {
|
||||||
for i, p := range b.Paths {
|
for i, p := range b.Paths {
|
||||||
// Detect the coding error of failing to initialise Pattern
|
// Detect the coding error of failing to initialise Pattern
|
||||||
if len(p.Pattern) == 0 {
|
if len(p.Pattern) == 0 {
|
||||||
panic(fmt.Sprintf("Routing pattern cannot be blank"))
|
panic("Routing pattern cannot be blank")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Detect the coding error of attempting to define a CreateOperation without defining an ExistenceCheck
|
// Detect the coding error of attempting to define a CreateOperation without defining an ExistenceCheck
|
||||||
|
|
|
||||||
25
sdk/logical/billing_system_view.go
Normal file
25
sdk/logical/billing_system_view.go
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
// Copyright IBM Corp. 2016, 2025
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package logical
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// billing.ConsumptionBillingManager is an implementation of this interface that the backend can use to write billing data.
|
||||||
|
type ConsumptionBillingManager interface {
|
||||||
|
WriteBillingData(ctx context.Context, pluginType string, data map[string]interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ================================
|
||||||
|
// Creates a null consumption billing manager that does nothing
|
||||||
|
var _ ConsumptionBillingManager = (*nullConsumptionBillingManager)(nil)
|
||||||
|
|
||||||
|
func NewNullConsumptionBillingManager() ConsumptionBillingManager {
|
||||||
|
return &nullConsumptionBillingManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type nullConsumptionBillingManager struct{}
|
||||||
|
|
||||||
|
func (n *nullConsumptionBillingManager) WriteBillingData(ctx context.Context, pluginType string, data map[string]interface{}) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -117,6 +117,9 @@ type SystemView interface {
|
||||||
|
|
||||||
// ExtractVerifyPlugin extracts and verifies the plugin artifact
|
// ExtractVerifyPlugin extracts and verifies the plugin artifact
|
||||||
DownloadExtractVerifyPlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error
|
DownloadExtractVerifyPlugin(ctx context.Context, plugin *pluginutil.PluginRunner) error
|
||||||
|
|
||||||
|
// GetConsumptionBillingManager returns the consumption billing manager
|
||||||
|
GetConsumptionBillingManager() ConsumptionBillingManager
|
||||||
}
|
}
|
||||||
|
|
||||||
type PasswordPolicy interface {
|
type PasswordPolicy interface {
|
||||||
|
|
@ -320,6 +323,10 @@ func (d StaticSystemView) DownloadExtractVerifyPlugin(_ context.Context, _ *plug
|
||||||
return errors.New("DownloadExtractVerifyPlugin is not implemented in StaticSystemView")
|
return errors.New("DownloadExtractVerifyPlugin is not implemented in StaticSystemView")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d StaticSystemView) GetConsumptionBillingManager() ConsumptionBillingManager {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PluginLicenseUtil defines the functions needed to request License and PluginEnv
|
// PluginLicenseUtil defines the functions needed to request License and PluginEnv
|
||||||
// by the plugin licensing under github.com/hashicorp/vault-licensing
|
// by the plugin licensing under github.com/hashicorp/vault-licensing
|
||||||
// This only should be used by the plugin to get the license and plugin environment
|
// This only should be used by the plugin to get the license and plugin environment
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,11 @@ type gRPCSystemViewClient struct {
|
||||||
client pb.SystemViewClient
|
client pb.SystemViewClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *gRPCSystemViewClient) GetConsumptionBillingManager() logical.ConsumptionBillingManager {
|
||||||
|
// Not implemented on pluginbackend
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *gRPCSystemViewClient) DefaultLeaseTTL() time.Duration {
|
func (s *gRPCSystemViewClient) DefaultLeaseTTL() time.Duration {
|
||||||
reply, err := s.client.DefaultLeaseTTL(context.Background(), &pb.Empty{})
|
reply, err := s.client.DefaultLeaseTTL(context.Background(), &pb.Empty{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -58,17 +58,6 @@ import type { Task } from 'ember-concurrency';
|
||||||
* 🔀 Multiple visible mounts:
|
* 🔀 Multiple visible mounts:
|
||||||
* ▸ Path dropdown is shown.
|
* ▸ Path dropdown is shown.
|
||||||
*
|
*
|
||||||
* @example
|
|
||||||
* <Auth::Page
|
|
||||||
* @cluster={{this.model.clusterModel}}
|
|
||||||
* @directLinkData={{this.model.directLinkData}}
|
|
||||||
* @loginSettings={{this.model.loginSettings}}
|
|
||||||
* @namespaceQueryParam={{this.namespaceQueryParam}}
|
|
||||||
* @oidcProviderQueryParam={{this.oidcProvider}}
|
|
||||||
* @loginAndTransition={{this.loginAndTransition}}
|
|
||||||
* @onNamespaceUpdate={{perform this.updateNamespace}}
|
|
||||||
* @visibleAuthMounts={{this.model.visibleAuthMounts}}
|
|
||||||
* />
|
|
||||||
*
|
*
|
||||||
* @param {object} cluster - the ember data cluster model. contains information such as cluster id, name and boolean for if the cluster is in standby
|
* @param {object} cluster - the ember data cluster model. contains information such as cluster id, name and boolean for if the cluster is in standby
|
||||||
* @param {object} directLinkData - mount data built from the "with" query param. If param is a mount path and maps to a visible mount, the login form defaults to this mount. Otherwise the form preselects the passed auth type.
|
* @param {object} directLinkData - mount data built from the "with" query param. If param is a mount path and maps to a visible mount, the login form defaults to this mount. Otherwise the form preselects the passed auth type.
|
||||||
|
|
@ -161,10 +150,14 @@ export default class AuthPage extends Component<Args> {
|
||||||
get directLinkViews() {
|
get directLinkViews() {
|
||||||
const { directLinkData } = this.args;
|
const { directLinkData } = this.args;
|
||||||
|
|
||||||
// If "path" key exists we know the "with" query param references a mount with listing_visibility="unauth"
|
// If "path" key exists then the "with" query param references a specific mount with listing_visibility="unauth".
|
||||||
// Treat it as a preferred method and hide all other tabs.
|
// Show only this mount and hide any others for this auth type as well as any tabs for different auth types.
|
||||||
if (directLinkData?.path) {
|
if (directLinkData?.path) {
|
||||||
const tabData = this.filterVisibleMountsByType([directLinkData.type]);
|
const mounts = this.mountsByType(directLinkData.type);
|
||||||
|
const selectedMount = mounts?.find((m) => m.path === directLinkData?.path);
|
||||||
|
const tabData: UnauthMountsByType = {
|
||||||
|
[directLinkData.type]: selectedMount ? [selectedMount] : null,
|
||||||
|
};
|
||||||
const defaultView = this.constructViews(FormView.TABS, tabData);
|
const defaultView = this.constructViews(FormView.TABS, tabData);
|
||||||
const alternateView = this.constructViews(FormView.DROPDOWN, null);
|
const alternateView = this.constructViews(FormView.DROPDOWN, null);
|
||||||
|
|
||||||
|
|
@ -263,11 +256,16 @@ export default class AuthPage extends Component<Args> {
|
||||||
const tabs: UnauthMountsByType = {};
|
const tabs: UnauthMountsByType = {};
|
||||||
for (const type of authTypes) {
|
for (const type of authTypes) {
|
||||||
// adds visible mounts for each type, if they exist
|
// adds visible mounts for each type, if they exist
|
||||||
tabs[type] = this.visibleMountsByType?.[type] || null;
|
tabs[type] = this.mountsByType(type);
|
||||||
}
|
}
|
||||||
return tabs;
|
return tabs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private mountsByType(type: string) {
|
||||||
|
// Return null and not an empty array to distinguish between "dropdown mode" and "tabs with no mounts" in downstream components
|
||||||
|
return this.visibleMountsByType?.[type] || null;
|
||||||
|
}
|
||||||
|
|
||||||
private constructViews(view: FormView, tabData: UnauthMountsByType | null) {
|
private constructViews(view: FormView, tabData: UnauthMountsByType | null) {
|
||||||
return { view, tabData };
|
return { view, tabData };
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,6 @@ export default class AuthRoute extends Route {
|
||||||
const { default_auth_type, backup_auth_types } = response.data;
|
const { default_auth_type, backup_auth_types } = response.data;
|
||||||
return {
|
return {
|
||||||
defaultType: default_auth_type,
|
defaultType: default_auth_type,
|
||||||
// TODO WIP backend PR consistently return empty array when no backup_auth_types
|
|
||||||
backupTypes: backup_auth_types?.length ? backup_auth_types : null,
|
backupTypes: backup_auth_types?.length ? backup_auth_types : null,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -132,6 +132,20 @@ module('Integration | Component | auth | page | listing visibility', function (h
|
||||||
assert.dom(GENERAL.backButton).doesNotExist();
|
assert.dom(GENERAL.backButton).doesNotExist();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('it treats direct link as only mount when multiple mounts are tuned with listing_visibility="unauth"', async function (assert) {
|
||||||
|
this.directLinkData = { path: 'userpass2/', type: 'userpass' };
|
||||||
|
await this.renderComponent();
|
||||||
|
assert.dom(AUTH_FORM.authForm('userpass')).exists;
|
||||||
|
assert.dom(AUTH_FORM.tabBtn('userpass')).hasText('Userpass', 'it renders tab for type');
|
||||||
|
assert.dom(AUTH_FORM.tabs).exists({ count: 1 }, 'only one tab renders');
|
||||||
|
assert.dom(GENERAL.inputByAttr('path')).hasAttribute('type', 'hidden');
|
||||||
|
assert.dom(GENERAL.inputByAttr('path')).hasValue('userpass2/');
|
||||||
|
assert.dom(GENERAL.button('Sign in with other methods')).exists('"Sign in with other methods" renders');
|
||||||
|
assert.dom(GENERAL.selectByAttr('auth type')).doesNotExist();
|
||||||
|
assert.dom(AUTH_FORM.advancedSettings).doesNotExist();
|
||||||
|
assert.dom(GENERAL.backButton).doesNotExist();
|
||||||
|
});
|
||||||
|
|
||||||
test('it prioritizes auth type from canceled mfa instead of direct link for path', async function (assert) {
|
test('it prioritizes auth type from canceled mfa instead of direct link for path', async function (assert) {
|
||||||
assert.expect(1);
|
assert.expect(1);
|
||||||
this.directLinkData = this.directLinkIsVisibleMount; // type is "oidc"
|
this.directLinkData = this.directLinkIsVisibleMount; // type is "oidc"
|
||||||
|
|
|
||||||
|
|
@ -4,19 +4,24 @@
|
||||||
package billing
|
package billing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
log "github.com/hashicorp/go-hclog"
|
||||||
|
"github.com/hashicorp/vault/sdk/logical"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
BillingSubPath = "billing/"
|
BillingSubPath = "billing/"
|
||||||
ReplicatedPrefix = "replicated/"
|
ReplicatedPrefix = "replicated/"
|
||||||
RoleHWMCountsHWM = "maxRoleCounts/"
|
RoleHWMCountsHWM = "maxRoleCounts/"
|
||||||
KvHWMCountsHWM = "maxKvCounts/"
|
KvHWMCountsHWM = "maxKvCounts/"
|
||||||
DataProtectionCallCountsMetric = "dataProtectionCallCounts/"
|
TransitDataProtectionCallCountsPrefix = "transitDataProtectionCallCounts/"
|
||||||
LocalPrefix = "local/"
|
LocalPrefix = "local/"
|
||||||
ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/"
|
ThirdPartyPluginsPrefix = "thirdPartyPluginCounts/"
|
||||||
|
|
||||||
BillingWriteInterval = 10 * time.Minute
|
BillingWriteInterval = 10 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
@ -27,7 +32,9 @@ type ConsumptionBilling struct {
|
||||||
// BillingStorageLock controls access to the billing storage paths
|
// BillingStorageLock controls access to the billing storage paths
|
||||||
BillingStorageLock sync.RWMutex
|
BillingStorageLock sync.RWMutex
|
||||||
|
|
||||||
BillingConfig BillingConfig
|
BillingConfig BillingConfig
|
||||||
|
DataProtectionCallCounts DataProtectionCallCounts
|
||||||
|
Logger log.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
type BillingConfig struct {
|
type BillingConfig struct {
|
||||||
|
|
@ -45,10 +52,26 @@ func GetMonthlyBillingPath(localPrefix string, now time.Time, billingMetric stri
|
||||||
}
|
}
|
||||||
|
|
||||||
type DataProtectionCallCounts struct {
|
type DataProtectionCallCounts struct {
|
||||||
Transit int64 `json:"transit,omitempty"`
|
Transit *atomic.Uint64 `json:"transit,omitempty"`
|
||||||
// TODO: Uncomment when we add support for Transform tracking (VAULT-41205)
|
// TODO: Uncomment when we add support for Transform tracking (VAULT-41205)
|
||||||
// Transform int64 `json:"transform,omitempty"`
|
// Transform atomic.int64 `json:"transform,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global counter for all data protection calls on this cluster
|
var _ logical.ConsumptionBillingManager = (*ConsumptionBilling)(nil)
|
||||||
var CurrentDataProtectionCallCounts = DataProtectionCallCounts{}
|
|
||||||
|
func (s *ConsumptionBilling) WriteBillingData(ctx context.Context, mountType string, data map[string]interface{}) error {
|
||||||
|
switch mountType {
|
||||||
|
case "transit":
|
||||||
|
val, ok := data["count"].(uint64)
|
||||||
|
if !ok {
|
||||||
|
err := fmt.Errorf("invalid value type for transit")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.DataProtectionCallCounts.Transit.Add(val)
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("unknown metric type: %s", mountType)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ package vault
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/vault/helper/timeutil"
|
"github.com/hashicorp/vault/helper/timeutil"
|
||||||
|
|
@ -15,8 +17,14 @@ func (c *Core) setupConsumptionBilling(ctx context.Context) error {
|
||||||
// We need replication (post unseal) to start before we run the consumption billing metrics worker
|
// We need replication (post unseal) to start before we run the consumption billing metrics worker
|
||||||
// This is because there is primary/secondary cluster specific logic
|
// This is because there is primary/secondary cluster specific logic
|
||||||
c.consumptionBillingLock.Lock()
|
c.consumptionBillingLock.Lock()
|
||||||
|
logger := c.baseLogger.Named("billing")
|
||||||
|
c.AddLogger(logger)
|
||||||
c.consumptionBilling = &billing.ConsumptionBilling{
|
c.consumptionBilling = &billing.ConsumptionBilling{
|
||||||
BillingConfig: c.billingConfig,
|
BillingConfig: c.billingConfig,
|
||||||
|
DataProtectionCallCounts: billing.DataProtectionCallCounts{
|
||||||
|
Transit: &atomic.Uint64{},
|
||||||
|
},
|
||||||
|
Logger: logger,
|
||||||
}
|
}
|
||||||
c.consumptionBillingLock.Unlock()
|
c.consumptionBillingLock.Unlock()
|
||||||
c.postUnsealFuncs = append(c.postUnsealFuncs, func() {
|
c.postUnsealFuncs = append(c.postUnsealFuncs, func() {
|
||||||
|
|
@ -72,7 +80,12 @@ func (c *Core) updateBillingMetrics(ctx context.Context) error {
|
||||||
c.UpdateReplicatedHWMMetrics(ctx, currentMonth)
|
c.UpdateReplicatedHWMMetrics(ctx, currentMonth)
|
||||||
}
|
}
|
||||||
c.UpdateLocalHWMMetrics(ctx, currentMonth)
|
c.UpdateLocalHWMMetrics(ctx, currentMonth)
|
||||||
c.UpdateLocalAggregatedMetrics(ctx, currentMonth)
|
if err := c.UpdateLocalAggregatedMetrics(ctx, currentMonth); err != nil {
|
||||||
|
c.logger.Error("error updating cluster data protection call counts", "error", err)
|
||||||
|
} else {
|
||||||
|
c.logger.Info("updated cluster data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -119,10 +132,7 @@ func (c *Core) UpdateLocalHWMMetrics(ctx context.Context, currentMonth time.Time
|
||||||
|
|
||||||
func (c *Core) UpdateLocalAggregatedMetrics(ctx context.Context, currentMonth time.Time) error {
|
func (c *Core) UpdateLocalAggregatedMetrics(ctx context.Context, currentMonth time.Time) error {
|
||||||
if _, err := c.UpdateDataProtectionCallCounts(ctx, currentMonth); err != nil {
|
if _, err := c.UpdateDataProtectionCallCounts(ctx, currentMonth); err != nil {
|
||||||
c.logger.Error("error updating local max data protection call counts", "error", err)
|
return fmt.Errorf("could not store transit data protection call counts: %w", err)
|
||||||
} else {
|
|
||||||
c.logger.Info("updated local max data protection call counts", "prefix", billing.LocalPrefix, "currentMonth", currentMonth)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
12
vault/consumption_billing_testing_util.go
Normal file
12
vault/consumption_billing_testing_util.go
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
// Copyright IBM Corp. 2016, 2025
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package vault
|
||||||
|
|
||||||
|
func (c *Core) ResetInMemoryDataProtectionCallCounts() {
|
||||||
|
c.consumptionBilling.DataProtectionCallCounts.Transit.Store(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Core) GetInMemoryTransitDataProtectionCallCounts() uint64 {
|
||||||
|
return c.consumptionBilling.DataProtectionCallCounts.Transit.Load()
|
||||||
|
}
|
||||||
|
|
@ -267,29 +267,36 @@ func (c *Core) GetBillingSubView() *BarrierView {
|
||||||
|
|
||||||
// storeDataProtectionCallCountsLocked must be called with BillingStorageLock held
|
// storeDataProtectionCallCountsLocked must be called with BillingStorageLock held
|
||||||
func (c *Core) storeDataProtectionCallCountsLocked(ctx context.Context, maxCounts *billing.DataProtectionCallCounts, localPathPrefix string, month time.Time) error {
|
func (c *Core) storeDataProtectionCallCountsLocked(ctx context.Context, maxCounts *billing.DataProtectionCallCounts, localPathPrefix string, month time.Time) error {
|
||||||
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric)
|
// Store count for each data protection type separately because they are atomic counters
|
||||||
entry, err := logical.StorageEntryJSON(billingPath, maxCounts)
|
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.TransitDataProtectionCallCountsPrefix)
|
||||||
if err != nil {
|
transitCount := maxCounts.Transit.Load()
|
||||||
return err
|
entry := &logical.StorageEntry{
|
||||||
|
Key: billingPath,
|
||||||
|
Value: []byte(strconv.FormatUint(transitCount, 10)),
|
||||||
}
|
}
|
||||||
return c.GetBillingSubView().Put(ctx, entry)
|
return c.GetBillingSubView().Put(ctx, entry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getStoredDataProtectionCallCountsLocked must be called with BillingStorageLock held
|
// getStoredDataProtectionCallCountsLocked must be called with BillingStorageLock held
|
||||||
func (c *Core) getStoredDataProtectionCallCountsLocked(ctx context.Context, localPathPrefix string, month time.Time) (*billing.DataProtectionCallCounts, error) {
|
func (c *Core) getStoredDataProtectionCallCountsLocked(ctx context.Context, localPathPrefix string, month time.Time) (*billing.DataProtectionCallCounts, error) {
|
||||||
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.DataProtectionCallCountsMetric)
|
// Retrieve count for each data protection type separately because they are atomic counters
|
||||||
|
ret := &billing.DataProtectionCallCounts{
|
||||||
|
Transit: &atomic.Uint64{},
|
||||||
|
}
|
||||||
|
billingPath := billing.GetMonthlyBillingPath(localPathPrefix, month, billing.TransitDataProtectionCallCountsPrefix)
|
||||||
entry, err := c.GetBillingSubView().Get(ctx, billingPath)
|
entry, err := c.GetBillingSubView().Get(ctx, billingPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
return &billing.DataProtectionCallCounts{}, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
var maxCounts billing.DataProtectionCallCounts
|
transitCount, err := strconv.ParseUint(string(entry.Value), 10, 64)
|
||||||
if err := entry.DecodeJSON(&maxCounts); err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &maxCounts, nil
|
ret.Transit.Store(transitCount)
|
||||||
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Core) GetStoredDataProtectionCallCounts(ctx context.Context, month time.Time) (*billing.DataProtectionCallCounts, error) {
|
func (c *Core) GetStoredDataProtectionCallCounts(ctx context.Context, month time.Time) (*billing.DataProtectionCallCounts, error) {
|
||||||
|
|
@ -302,10 +309,6 @@ func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth
|
||||||
c.consumptionBilling.BillingStorageLock.Lock()
|
c.consumptionBilling.BillingStorageLock.Lock()
|
||||||
defer c.consumptionBilling.BillingStorageLock.Unlock()
|
defer c.consumptionBilling.BillingStorageLock.Unlock()
|
||||||
|
|
||||||
// Read and atomically reset the current counter value
|
|
||||||
// TODO: Reset Transform call counts too (VAULT-41205)
|
|
||||||
currentTransitCount := atomic.SwapInt64(&billing.CurrentDataProtectionCallCounts.Transit, 0)
|
|
||||||
|
|
||||||
storedDataProtectionCallCounts, err := c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, currentMonth)
|
storedDataProtectionCallCounts, err := c.getStoredDataProtectionCallCountsLocked(ctx, billing.LocalPrefix, currentMonth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -315,7 +318,8 @@ func (c *Core) UpdateDataProtectionCallCounts(ctx context.Context, currentMonth
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sum the current count with the stored count
|
// Sum the current count with the stored count
|
||||||
storedDataProtectionCallCounts.Transit += currentTransitCount
|
transitCount := c.consumptionBilling.DataProtectionCallCounts.Transit.Swap(0)
|
||||||
|
storedDataProtectionCallCounts.Transit.Add(transitCount)
|
||||||
// TODO: Update Transform call counts (VAULT-41205)
|
// TODO: Update Transform call counts (VAULT-41205)
|
||||||
|
|
||||||
err = c.storeDataProtectionCallCountsLocked(ctx, storedDataProtectionCallCounts, billing.LocalPrefix, currentMonth)
|
err = c.storeDataProtectionCallCountsLocked(ctx, storedDataProtectionCallCounts, billing.LocalPrefix, currentMonth)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ func verifyExpectedRoleCounts(t *testing.T, actual *RoleCounts, baseCount int) {
|
||||||
|
|
||||||
// testCMACOperations is a no-op in OSS since CMAC is an Enterprise-only feature.
|
// testCMACOperations is a no-op in OSS since CMAC is an Enterprise-only feature.
|
||||||
// Returns the current count unchanged.
|
// Returns the current count unchanged.
|
||||||
func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount int64) int64 {
|
func testCMACOperations(t *testing.T, core *Core, ctx context.Context, root string, currentCount uint64) uint64 {
|
||||||
// CMAC is not supported in OSS, so we don't perform any operations
|
// CMAC is not supported in OSS, so we don't perform any operations
|
||||||
return currentCount
|
return currentCount
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -420,9 +420,6 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
_, err := core.HandleRequest(ctx, req)
|
_, err := core.HandleRequest(ctx, req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Reset the transit counters
|
|
||||||
billing.CurrentDataProtectionCallCounts.Transit = 0
|
|
||||||
|
|
||||||
// Create an encryption key
|
// Create an encryption key
|
||||||
req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/foo")
|
req = logical.TestRequest(t, logical.CreateOperation, "transit/keys/foo")
|
||||||
req.Data["type"] = "aes256-gcm96"
|
req.Data["type"] = "aes256-gcm96"
|
||||||
|
|
@ -439,8 +436,8 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.NotNil(t, resp.Data)
|
require.NotNil(t, resp.Data)
|
||||||
|
|
||||||
// Verify that the transit counter is incremented (replicated mount by default)
|
// Verify that the transit counter is incremented
|
||||||
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(1), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Get the ciphertext from the encryption response
|
// Get the ciphertext from the encryption response
|
||||||
ciphertext, ok := resp.Data["ciphertext"].(string)
|
ciphertext, ok := resp.Data["ciphertext"].(string)
|
||||||
|
|
@ -455,7 +452,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify that the transit counter is incremented
|
// Verify that the transit counter is incremented
|
||||||
require.Equal(t, int64(2), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(2), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Test rewrap operation
|
// Test rewrap operation
|
||||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/rewrap/foo")
|
req = logical.TestRequest(t, logical.UpdateOperation, "transit/rewrap/foo")
|
||||||
|
|
@ -467,7 +464,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NotNil(t, resp.Data)
|
require.NotNil(t, resp.Data)
|
||||||
|
|
||||||
// Verify that the transit counter is incremented
|
// Verify that the transit counter is incremented
|
||||||
require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(3), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Get the new ciphertext from rewrap
|
// Get the new ciphertext from rewrap
|
||||||
newCiphertext, ok := resp.Data["ciphertext"].(string)
|
newCiphertext, ok := resp.Data["ciphertext"].(string)
|
||||||
|
|
@ -483,9 +480,9 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NotNil(t, resp.Data)
|
require.NotNil(t, resp.Data)
|
||||||
|
|
||||||
// Verify that the transit counter is incremented
|
// Verify that the transit counter is incremented
|
||||||
require.Equal(t, int64(4), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(4), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Test HMAC generation
|
// Test HMAC operation
|
||||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/hmac/foo")
|
req = logical.TestRequest(t, logical.UpdateOperation, "transit/hmac/foo")
|
||||||
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
|
||||||
req.ClientToken = root
|
req.ClientToken = root
|
||||||
|
|
@ -495,7 +492,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NotNil(t, resp.Data)
|
require.NotNil(t, resp.Data)
|
||||||
|
|
||||||
// Verify that the transit counter is incremented
|
// Verify that the transit counter is incremented
|
||||||
require.Equal(t, int64(5), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(5), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Get the HMAC value
|
// Get the HMAC value
|
||||||
hmacValue, ok := resp.Data["hmac"].(string)
|
hmacValue, ok := resp.Data["hmac"].(string)
|
||||||
|
|
@ -513,7 +510,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NotNil(t, resp.Data)
|
require.NotNil(t, resp.Data)
|
||||||
|
|
||||||
// Verify that the transit counter is incremented
|
// Verify that the transit counter is incremented
|
||||||
require.Equal(t, int64(6), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(6), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Verify the HMAC is valid
|
// Verify the HMAC is valid
|
||||||
hmacValid, ok := resp.Data["valid"].(bool)
|
hmacValid, ok := resp.Data["valid"].(bool)
|
||||||
|
|
@ -537,7 +534,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NotNil(t, resp.Data)
|
require.NotNil(t, resp.Data)
|
||||||
|
|
||||||
// Verify that the transit counter is incremented
|
// Verify that the transit counter is incremented
|
||||||
require.Equal(t, int64(7), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(7), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Get the signature
|
// Get the signature
|
||||||
signature, ok := resp.Data["signature"].(string)
|
signature, ok := resp.Data["signature"].(string)
|
||||||
|
|
@ -555,7 +552,7 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NotNil(t, resp.Data)
|
require.NotNil(t, resp.Data)
|
||||||
|
|
||||||
// Verify that the transit counter is incremented
|
// Verify that the transit counter is incremented
|
||||||
require.Equal(t, int64(8), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(8), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Verify the signature is valid
|
// Verify the signature is valid
|
||||||
signatureValid, ok := resp.Data["valid"].(bool)
|
signatureValid, ok := resp.Data["valid"].(bool)
|
||||||
|
|
@ -563,27 +560,27 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.True(t, signatureValid)
|
require.True(t, signatureValid)
|
||||||
|
|
||||||
// Test CMAC operations (ENT only - will be no-op in OSS)
|
// Test CMAC operations (ENT only - will be no-op in OSS)
|
||||||
currentCount := billing.CurrentDataProtectionCallCounts.Transit
|
currentCount := core.GetInMemoryTransitDataProtectionCallCounts()
|
||||||
currentCount = testCMACOperations(t, core, ctx, root, currentCount)
|
currentCount = testCMACOperations(t, core, ctx, root, currentCount)
|
||||||
|
|
||||||
// Verify that the transit counter matches expected count
|
// Verify that the transit counter matches expected count
|
||||||
require.Equal(t, currentCount, billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, currentCount, core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Now test persisting the summed counts - store and retrieve counts
|
// Now test persisting the summed counts - store and retrieve counts
|
||||||
// First, update the data protection call counts (this will sum current counter with stored value)
|
// First, update the data protection call counts (this will sum current counter with stored value)
|
||||||
summedCounts, err := core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
summedCounts, err := core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, summedCounts)
|
require.NotNil(t, summedCounts)
|
||||||
require.Equal(t, currentCount, summedCounts.Transit)
|
require.Equal(t, currentCount, summedCounts.Transit.Load())
|
||||||
|
|
||||||
// Verify the counter was reset after update
|
// Verify the counter was reset after update
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
|
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update")
|
||||||
|
|
||||||
// Retrieve the stored counts
|
// Retrieve the stored counts
|
||||||
storedCounts, err := core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
storedCounts, err := core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, storedCounts)
|
require.NotNil(t, storedCounts)
|
||||||
require.Equal(t, currentCount, storedCounts.Transit)
|
require.Equal(t, currentCount, storedCounts.Transit.Load())
|
||||||
|
|
||||||
// Perform more operations to increase the counter
|
// Perform more operations to increase the counter
|
||||||
req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo")
|
req = logical.TestRequest(t, logical.UpdateOperation, "transit/encrypt/foo")
|
||||||
|
|
@ -593,21 +590,21 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Counter should now be 1 (reset + 1 operation)
|
// Counter should now be 1 (reset + 1 operation)
|
||||||
require.Equal(t, int64(1), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(1), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Update counts again - should sum the new count (1) with the stored count (currentCount)
|
// Update counts again - should sum the new count (1) with the stored count (currentCount)
|
||||||
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expectedSum := currentCount + 1
|
expectedSum := currentCount + 1
|
||||||
require.Equal(t, expectedSum, summedCounts.Transit, "Count should be sum of stored and current")
|
require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should be sum of stored and current")
|
||||||
|
|
||||||
// Verify the counter was reset after update
|
// Verify the counter was reset after update
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
|
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update")
|
||||||
|
|
||||||
// Verify stored counts are now the sum
|
// Verify stored counts are now the sum
|
||||||
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, expectedSum, storedCounts.Transit)
|
require.Equal(t, expectedSum, storedCounts.Transit.Load())
|
||||||
|
|
||||||
// Add more operations without manually resetting
|
// Add more operations without manually resetting
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
|
|
@ -619,35 +616,35 @@ func TestDataProtectionCallCounts(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Counter should be 3
|
// Counter should be 3
|
||||||
require.Equal(t, int64(3), billing.CurrentDataProtectionCallCounts.Transit)
|
require.Equal(t, uint64(3), core.GetInMemoryTransitDataProtectionCallCounts())
|
||||||
|
|
||||||
// Update counts - should sum 3 with the previous stored sum
|
// Update counts - should sum 3 with the previous stored sum
|
||||||
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expectedSum = expectedSum + 3
|
expectedSum = expectedSum + 3
|
||||||
require.Equal(t, expectedSum, summedCounts.Transit, "Count should continue to sum")
|
require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should continue to sum")
|
||||||
|
|
||||||
// Verify the counter was reset after update
|
// Verify the counter was reset after update
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should be reset after update")
|
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should be reset after update")
|
||||||
|
|
||||||
// Verify stored counts
|
// Verify stored counts
|
||||||
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, expectedSum, storedCounts.Transit)
|
require.Equal(t, expectedSum, storedCounts.Transit.Load())
|
||||||
|
|
||||||
// Update again without any new operations
|
// Update again without any new operations
|
||||||
// This verifies we don't double-count
|
// This verifies we don't double-count
|
||||||
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
summedCounts, err = core.UpdateDataProtectionCallCounts(ctx, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, expectedSum, summedCounts.Transit, "Count should remain the same when no new operations occurred")
|
require.Equal(t, expectedSum, summedCounts.Transit.Load(), "Count should remain the same when no new operations occurred")
|
||||||
|
|
||||||
// Verify stored counts haven't changed
|
// Verify stored counts haven't changed
|
||||||
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
storedCounts, err = core.GetStoredDataProtectionCallCounts(ctx, time.Now())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, expectedSum, storedCounts.Transit, "Stored count should remain the same")
|
require.Equal(t, expectedSum, storedCounts.Transit.Load(), "Stored count should remain the same")
|
||||||
|
|
||||||
// Verify counter is still at 0
|
// Verify counter is still at 0
|
||||||
require.Equal(t, int64(0), billing.CurrentDataProtectionCallCounts.Transit, "Counter should still be 0")
|
require.Equal(t, uint64(0), core.GetInMemoryTransitDataProtectionCallCounts(), "Counter should still be 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) {
|
func addRoleToStorage(t *testing.T, core *Core, mount string, key string, numberOfKeys int) {
|
||||||
|
|
|
||||||
|
|
@ -78,3 +78,7 @@ func (c *Core) setupHeaderHMACKey(ctx context.Context, isPerfStandby bool) error
|
||||||
func (c *Core) GetPkiCertificateCounter() logical.CertificateCounter {
|
func (c *Core) GetPkiCertificateCounter() logical.CertificateCounter {
|
||||||
return c.pkiCertCountManager
|
return c.pkiCertCountManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Core) GetConsumptionBillingManager() logical.ConsumptionBillingManager {
|
||||||
|
return c.consumptionBilling
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -401,3 +401,7 @@ func (d dynamicSystemView) DeregisterRotationJob(ctx context.Context, req *rotat
|
||||||
|
|
||||||
return d.core.DeregisterRotationJob(nsCtx, req)
|
return d.core.DeregisterRotationJob(nsCtx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d dynamicSystemView) GetConsumptionBillingManager() logical.ConsumptionBillingManager {
|
||||||
|
return d.core.GetConsumptionBillingManager()
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue