diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index 2f3e8b3ad1..2bf1bfd21c 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1715,7 +1715,6 @@ func (i *IdentityStore) oidcKeyRotation(ctx context.Context, s logical.Storage) // oidcPeriodFunc is invoked by the backend's periodFunc and runs regular key // rotations and expiration actions. func (i *IdentityStore) oidcPeriodicFunc(ctx context.Context) { - i.Logger().Debug("begin oidcPeriodicFunc") var nextRun time.Time now := time.Now() diff --git a/vault/identity_store_oidc_provider.go b/vault/identity_store_oidc_provider.go index d1b98dd534..980ac8ff03 100644 --- a/vault/identity_store_oidc_provider.go +++ b/vault/identity_store_oidc_provider.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/identitytpl" "github.com/hashicorp/vault/sdk/logical" + "gopkg.in/square/go-jose.v2" ) type assignment struct { @@ -277,9 +278,167 @@ func oidcProviderPaths(i *IdentityStore) []*framework.Path { HelpSynopsis: "Query OIDC configurations", HelpDescription: "Query this path to retrieve the configured OIDC Issuer and Keys endpoints, response types, subject types, and signing algorithms used by the OIDC backend.", }, + { + Pattern: "oidc/provider/" + framework.GenericNameRegex("name") + "/.well-known/keys", + Fields: map[string]*framework.FieldSchema{ + "name": { + Type: framework.TypeString, + Description: "Name of the provider", + }, + }, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: i.pathOIDCReadProviderPublicKeys, + }, + HelpSynopsis: "Retrieve public keys", + HelpDescription: "Returns the public portion of keys for a named OIDC provider. Clients can use them to validate the authenticity of an ID token.", + }, } } +func (i *IdentityStore) listClients(ctx context.Context, s logical.Storage) ([]*client, error) { + clientNames, err := s.List(ctx, clientPath) + if err != nil { + return nil, err + } + + var clients []*client + for _, name := range clientNames { + entry, err := s.Get(ctx, clientPath+name) + if err != nil { + return nil, err + } + if entry == nil { + continue + } + + var client client + if err := entry.DecodeJSON(&client); err != nil { + return nil, err + } + clients = append(clients, &client) + } + + return clients, nil +} + +// TODO: load clients into memory (go-memdb) to look this up +func (i *IdentityStore) clientByID(ctx context.Context, s logical.Storage, id string) (*client, error) { + clients, err := i.listClients(ctx, s) + if err != nil { + return nil, err + } + + for _, client := range clients { + if client.ClientID == id { + return client, nil + } + } + + return nil, nil +} + +// keyIDsReferencedByTargetClientIDs returns a slice of key IDs that are +// referenced by the clients' targetIDs. +// If targetIDs contains "*" then the IDs for all public keys are returned. +func (i *IdentityStore) keyIDsReferencedByTargetClientIDs(ctx context.Context, s logical.Storage, targetIDs []string) ([]string, error) { + keyNames := make(map[string]bool) + + // Get all key names referenced by clients if wildcard "*" in target client IDs + if strutil.StrListContains(targetIDs, "*") { + clients, err := i.listClients(ctx, s) + if err != nil { + return nil, err + } + + for _, client := range clients { + keyNames[client.Key] = true + } + } + + // Otherwise, get the key names referenced by each target client ID + if len(keyNames) == 0 { + for _, clientID := range targetIDs { + client, err := i.clientByID(ctx, s, clientID) + if err != nil { + return nil, err + } + + if client != nil { + keyNames[client.Key] = true + } + } + } + + // Collect the key IDs + var keyIDs []string + for name, _ := range keyNames { + entry, err := s.Get(ctx, namedKeyConfigPath+name) + if err != nil { + return nil, err + } + + var key namedKey + if err := entry.DecodeJSON(&key); err != nil { + return nil, err + } + for _, expirableKey := range key.KeyRing { + keyIDs = append(keyIDs, expirableKey.KeyID) + } + } + return keyIDs, nil +} + +// pathOIDCReadProviderPublicKeys is used to retrieve all public keys for a +// named provider so that clients can verify the validity of a signed OIDC token. +func (i *IdentityStore) pathOIDCReadProviderPublicKeys(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + providerName := d.Get("name").(string) + + var provider provider + + providerEntry, err := req.Storage.Get(ctx, providerPath+providerName) + if err != nil { + return nil, err + } + if providerEntry == nil { + return nil, nil + } + if err := providerEntry.DecodeJSON(&provider); err != nil { + return nil, err + } + + keyIDs, err := i.keyIDsReferencedByTargetClientIDs(ctx, req.Storage, provider.AllowedClientIDs) + if err != nil { + return nil, err + } + + jwks := &jose.JSONWebKeySet{ + Keys: make([]jose.JSONWebKey, 0, len(keyIDs)), + } + + for _, keyID := range keyIDs { + key, err := loadOIDCPublicKey(ctx, req.Storage, keyID) + if err != nil { + return nil, err + } + jwks.Keys = append(jwks.Keys, *key) + } + + data, err := json.Marshal(jwks) + if err != nil { + return nil, err + } + + resp := &logical.Response{ + Data: map[string]interface{}{ + logical.HTTPStatusCode: 200, + logical.HTTPRawBody: data, + logical.HTTPContentType: "application/json", + }, + } + + return resp, nil +} + func (i *IdentityStore) pathOIDCProviderDiscovery(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) diff --git a/vault/identity_store_oidc_provider_test.go b/vault/identity_store_oidc_provider_test.go index 0ed35e567a..68d99e9ae4 100644 --- a/vault/identity_store_oidc_provider_test.go +++ b/vault/identity_store_oidc_provider_test.go @@ -10,8 +10,151 @@ import ( "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/logical" + "gopkg.in/square/go-jose.v2" ) +// TestOIDC_Path_OIDC_ProviderReadPublicKey_ProviderDoesNotExist tests that the +// path can handle the read operation when the provider does not exist +func TestOIDC_Path_OIDC_ProviderReadPublicKey_ProviderDoesNotExist(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + storage := &logical.InmemStorage{} + + // Read "test-provider" .well-known keys + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/provider/test-provider/.well-known/keys", + Operation: logical.ReadOperation, + Storage: storage, + }) + expectedResp := &logical.Response{} + if resp != expectedResp && err != nil { + t.Fatalf("expected empty response but got success; error:\n%v\nresp: %#v", err, resp) + } +} + +// TestOIDC_Path_OIDC_ProviderReadPublicKey tests the provider .well-known +// keys endpoint read operations +func TestOIDC_Path_OIDC_ProviderReadPublicKey(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + ctx := namespace.RootContext(nil) + storage := &logical.InmemStorage{} + + // Create a test key "test-key-1" + c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/key/test-key-1", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "verification_ttl": "2m", + "rotation_period": "2m", + }, + Storage: storage, + }) + + // Create a test client "test-client-1" + c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client-1", + Operation: logical.CreateOperation, + Storage: storage, + Data: map[string]interface{}{ + "key": "test-key-1", + }, + }) + + // get the clientID + resp, _ := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client-1", + Operation: logical.ReadOperation, + Storage: storage, + }) + clientID := resp.Data["client_id"].(string) + + // Create a test provider "test-provider" and allow all client IDs -- should succeed + resp, err := c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/provider/test-provider", + Operation: logical.CreateOperation, + Storage: storage, + Data: map[string]interface{}{ + "issuer": "https://example.com:8200", + "allowed_client_ids": []string{"*"}, + }, + }) + expectSuccess(t, resp, err) + + // Read "test-provider" .well-known keys + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/provider/test-provider/.well-known/keys", + Operation: logical.ReadOperation, + Storage: storage, + }) + expectSuccess(t, resp, err) + + responseJWKS := &jose.JSONWebKeySet{} + json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) + if len(responseJWKS.Keys) != 2 { + t.Fatalf("expected 2 public key but instead got %d", len(responseJWKS.Keys)) + } + + // Create a test key "test-key-2" + c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/key/test-key-2", + Operation: logical.CreateOperation, + Data: map[string]interface{}{ + "verification_ttl": "2m", + "rotation_period": "2m", + }, + Storage: storage, + }) + + // Create a test client "test-client-2" + c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/client/test-client-2", + Operation: logical.CreateOperation, + Storage: storage, + Data: map[string]interface{}{ + "key": "test-key-2", + }, + }) + + // Read "test-provider" .well-known keys + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/provider/test-provider/.well-known/keys", + Operation: logical.ReadOperation, + Storage: storage, + }) + expectSuccess(t, resp, err) + + responseJWKS = &jose.JSONWebKeySet{} + json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) + if len(responseJWKS.Keys) != 4 { + t.Fatalf("expected 4 public key but instead got %d", len(responseJWKS.Keys)) + } + + // Update the test provider "test-provider" to only allow test-client-1 -- should succeed + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/provider/test-provider", + Operation: logical.UpdateOperation, + Storage: storage, + Data: map[string]interface{}{ + "allowed_client_ids": []string{clientID}, + }, + }) + expectSuccess(t, resp, err) + + // Read "test-provider" .well-known keys + resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ + Path: "oidc/provider/test-provider/.well-known/keys", + Operation: logical.ReadOperation, + Storage: storage, + }) + expectSuccess(t, resp, err) + + responseJWKS = &jose.JSONWebKeySet{} + json.Unmarshal(resp.Data["http_raw_body"].([]byte), responseJWKS) + if len(responseJWKS.Keys) != 2 { + t.Fatalf("expected 2 public key but instead got %d", len(responseJWKS.Keys)) + } +} + // TestOIDC_Path_OIDC_ProviderClient_NoKeyParameter tests that a client cannot // be created without a key parameter func TestOIDC_Path_OIDC_ProviderClient_NoKeyParameter(t *testing.T) { @@ -97,7 +240,7 @@ func TestOIDC_Path_OIDC_ProviderClient_UpdateKey(t *testing.T) { }) expectSuccess(t, resp, err) - // Create a test client "test-client" -- should fail + // Update the test client "test-client" -- should fail resp, err = c.identityStore.HandleRequest(ctx, &logical.Request{ Path: "oidc/client/test-client", Operation: logical.UpdateOperation,