diff --git a/builtin/logical/aws/client.go b/builtin/logical/aws/client.go index 802abb3d1d..4891666eae 100644 --- a/builtin/logical/aws/client.go +++ b/builtin/logical/aws/client.go @@ -5,6 +5,7 @@ package aws import ( "context" + "errors" "fmt" "os" "strconv" @@ -23,91 +24,139 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +// Return a slice of *aws.Config, based on descending configuration priority. STS endpoints are the only place this is used. // NOTE: The caller is required to ensure that b.clientMutex is at least read locked -func (b *backend) getRootConfig(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) (*aws.Config, error) { - credsConfig := &awsutil.CredentialsConfig{} - var endpoint string - var maxRetries int = aws.UseServiceDefaultRetries +func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientType string, logger hclog.Logger) ([]*aws.Config, error) { + // set fallback region (we can overwrite later) + fallbackRegion := os.Getenv("AWS_REGION") + if fallbackRegion == "" { + fallbackRegion = os.Getenv("AWS_DEFAULT_REGION") + } + if fallbackRegion == "" { + fallbackRegion = "us-east-1" + } + + maxRetries := aws.UseServiceDefaultRetries entry, err := s.Get(ctx, "config/root") if err != nil { return nil, err } - if entry != nil { - var config rootConfig - if err := entry.DecodeJSON(&config); err != nil { - return nil, fmt.Errorf("error reading root configuration: %w", err) + var configs []*aws.Config + + // ensure the nil case uses defaults + if entry == nil { + ccfg := awsutil.CredentialsConfig{ + HTTPClient: cleanhttp.DefaultClient(), + Logger: logger, + Region: fallbackRegion, } - - credsConfig.AccessKey = config.AccessKey - credsConfig.SecretKey = config.SecretKey - credsConfig.Region = config.Region - maxRetries = config.MaxRetries - switch { - case clientType == "iam" && config.IAMEndpoint != "": - endpoint = *aws.String(config.IAMEndpoint) - case clientType == "sts" && config.STSEndpoint != "": - endpoint = *aws.String(config.STSEndpoint) - if config.STSRegion != "" { - credsConfig.Region = config.STSRegion - } + creds, err := ccfg.GenerateCredentialChain() + if err != nil { + return nil, err } + configs = append(configs, &aws.Config{ + Credentials: creds, + Region: aws.String(fallbackRegion), + Endpoint: aws.String(""), + MaxRetries: aws.Int(maxRetries), + }) - if config.IdentityTokenAudience != "" { - ns, err := namespace.FromContext(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get namespace from context: %w", err) - } - - fetcher := &PluginIdentityTokenFetcher{ - sys: b.System(), - logger: b.Logger(), - ns: ns, - audience: config.IdentityTokenAudience, - ttl: config.IdentityTokenTTL, - } - - sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10) - credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix) - credsConfig.WebIdentityTokenFetcher = fetcher - credsConfig.RoleARN = config.RoleARN - } + return configs, nil } - if credsConfig.Region == "" { - credsConfig.Region = os.Getenv("AWS_REGION") - if credsConfig.Region == "" { - credsConfig.Region = os.Getenv("AWS_DEFAULT_REGION") - if credsConfig.Region == "" { - credsConfig.Region = "us-east-1" - } - } + var config rootConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, fmt.Errorf("error reading root configuration: %w", err) } + var endpoints []string + var regions []string + credsConfig := &awsutil.CredentialsConfig{} + + credsConfig.AccessKey = config.AccessKey + credsConfig.SecretKey = config.SecretKey credsConfig.HTTPClient = cleanhttp.DefaultClient() - credsConfig.Logger = logger - creds, err := credsConfig.GenerateCredentialChain() - if err != nil { - return nil, err + maxRetries = config.MaxRetries + if clientType == "iam" && config.IAMEndpoint != "" { + endpoints = append(endpoints, config.IAMEndpoint) + } else if clientType == "sts" && config.STSEndpoint != "" { + endpoints = append(endpoints, config.STSEndpoint) + if config.STSRegion != "" { + regions = append(regions, config.STSRegion) + } + + if len(config.STSFallbackEndpoints) > 0 { + endpoints = append(endpoints, config.STSFallbackEndpoints...) + } + + if len(config.STSFallbackRegions) > 0 { + regions = append(regions, config.STSFallbackRegions...) + } } - return &aws.Config{ - Credentials: creds, - Region: aws.String(credsConfig.Region), - Endpoint: &endpoint, - HTTPClient: cleanhttp.DefaultClient(), - MaxRetries: aws.Int(maxRetries), - }, nil + if config.IdentityTokenAudience != "" { + ns, err := namespace.FromContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get namespace from context: %w", err) + } + + fetcher := &PluginIdentityTokenFetcher{ + sys: b.System(), + logger: b.Logger(), + ns: ns, + audience: config.IdentityTokenAudience, + ttl: config.IdentityTokenTTL, + } + + sessionSuffix := strconv.FormatInt(time.Now().UnixNano(), 10) + credsConfig.RoleSessionName = fmt.Sprintf("vault-aws-secrets-%s", sessionSuffix) + credsConfig.WebIdentityTokenFetcher = fetcher + credsConfig.RoleARN = config.RoleARN + } + + if len(regions) == 0 { + regions = append(regions, fallbackRegion) + } + + if len(regions) != len(endpoints) { + // this probably can't happen, if the input was checked correctly + return nil, errors.New("number of regions does not match number of endpoints") + } + + for i := 0; i < len(endpoints); i++ { + if len(regions) > i { + credsConfig.Region = regions[i] + } else { + credsConfig.Region = fallbackRegion + } + creds, err := credsConfig.GenerateCredentialChain() + if err != nil { + return nil, err + } + configs = append(configs, &aws.Config{ + Credentials: creds, + Region: aws.String(credsConfig.Region), + Endpoint: aws.String(endpoints[i]), + MaxRetries: aws.Int(maxRetries), + HTTPClient: cleanhttp.DefaultClient(), + }) + } + + return configs, nil } func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) { - awsConfig, err := b.getRootConfig(ctx, s, "iam", logger) + awsConfig, err := b.getRootConfigs(ctx, s, "iam", logger) if err != nil { return nil, err } - sess, err := session.NewSession(awsConfig) + if len(awsConfig) != 1 { + return nil, errors.New("could not obtain aws config") + } + sess, err := session.NewSession(awsConfig[0]) if err != nil { return nil, err } @@ -119,19 +168,33 @@ func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, log } func (b *backend) nonCachedClientSTS(ctx context.Context, s logical.Storage, logger hclog.Logger) (*sts.STS, error) { - awsConfig, err := b.getRootConfig(ctx, s, "sts", logger) + awsConfig, err := b.getRootConfigs(ctx, s, "sts", logger) if err != nil { return nil, err } - sess, err := session.NewSession(awsConfig) - if err != nil { - return nil, err + + var client *sts.STS + + for _, cfg := range awsConfig { + sess, err := session.NewSession(cfg) + if err != nil { + return nil, err + } + client = sts.New(sess) + if client == nil { + return nil, fmt.Errorf("could not obtain sts client") + } + + // ping the client - we only care about errors + _, err = client.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err == nil { + return client, nil + } else { + b.Logger().Debug("couldn't connect with config trying next", "failed endpoint", cfg.Endpoint, "failed region", cfg.Region) + } } - client := sts.New(sess) - if client == nil { - return nil, fmt.Errorf("could not obtain sts client") - } - return client, nil + + return nil, fmt.Errorf("could not obtain sts client") } // PluginIdentityTokenFetcher fetches plugin identity tokens from Vault. It is provided diff --git a/builtin/logical/aws/path_config_root.go b/builtin/logical/aws/path_config_root.go index 741c8502d0..84b2f92fa5 100644 --- a/builtin/logical/aws/path_config_root.go +++ b/builtin/logical/aws/path_config_root.go @@ -52,6 +52,14 @@ func pathConfigRoot(b *backend) *framework.Path { Type: framework.TypeString, Description: "Specific region for STS API calls.", }, + "sts_fallback_endpoints": { + Type: framework.TypeCommaStringSlice, + Description: "Fallback endpoints if sts_endpoint is unreachable", + }, + "sts_fallback_regions": { + Type: framework.TypeCommaStringSlice, + Description: "Fallback regions if sts_region is unreachable", + }, "max_retries": { Type: framework.TypeInt, Default: aws.UseServiceDefaultRetries, @@ -110,14 +118,16 @@ func (b *backend) pathConfigRootRead(ctx context.Context, req *logical.Request, } configData := map[string]interface{}{ - "access_key": config.AccessKey, - "region": config.Region, - "iam_endpoint": config.IAMEndpoint, - "sts_endpoint": config.STSEndpoint, - "sts_region": config.STSRegion, - "max_retries": config.MaxRetries, - "username_template": config.UsernameTemplate, - "role_arn": config.RoleARN, + "access_key": config.AccessKey, + "region": config.Region, + "iam_endpoint": config.IAMEndpoint, + "sts_endpoint": config.STSEndpoint, + "sts_region": config.STSRegion, + "sts_fallback_endpoints": config.STSFallbackEndpoints, + "sts_fallback_regions": config.STSFallbackRegions, + "max_retries": config.MaxRetries, + "username_template": config.UsernameTemplate, + "role_arn": config.RoleARN, } config.PopulatePluginIdentityTokenData(configData) @@ -138,19 +148,28 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, usernameTemplate = defaultUserNameTemplate } + stsFallbackEndpoints := data.Get("sts_fallback_endpoints").([]string) + stsFallbackRegions := data.Get("sts_fallback_regions").([]string) + + if len(stsFallbackEndpoints) != len(stsFallbackRegions) { + return logical.ErrorResponse("fallback endpoints and fallback regions must be the same length"), nil + } + b.clientMutex.Lock() defer b.clientMutex.Unlock() rc := rootConfig{ - AccessKey: data.Get("access_key").(string), - SecretKey: data.Get("secret_key").(string), - IAMEndpoint: iamendpoint, - STSEndpoint: stsendpoint, - STSRegion: stsregion, - Region: region, - MaxRetries: maxretries, - UsernameTemplate: usernameTemplate, - RoleARN: roleARN, + AccessKey: data.Get("access_key").(string), + SecretKey: data.Get("secret_key").(string), + IAMEndpoint: iamendpoint, + STSEndpoint: stsendpoint, + STSRegion: stsregion, + STSFallbackEndpoints: stsFallbackEndpoints, + STSFallbackRegions: stsFallbackRegions, + Region: region, + MaxRetries: maxretries, + UsernameTemplate: usernameTemplate, + RoleARN: roleARN, } if err := rc.ParsePluginIdentityTokenFields(data); err != nil { return logical.ErrorResponse(err.Error()), nil @@ -196,15 +215,17 @@ func (b *backend) pathConfigRootWrite(ctx context.Context, req *logical.Request, type rootConfig struct { pluginidentityutil.PluginIdentityTokenParams - AccessKey string `json:"access_key"` - SecretKey string `json:"secret_key"` - IAMEndpoint string `json:"iam_endpoint"` - STSEndpoint string `json:"sts_endpoint"` - STSRegion string `json:"sts_region"` - Region string `json:"region"` - MaxRetries int `json:"max_retries"` - UsernameTemplate string `json:"username_template"` - RoleARN string `json:"role_arn"` + AccessKey string `json:"access_key"` + SecretKey string `json:"secret_key"` + IAMEndpoint string `json:"iam_endpoint"` + STSEndpoint string `json:"sts_endpoint"` + STSRegion string `json:"sts_region"` + STSFallbackEndpoints []string `json:"sts_fallback_endpoints"` + STSFallbackRegions []string `json:"sts_fallback_regions"` + Region string `json:"region"` + MaxRetries int `json:"max_retries"` + UsernameTemplate string `json:"username_template"` + RoleARN string `json:"role_arn"` } const pathConfigRootHelpSyn = ` diff --git a/builtin/logical/aws/path_config_root_test.go b/builtin/logical/aws/path_config_root_test.go index 9c1ed0476f..1439a8b5ce 100644 --- a/builtin/logical/aws/path_config_root_test.go +++ b/builtin/logical/aws/path_config_root_test.go @@ -31,6 +31,8 @@ func TestBackend_PathConfigRoot(t *testing.T) { "iam_endpoint": "https://iam.amazonaws.com", "sts_endpoint": "https://sts.us-west-2.amazonaws.com", "sts_region": "", + "sts_fallback_endpoints": []string{}, + "sts_fallback_regions": []string{}, "max_retries": 10, "username_template": defaultUserNameTemplate, "role_arn": "", @@ -66,6 +68,152 @@ func TestBackend_PathConfigRoot(t *testing.T) { } } +// TestBackend_PathConfigRoot_STSFallback tests valid versions of STS fallback parameters - slice and csv +func TestBackend_PathConfigRoot_STSFallback(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = &testSystemView{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + + configData := map[string]interface{}{ + "access_key": "AKIAEXAMPLE", + "secret_key": "RandomData", + "region": "us-west-2", + "iam_endpoint": "https://iam.amazonaws.com", + "sts_endpoint": "https://sts.us-west-2.amazonaws.com", + "sts_region": "", + "sts_fallback_endpoints": []string{"192.168.1.1", "127.0.0.1"}, + "sts_fallback_regions": []string{"my-house-1", "my-house-2"}, + "max_retries": 10, + "username_template": defaultUserNameTemplate, + "role_arn": "", + "identity_token_audience": "", + "identity_token_ttl": int64(0), + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Storage: config.StorageView, + Path: "config/root", + Data: configData, + } + + resp, err := b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err) + } + + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.ReadOperation, + Storage: config.StorageView, + Path: "config/root", + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err) + } + + delete(configData, "secret_key") + require.Equal(t, configData, resp.Data) + if !reflect.DeepEqual(resp.Data, configData) { + t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data) + } + + // test we can handle comma separated strings, per CommaStringSlice + configData = map[string]interface{}{ + "access_key": "AKIAEXAMPLE", + "secret_key": "RandomData", + "region": "us-west-2", + "iam_endpoint": "https://iam.amazonaws.com", + "sts_endpoint": "https://sts.us-west-2.amazonaws.com", + "sts_region": "", + "sts_fallback_endpoints": "1.1.1.1,8.8.8.8", + "sts_fallback_regions": "zone-1,zone-2", + "max_retries": 10, + "username_template": defaultUserNameTemplate, + "role_arn": "", + "identity_token_audience": "", + "identity_token_ttl": int64(0), + } + + configReq = &logical.Request{ + Operation: logical.UpdateOperation, + Storage: config.StorageView, + Path: "config/root", + Data: configData, + } + + resp, err = b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: config writing failed: resp:%#v\n err: %v", resp, err) + } + + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.ReadOperation, + Storage: config.StorageView, + Path: "config/root", + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: config reading failed: resp:%#v\n err: %v", resp, err) + } + + delete(configData, "secret_key") + configData["sts_fallback_endpoints"] = []string{"1.1.1.1", "8.8.8.8"} + configData["sts_fallback_regions"] = []string{"zone-1", "zone-2"} + require.Equal(t, configData, resp.Data) + if !reflect.DeepEqual(resp.Data, configData) { + t.Errorf("bad: expected to read config root as %#v, got %#v instead", configData, resp.Data) + } +} + +// TestBackend_PathConfigRoot_STSFallback_mismatchedfallback ensures configuration writing will fail if the +// region/endpoint entries are different lengths +func TestBackend_PathConfigRoot_STSFallback_mismatchedfallback(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = &testSystemView{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + + // test we can handle comma separated strings, per CommaStringSlice + configData := map[string]interface{}{ + "access_key": "AKIAEXAMPLE", + "secret_key": "RandomData", + "region": "us-west-2", + "iam_endpoint": "https://iam.amazonaws.com", + "sts_endpoint": "https://sts.us-west-2.amazonaws.com", + "sts_region": "", + "sts_fallback_endpoints": "1.1.1.1,8.8.8.8", + "sts_fallback_regions": "zone-1,zone-2", + "max_retries": 10, + "username_template": defaultUserNameTemplate, + "role_arn": "", + "identity_token_audience": "", + "identity_token_ttl": int64(0), + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Storage: config.StorageView, + Path: "config/root", + Data: configData, + } + + resp, err := b.HandleRequest(context.Background(), configReq) + if err != nil { + t.Fatalf("bad: config writing failed: err: %v", err) + } + if resp != nil && !resp.IsError() { + t.Fatalf("expected an error, but it successfully wrote") + } +} + // TestBackend_PathConfigRoot_PluginIdentityToken tests that configuration // of plugin WIF returns an immediate error. func TestBackend_PathConfigRoot_PluginIdentityToken(t *testing.T) { diff --git a/changelog/29051.txt b/changelog/29051.txt new file mode 100644 index 0000000000..13c42006d9 --- /dev/null +++ b/changelog/29051.txt @@ -0,0 +1,3 @@ +```release-note:improvement +secrets/aws: add fallback endpoint and region parameters to sts configuration +``` diff --git a/website/content/api-docs/secret/aws.mdx b/website/content/api-docs/secret/aws.mdx index f1c0959e3c..9b34a2a633 100644 --- a/website/content/api-docs/secret/aws.mdx +++ b/website/content/api-docs/secret/aws.mdx @@ -79,6 +79,12 @@ valid AWS credentials with proper permissions. - `sts_endpoint` `(string: )` – Specifies a custom HTTP STS endpoint to use. +- `sts_region` `(string: )` - Specifies a custom STS region to use (should match `sts_endpoint`) + +- `sts_fallback_endpoints` `(list: )` - Specifies an ordered list of fallback STS endpoints to use + +- `sts_fallback_regions` `(list: )` - Specifies an ordered list of fallback STS regions to use (should match fallback endpoints) + - `username_template` `(string: )` - [Template](/vault/docs/concepts/username-templating) describing how dynamic usernames are generated. The username template is used to generate both IAM usernames (capped at 64 characters) and STS usernames (capped at 32 characters). Longer usernames result in a 500 error.