diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index 01920dd0c1..2399ec7a62 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -73,6 +73,8 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error) b.pathTrim(), b.pathCacheConfig(), b.pathConfigKeys(), + b.pathCreateCsr(), + b.pathImportCertChain(), }, Secrets: []*framework.Secret{}, diff --git a/builtin/logical/transit/path_certificates.go b/builtin/logical/transit/path_certificates.go new file mode 100644 index 0000000000..59871cd337 --- /dev/null +++ b/builtin/logical/transit/path_certificates.go @@ -0,0 +1,300 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package transit + +import ( + "context" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "strings" + + "github.com/hashicorp/vault/sdk/helper/errutil" + + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/keysutil" + "github.com/hashicorp/vault/sdk/logical" +) + +func (b *backend) pathCreateCsr() *framework.Path { + return &framework.Path{ + Pattern: "keys/" + framework.GenericNameRegex("name") + "/csr", + Fields: map[string]*framework.FieldSchema{ + "name": { + Type: framework.TypeString, + Required: true, + Description: "Name of the key", + }, + "version": { + Type: framework.TypeInt, + Required: false, + Description: "Optional version of key, 'latest' if not set", + }, + "csr": { + Type: framework.TypeString, + Required: false, + Description: `PEM encoded CSR template. The information attributes +will be used as a basis for the CSR with the key in transit. If not set, an empty CSR is returned.`, + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + // NOTE: Create and Update? + logical.CreateOperation: &framework.PathOperation{ + Callback: b.pathCreateCsrWrite, + DisplayAttrs: &framework.DisplayAttributes{ + OperationVerb: "create", + }, + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathCreateCsrWrite, + DisplayAttrs: &framework.DisplayAttributes{ + OperationVerb: "update", + }, + }, + }, + HelpSynopsis: pathCreateCsrHelpSyn, + HelpDescription: pathCreateCsrHelpDesc, + } +} + +func (b *backend) pathImportCertChain() *framework.Path { + return &framework.Path{ + // NOTE: `set-certificate` or `set_certificate`? Paths seem to use different + // case, such as `transit/wrapping_key` and `transit/cache-config`. + Pattern: "keys/" + framework.GenericNameRegex("name") + "/set-certificate", + Fields: map[string]*framework.FieldSchema{ + "name": { + Type: framework.TypeString, + Required: true, + Description: "Name of the key", + }, + "version": { + Type: framework.TypeInt, + Required: false, + Description: "Optional version of key, 'latest' if not set", + }, + "certificate_chain": { + Type: framework.TypeString, + Required: true, + Description: `PEM encoded certificate chain. It should be composed +by one or more concatenated PEM blocks and ordered starting from the end-entity certificate.`, + }, + }, + Operations: map[logical.Operation]framework.OperationHandler{ + // NOTE: Create and Update? + logical.CreateOperation: &framework.PathOperation{ + Callback: b.pathImportCertChainWrite, + DisplayAttrs: &framework.DisplayAttributes{ + OperationVerb: "create", + }, + }, + logical.UpdateOperation: &framework.PathOperation{ + Callback: b.pathImportCertChainWrite, + DisplayAttrs: &framework.DisplayAttributes{ + OperationVerb: "update", + }, + }, + }, + HelpSynopsis: pathImportCertChainHelpSyn, + HelpDescription: pathImportCertChainHelpDesc, + } +} + +func (b *backend) pathCreateCsrWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ + Storage: req.Storage, + Name: name, + }, b.GetRandomReader()) + if err != nil { + return nil, err + } + if p == nil { + return logical.ErrorResponse(fmt.Sprintf("key with provided name '%s' not found", name)), logical.ErrInvalidRequest + } + if !b.System().CachingDisabled() { + p.Lock(false) // NOTE: No lock on "read" operations? + } + defer p.Unlock() + + // Check if transit key supports signing + if !p.Type.SigningSupported() { + return logical.ErrorResponse(fmt.Sprintf("key type '%s' does not support signing", p.Type)), logical.ErrInvalidRequest + } + + // Check if key can be derived + if p.Derived { + return logical.ErrorResponse("operation not supported on keys with derivation enabled"), logical.ErrInvalidRequest + } + + // Transit key version + signingKeyVersion := p.LatestVersion + // NOTE: BYOK endpoints seem to remove "v" prefix from version, + // are versions like that also supported? + if version, ok := d.GetOk("version"); ok { + signingKeyVersion = version.(int) + } + + // Read and parse CSR template + pemCsrTemplate := d.Get("csr").(string) + csrTemplate, err := parseCsr(pemCsrTemplate) + if err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } + + pemCsr, err := p.CreateCsr(signingKeyVersion, csrTemplate) + if err != nil { + prefixedErr := fmt.Errorf("could not create the csr: %w", err) + switch err.(type) { + case errutil.UserError: + return logical.ErrorResponse(prefixedErr.Error()), logical.ErrInvalidRequest + default: + return nil, prefixedErr + } + } + + resp := &logical.Response{ + Data: map[string]interface{}{ + "name": p.Name, + "type": p.Type.String(), + "csr": string(pemCsr), + }, + } + + return resp, nil +} + +func (b *backend) pathImportCertChainWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ + Storage: req.Storage, + Name: name, + }, b.GetRandomReader()) + if err != nil { + return nil, err + } + if p == nil { + return logical.ErrorResponse(fmt.Sprintf("key with provided name '%s' not found", name)), logical.ErrInvalidRequest + } + if !b.System().CachingDisabled() { + p.Lock(true) // NOTE: Lock as we are might write to the policy + } + defer p.Unlock() + + // Check if transit key supports signing + if !p.Type.SigningSupported() { + return logical.ErrorResponse(fmt.Sprintf("key type %s does not support signing", p.Type)), logical.ErrInvalidRequest + } + + // Check if key can be derived + if p.Derived { + return logical.ErrorResponse("operation not supported on keys with derivation enabled"), logical.ErrInvalidRequest + } + + // Transit key version + keyVersion := p.LatestVersion + if version, ok := d.GetOk("version"); ok { + keyVersion = version.(int) + } + + // Get certificate chain + pemCertChain := d.Get("certificate_chain").(string) + certChain, err := parseCertificateChain(pemCertChain) + if err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } + + err = p.ValidateAndPersistCertificateChain(ctx, keyVersion, certChain, req.Storage) + if err != nil { + prefixedErr := fmt.Errorf("failed to persist certificate chain: %w", err) + switch err.(type) { + case errutil.UserError: + return logical.ErrorResponse(prefixedErr.Error()), logical.ErrInvalidRequest + default: + return nil, prefixedErr + } + } + + resp := &logical.Response{ + Data: map[string]interface{}{ + "name": p.Name, + "type": p.Type.String(), + "certificate-chain": pemCertChain, + }, + } + + return resp, nil +} + +func parseCsr(csrStr string) (*x509.CertificateRequest, error) { + if csrStr == "" { + return &x509.CertificateRequest{}, nil + } + + block, _ := pem.Decode([]byte(csrStr)) + if block == nil { + return nil, errors.New("could not decode PEM certificate request") + } + + csr, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return nil, err + } + + return csr, nil +} + +func parseCertificateChain(certChainString string) ([]*x509.Certificate, error) { + var certificates []*x509.Certificate + + var pemCertBlocks []*pem.Block + pemBytes := []byte(strings.TrimSpace(certChainString)) + for len(pemBytes) > 0 { + var pemCertBlock *pem.Block + pemCertBlock, pemBytes = pem.Decode(pemBytes) + if pemCertBlock == nil { + return nil, errors.New("could not decode PEM block in certificate chain") + } + + switch pemCertBlock.Type { + case "CERTIFICATE", "X05 CERTIFICATE": + pemCertBlocks = append(pemCertBlocks, pemCertBlock) + default: + // Ignore any other entries + } + } + + if len(pemCertBlocks) == 0 { + return nil, errors.New("provided certificate chain did not contain any valid PEM certificate") + } + + for _, certBlock := range pemCertBlocks { + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate in certificate chain: %w", err) + } + + certificates = append(certificates, cert) + } + + return certificates, nil +} + +const pathCreateCsrHelpSyn = `Create a CSR from a key in transit` + +const pathCreateCsrHelpDesc = `This path is used to create a CSR from a key in +transit. If a CSR template is provided, its significant information, expect key +related data, are included in the CSR otherwise an empty CSR is returned. +` + +const pathImportCertChainHelpSyn = `Imports an externally-signed certificate +chain into an existing key version` + +const pathImportCertChainHelpDesc = `This path is used to import an externally- +signed certificate chain into a key in transit. The leaf certificate key has to +match the selected key in transit. +` diff --git a/builtin/logical/transit/path_certificates_test.go b/builtin/logical/transit/path_certificates_test.go new file mode 100644 index 0000000000..13598febde --- /dev/null +++ b/builtin/logical/transit/path_certificates_test.go @@ -0,0 +1,260 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package transit + +import ( + "context" + cryptoRand "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/logical/pki" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault" + "github.com/stretchr/testify/require" +) + +func TestTransit_Certs_CreateCsr(t *testing.T) { + // NOTE: Use an existing CSR or generate one here? + templateCsr := ` +-----BEGIN CERTIFICATE REQUEST----- +MIICRTCCAS0CAQAwADCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAM49 +McW7u3ILuAJfSFLUtGOMGBytHmMFcjTiX+5JcajFj0Uszb+HQ7eIsJJNXhVc/7fg +Z01DZvcCqb9ChEWE3xi4GEkPMXay7p7G1ooSLnQp6Z0lL5CuIFfMVOTvjfhTwRaJ +l9v2mMlm80BeiAUBqeoyGVrIh5fKASxaE0jrhjAxhGzqrXdDnL8A4na6ArprV4iS +aEAziODd2WmplSKgUwEaFdeG1t1bJf3o5ZQRCnKNtQcAk8UmgtvFEO8ohGMln/Fj +O7u7s6iRhOGf1g1NCAP5pGqxNx3bjz5f/CUcTSIGAReEomg41QTIhD9muCTL8qnm +6lS87wkGTv7qbeIGB7sCAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4IBAQAfjE+jNqIk +4V1tL3g5XPjxr2+QcwddPf8opmbAzgt0+TiIHcDGBAxsXyi7sC9E5AFfFp7W07Zv +r5+v4i529K9q0BgGtHFswoEnhd4dC8Ye53HtSoEtXkBpZMDrtbS7eZa9WccT6zNx +4taTkpptZVrmvPj+jLLFkpKJJ3d+Gbrp6hiORPadT+igLKkqvTeocnhOdAtt427M +RXTVgN14pV3tqO+5MXzNw5tGNPcwWARWwPH9eCRxLwLUuxE4Qu73pUeEFjDEfGkN +iBnlTsTXBOMqSGryEkmRaZslWDvblvYeObYw+uc3kCbJ7jRy9soVwkbb5FueF/yC +O1aQIm23HrrG +-----END CERTIFICATE REQUEST----- +` + + testTransit_CreateCsr(t, "rsa-2048", templateCsr) + testTransit_CreateCsr(t, "rsa-3072", templateCsr) + testTransit_CreateCsr(t, "rsa-4096", templateCsr) + testTransit_CreateCsr(t, "ecdsa-p256", templateCsr) + testTransit_CreateCsr(t, "ecdsa-p384", templateCsr) + testTransit_CreateCsr(t, "ecdsa-p521", templateCsr) + testTransit_CreateCsr(t, "ed25519", templateCsr) + testTransit_CreateCsr(t, "aes256-gcm96", templateCsr) +} + +func testTransit_CreateCsr(t *testing.T, keyType, pemTemplateCsr string) { + var resp *logical.Response + var err error + b, s := createBackendWithStorage(t) + + // Create the policy + policyReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "keys/test-key", + Storage: s, + Data: map[string]interface{}{ + "type": keyType, + }, + } + resp, err = b.HandleRequest(context.Background(), policyReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("resp: %#v\nerr: %v", resp, err) + } + + csrSignReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "keys/test-key/csr", + Storage: s, + Data: map[string]interface{}{ + "csr": pemTemplateCsr, + }, + } + + resp, err = b.HandleRequest(context.Background(), csrSignReq) + + switch keyType { + case "rsa-2048", "rsa-3072", "rsa-4096", "ecdsa-p256", "ecdsa-p384", "ecdsa-p521", "ed25519": + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("failed to sign CSR, err:%v resp:%#v", err, resp) + } + + signedCsrBytes, ok := resp.Data["csr"] + if !ok { + t.Fatal("expected response data to hold a 'csr' field") + } + + signedCsr, err := parseCsr(signedCsrBytes.(string)) + if err != nil { + t.Fatalf("failed to parse returned csr, err:%v", err) + } + + templateCsr, err := parseCsr(pemTemplateCsr) + if err != nil { + t.Fatalf("failed to parse returned template csr, err:%v", err) + } + + // NOTE: Check other fields? + if !reflect.DeepEqual(signedCsr.Subject, templateCsr.Subject) { + t.Fatalf("subjects should have matched, err:%v", err) + } + + default: + if err == nil || (resp != nil && !resp.IsError()) { + t.Fatalf("should have failed to sign CSR, provided key type does not support signing") + } + } +} + +// NOTE: Tests are using two 'different' methods of checking for errors, which one sould we prefer? +func TestTransit_Certs_ImportCertChain(t *testing.T) { + // Create Cluster + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "transit": Factory, + "pki": pki.Factory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + vault.TestWaitActive(t, cores[0].Core) + client := cores[0].Client + + // Mount transit backend + err := client.Sys().Mount("transit", &api.MountInput{ + Type: "transit", + }) + require.NoError(t, err) + + // Mount PKI backend + err = client.Sys().Mount("pki", &api.MountInput{ + Type: "pki", + }) + require.NoError(t, err) + + testTransit_ImportCertChain(t, client, "rsa-2048") + testTransit_ImportCertChain(t, client, "rsa-3072") + testTransit_ImportCertChain(t, client, "rsa-4096") + testTransit_ImportCertChain(t, client, "ecdsa-p256") + testTransit_ImportCertChain(t, client, "ecdsa-p384") + testTransit_ImportCertChain(t, client, "ecdsa-p521") + testTransit_ImportCertChain(t, client, "ed25519") +} + +func testTransit_ImportCertChain(t *testing.T, apiClient *api.Client, keyType string) { + keyName := fmt.Sprintf("%s", keyType) + issuerName := fmt.Sprintf("%s-issuer", keyType) + + // Create transit key + _, err := apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s", keyName), map[string]interface{}{ + "type": keyType, + }) + require.NoError(t, err) + + // Setup a new CSR + privKey, err := rsa.GenerateKey(cryptoRand.Reader, 3072) + require.NoError(t, err) + + var csrTemplate x509.CertificateRequest + csrTemplate.Subject.CommonName = "example.com" + reqCsrBytes, err := x509.CreateCertificateRequest(cryptoRand.Reader, &csrTemplate, privKey) + require.NoError(t, err) + + pemTemplateCsr := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: reqCsrBytes, + }) + t.Logf("csr: %v", string(pemTemplateCsr)) + + // Create CSR from template CSR fields and key in transit + resp, err := apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s/csr", keyName), map[string]interface{}{ + "csr": string(pemTemplateCsr), + }) + require.NoError(t, err) + require.NotNil(t, resp) + pemCsr := resp.Data["csr"].(string) + + // Generate PKI root + resp, err = apiClient.Logical().Write("pki/root/generate/internal", map[string]interface{}{ + "issuer_name": issuerName, + "common_name": "PKI Root X1", + }) + require.NoError(t, err) + require.NotNil(t, resp) + + rootCertPEM := resp.Data["certificate"].(string) + pemBlock, _ := pem.Decode([]byte(rootCertPEM)) + require.NotNil(t, pemBlock) + + rootCert, err := x509.ParseCertificate(pemBlock.Bytes) + require.NoError(t, err) + + // Create role to be used in the certificate issuing + resp, err = apiClient.Logical().Write("pki/roles/example-dot-com", map[string]interface{}{ + "issuer_ref": issuerName, + "allowed_domains": "example.com", + "allow_bare_domains": true, + "basic_constraints_valid_for_non_ca": true, + "key_type": "any", + }) + require.NoError(t, err) + + // Sign the CSR + resp, err = apiClient.Logical().Write("pki/sign/example-dot-com", map[string]interface{}{ + "issuer_ref": issuerName, + "csr": pemCsr, + "ttl": "10m", + }) + require.NoError(t, err) + require.NotNil(t, resp) + + leafCertPEM := resp.Data["certificate"].(string) + pemBlock, _ = pem.Decode([]byte(leafCertPEM)) + require.NotNil(t, pemBlock) + + leafCert, err := x509.ParseCertificate(pemBlock.Bytes) + require.NoError(t, err) + + require.NoError(t, leafCert.CheckSignatureFrom(rootCert)) + t.Logf("root: %v", rootCertPEM) + t.Logf("leaf: %v", leafCertPEM) + + certificateChain := strings.Join([]string{leafCertPEM, rootCertPEM}, "\n") + // Import certificate chain to transit key version + resp, err = apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s/set-certificate", keyName), map[string]interface{}{ + "certificate_chain": certificateChain, + }) + require.NoError(t, err) + require.NotNil(t, resp) + + resp, err = apiClient.Logical().Read(fmt.Sprintf("transit/keys/%s", keyName)) + require.NotNil(t, resp) + keys, ok := resp.Data["keys"].(map[string]interface{}) + if !ok { + t.Fatalf("could not cast Keys value") + } + keyData, ok := keys["1"].(map[string]interface{}) + if !ok { + t.Fatalf("could not cast key version 1 from keys") + } + _, present := keyData["certificate_chain"] + if !present { + t.Fatalf("certificate chain not present in key version 1") + } +} diff --git a/builtin/logical/transit/path_export.go b/builtin/logical/transit/path_export.go index 250188dd92..afdb6d63cd 100644 --- a/builtin/logical/transit/path_export.go +++ b/builtin/logical/transit/path_export.go @@ -21,10 +21,11 @@ import ( ) const ( - exportTypeEncryptionKey = "encryption-key" - exportTypeSigningKey = "signing-key" - exportTypeHMACKey = "hmac-key" - exportTypePublicKey = "public-key" + exportTypeEncryptionKey = "encryption-key" + exportTypeSigningKey = "signing-key" + exportTypeHMACKey = "hmac-key" + exportTypePublicKey = "public-key" + exportTypeCertificateChain = "certificate-chain" ) func (b *backend) pathExportKeys() *framework.Path { @@ -71,6 +72,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request case exportTypeSigningKey: case exportTypeHMACKey: case exportTypePublicKey: + case exportTypeCertificateChain: default: return logical.ErrorResponse(fmt.Sprintf("invalid export type: %s", exportType)), logical.ErrInvalidRequest } @@ -90,7 +92,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request } defer p.Unlock() - if !p.Exportable && exportType != exportTypePublicKey { + if !p.Exportable && exportType != exportTypePublicKey && exportType != exportTypeCertificateChain { return logical.ErrorResponse("private key material is not exportable"), nil } @@ -103,6 +105,10 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request if !p.Type.SigningSupported() { return logical.ErrorResponse("signing not supported for the key"), logical.ErrInvalidRequest } + case exportTypeCertificateChain: + if !p.Type.SigningSupported() { + return logical.ErrorResponse("certificate chain not supported for keys that do not support signing"), logical.ErrInvalidRequest + } } retKeys := map[string]string{} @@ -241,6 +247,23 @@ func getExportKey(policy *keysutil.Policy, key *keysutil.KeyEntry, exportType st } return rsaKey, nil } + case exportTypeCertificateChain: + if key.CertificateChain == nil { + return "", errors.New("selected key version does not have a certificate chain imported") + } + + var pemCerts []string + for _, derCertBytes := range key.CertificateChain { + pemCert := strings.TrimSpace(string(pem.EncodeToMemory( + &pem.Block{ + Type: "CERTIFICATE", + Bytes: derCertBytes, + }))) + pemCerts = append(pemCerts, pemCert) + } + certChain := strings.Join(pemCerts, "\n") + + return certChain, nil } return "", fmt.Errorf("unknown key type %v for export type %v", policy.Type, exportType) diff --git a/builtin/logical/transit/path_export_test.go b/builtin/logical/transit/path_export_test.go index a8d4621d9d..7b0d6feb54 100644 --- a/builtin/logical/transit/path_export_test.go +++ b/builtin/logical/transit/path_export_test.go @@ -5,12 +5,23 @@ package transit import ( "context" + cryptoRand "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "fmt" "reflect" "strconv" + "strings" "testing" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/logical/pki" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/vault" + "github.com/hashicorp/vault/sdk/logical" + "github.com/stretchr/testify/require" ) func TestTransit_Export_KeyVersion_ExportsCorrectVersion(t *testing.T) { @@ -381,3 +392,134 @@ func TestTransit_Export_EncryptionKey_DoesNotExportHMACKey(t *testing.T) { t.Fatal("Encryption key data matched hmac key data") } } + +func TestTransit_Export_CertificateChain(t *testing.T) { + generateKeys(t) + + // Create Cluster + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "transit": Factory, + "pki": pki.Factory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + + cluster.Start() + defer cluster.Cleanup() + + cores := cluster.Cores + vault.TestWaitActive(t, cores[0].Core) + client := cores[0].Client + + // Mount transit backend + err := client.Sys().Mount("transit", &api.MountInput{ + Type: "transit", + }) + require.NoError(t, err) + + // Mount PKI backend + err = client.Sys().Mount("pki", &api.MountInput{ + Type: "pki", + }) + require.NoError(t, err) + + testTransit_exportCertificateChain(t, client, "rsa-2048") + testTransit_exportCertificateChain(t, client, "rsa-3072") + testTransit_exportCertificateChain(t, client, "rsa-4096") + testTransit_exportCertificateChain(t, client, "ecdsa-p256") + testTransit_exportCertificateChain(t, client, "ecdsa-p384") + testTransit_exportCertificateChain(t, client, "ecdsa-p521") + testTransit_exportCertificateChain(t, client, "ed25519") +} + +func testTransit_exportCertificateChain(t *testing.T, apiClient *api.Client, keyType string) { + keyName := fmt.Sprintf("%s", keyType) + issuerName := fmt.Sprintf("%s-issuer", keyType) + + // Get key to be imported + privKey := getKey(t, keyType) + privKeyBytes, err := x509.MarshalPKCS8PrivateKey(privKey) + require.NoError(t, err, fmt.Sprintf("failed to marshal private key: %s", err)) + + // Create CSR + var csrTemplate x509.CertificateRequest + csrTemplate.Subject.CommonName = "example.com" + csrBytes, err := x509.CreateCertificateRequest(cryptoRand.Reader, &csrTemplate, privKey) + require.NoError(t, err, fmt.Sprintf("failed to create CSR: %s", err)) + + pemCsr := string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + })) + + // Generate PKI root + _, err = apiClient.Logical().Write("pki/root/generate/internal", map[string]interface{}{ + "issuer_name": issuerName, + "common_name": "PKI Root X1", + }) + require.NoError(t, err) + + // Create role to be used in the certificate issuing + _, err = apiClient.Logical().Write("pki/roles/example-dot-com", map[string]interface{}{ + "issuer_ref": issuerName, + "allowed_domains": "example.com", + "allow_bare_domains": true, + "basic_constraints_valid_for_non_ca": true, + "key_type": "any", + }) + require.NoError(t, err) + + // Sign the CSR + resp, err := apiClient.Logical().Write("pki/sign/example-dot-com", map[string]interface{}{ + "issuer_ref": issuerName, + "csr": pemCsr, + "ttl": "10m", + }) + require.NoError(t, err) + require.NotNil(t, resp) + + leafCertPEM := resp.Data["certificate"].(string) + + // Get wrapping key + resp, err = apiClient.Logical().Read("transit/wrapping_key") + require.NoError(t, err) + require.NotNil(t, resp) + + pubWrappingKeyString := strings.TrimSpace(resp.Data["public_key"].(string)) + wrappingKeyPemBlock, _ := pem.Decode([]byte(pubWrappingKeyString)) + + pubWrappingKey, err := x509.ParsePKIXPublicKey(wrappingKeyPemBlock.Bytes) + require.NoError(t, err, "failed to parse wrapping key") + + blob := wrapTargetPKCS8ForImport(t, pubWrappingKey.(*rsa.PublicKey), privKeyBytes, "SHA256") + + // Import key + _, err = apiClient.Logical().Write(fmt.Sprintf("/transit/keys/%s/import", keyName), map[string]interface{}{ + "ciphertext": blob, + "type": keyType, + }) + require.NoError(t, err) + + // NOTE: Should it be possible for the "import" endpoint to also import the cert chain? + // Import cert chain + _, err = apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s/set-certificate", keyName), map[string]interface{}{ + "certificate_chain": leafCertPEM, + }) + require.NoError(t, err) + + // Export cert chain + resp, err = apiClient.Logical().Read(fmt.Sprintf("transit/export/certificate-chain/%s", keyName)) + require.NoError(t, err) + require.NotNil(t, resp) + + exportedKeys := resp.Data["keys"].(map[string]interface{}) + exportedCertChainPEM := exportedKeys["1"].(string) + + if exportedCertChainPEM != leafCertPEM { + t.Fatalf("expected exported cert chain to match with imported value") + } +} diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 488be6b2f5..2ba3a7d7f9 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -12,6 +12,7 @@ import ( "encoding/pem" "fmt" "strconv" + "strings" "time" "golang.org/x/crypto/ed25519" @@ -269,9 +270,10 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d * // Built-in helper type for returning asymmetric keys type asymKey struct { - Name string `json:"name" structs:"name" mapstructure:"name"` - PublicKey string `json:"public_key" structs:"public_key" mapstructure:"public_key"` - CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"creation_time"` + Name string `json:"name" structs:"name" mapstructure:"name"` + PublicKey string `json:"public_key" structs:"public_key" mapstructure:"public_key"` + CertificateChain string `json:"certificate_chain" structs:"certificate_chain" mapstructure:"certificate_chain"` + CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"creation_time"` } func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { @@ -379,6 +381,18 @@ func (b *backend) formatKeyPolicy(p *keysutil.Policy, context []byte) (*logical. if key.CreationTime.IsZero() { key.CreationTime = time.Unix(v.DeprecatedCreationTime, 0) } + if v.CertificateChain != nil { + var pemCerts []string + for _, derCertBytes := range v.CertificateChain { + pemCert := strings.TrimSpace(string(pem.EncodeToMemory( + &pem.Block{ + Type: "CERTIFICATE", + Bytes: derCertBytes, + }))) + pemCerts = append(pemCerts, pemCert) + } + key.CertificateChain = strings.Join(pemCerts, "\n") + } switch p.Type { case keysutil.KeyType_ECDSA_P256: diff --git a/changelog/21081.txt b/changelog/21081.txt new file mode 100644 index 0000000000..ecf1713eb6 --- /dev/null +++ b/changelog/21081.txt @@ -0,0 +1,3 @@ +```release-note:improvement +secrets/transit: Add support to create CSRs from keys in transit engine and import/export x509 certificates +``` diff --git a/sdk/helper/keysutil/policy.go b/sdk/helper/keysutil/policy.go index 869733b3e9..6c35c9ac02 100644 --- a/sdk/helper/keysutil/policy.go +++ b/sdk/helper/keysutil/policy.go @@ -244,6 +244,10 @@ type KeyEntry struct { DeprecatedCreationTime int64 `json:"creation_time"` ManagedKeyUUID string `json:"managed_key_id,omitempty"` + + // Key entry certificate chain. If set, leaf certificate key matches the + // KeyEntry key + CertificateChain [][]byte `json:"certificate_chain"` } func (ke *KeyEntry) IsPrivateKeyMissing() bool { @@ -2393,3 +2397,195 @@ func wrapTargetPKCS8ForImport(wrappingKey *rsa.PublicKey, preppedTargetKey []byt wrappedKeys := append(ephKeyWrapped, targetKeyWrapped...) return base64.StdEncoding.EncodeToString(wrappedKeys), nil } + +func (p *Policy) CreateCsr(keyVersion int, csrTemplate *x509.CertificateRequest) ([]byte, error) { + if !p.Type.SigningSupported() { + return nil, errutil.UserError{Err: fmt.Sprintf("key type '%s' does not support signing", p.Type)} + } + + keyEntry, err := p.safeGetKeyEntry(keyVersion) + if err != nil { + return nil, err + } + + if keyEntry.IsPrivateKeyMissing() { + return nil, errutil.UserError{Err: "private key not imported for key version selected"} + } + + csrTemplate.Signature = nil + csrTemplate.SignatureAlgorithm = x509.UnknownSignatureAlgorithm + + var key crypto.Signer + switch p.Type { + case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521: + var curve elliptic.Curve + switch p.Type { + case KeyType_ECDSA_P384: + curve = elliptic.P384() + case KeyType_ECDSA_P521: + curve = elliptic.P521() + default: + curve = elliptic.P256() + } + + key = &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: curve, + X: keyEntry.EC_X, + Y: keyEntry.EC_Y, + }, + D: keyEntry.EC_D, + } + + case KeyType_ED25519: + if p.Derived { + return nil, errutil.UserError{Err: "operation not supported on keys with derivation enabled"} + } + key = ed25519.PrivateKey(keyEntry.Key) + + case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096: + key = keyEntry.RSAKey + + default: + return nil, errutil.InternalError{Err: fmt.Sprintf("selected key type '%s' does not support signing", p.Type.String())} + } + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key) + if err != nil { + return nil, fmt.Errorf("could not create the cerfificate request: %w", err) + } + + pemCsr := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) + + return pemCsr, nil +} + +func (p *Policy) ValidateLeafCertKeyMatch(keyVersion int, certPublicKeyAlgorithm x509.PublicKeyAlgorithm, certPublicKey any) (bool, error) { + if !p.Type.SigningSupported() { + return false, errutil.UserError{Err: fmt.Sprintf("key type '%s' does not support signing", p.Type)} + } + + var keyTypeMatches bool + switch p.Type { + case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521: + if certPublicKeyAlgorithm == x509.ECDSA { + keyTypeMatches = true + } + case KeyType_ED25519: + if certPublicKeyAlgorithm == x509.Ed25519 { + keyTypeMatches = true + } + case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096: + if certPublicKeyAlgorithm == x509.RSA { + keyTypeMatches = true + } + } + if !keyTypeMatches { + return false, errutil.UserError{Err: fmt.Sprintf("provided leaf certificate public key algorithm '%s' does not match the transit key type '%s'", + certPublicKeyAlgorithm, p.Type)} + } + + keyEntry, err := p.safeGetKeyEntry(keyVersion) + if err != nil { + return false, err + } + + switch certPublicKeyAlgorithm { + case x509.ECDSA: + certPublicKey := certPublicKey.(*ecdsa.PublicKey) + var curve elliptic.Curve + switch p.Type { + case KeyType_ECDSA_P384: + curve = elliptic.P384() + case KeyType_ECDSA_P521: + curve = elliptic.P521() + default: + curve = elliptic.P256() + } + + publicKey := &ecdsa.PublicKey{ + Curve: curve, + X: keyEntry.EC_X, + Y: keyEntry.EC_Y, + } + + // NOTE: Is it worth having this check? + if publicKey.Curve != certPublicKey.Curve { + return false, nil + } + + return publicKey.Equal(certPublicKey), nil + + case x509.Ed25519: + if p.Derived { + return false, errutil.UserError{Err: "operation not supported on keys with derivation enabled"} + } + certPublicKey := certPublicKey.(ed25519.PublicKey) + + raw, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey) + if err != nil { + return false, err + } + publicKey := ed25519.PublicKey(raw) + + return publicKey.Equal(certPublicKey), nil + + case x509.RSA: + certPublicKey := certPublicKey.(*rsa.PublicKey) + publicKey := keyEntry.RSAKey.PublicKey + return publicKey.Equal(certPublicKey), nil + + case x509.UnknownPublicKeyAlgorithm: + return false, errutil.InternalError{Err: fmt.Sprint("certificate signed with an unknown algorithm")} + } + + return false, nil +} + +func (p *Policy) ValidateAndPersistCertificateChain(ctx context.Context, keyVersion int, certChain []*x509.Certificate, storage logical.Storage) error { + if len(certChain) == 0 { + return errutil.UserError{Err: "expected at least one certificate in the parsed certificate chain"} + } + + if certChain[0].BasicConstraintsValid && certChain[0].IsCA { + return errutil.UserError{Err: "certificate in the first position is not a leaf certificate"} + } + + for _, cert := range certChain[1:] { + if cert.BasicConstraintsValid && !cert.IsCA { + return errutil.UserError{Err: "provided certificate chain contains more than one leaf certificate"} + } + } + + valid, err := p.ValidateLeafCertKeyMatch(keyVersion, certChain[0].PublicKeyAlgorithm, certChain[0].PublicKey) + if err != nil { + prefixedErr := fmt.Errorf("could not validate key match between leaf certificate key and key version in transit: %w", err) + switch err.(type) { + case errutil.UserError: + return errutil.UserError{Err: prefixedErr.Error()} + default: + return prefixedErr + } + } + if !valid { + return fmt.Errorf("leaf certificate public key does match the key version selected") + } + + keyEntry, err := p.safeGetKeyEntry(keyVersion) + if err != nil { + return err + } + + // Convert the certificate chain to DER format + derCertificates := make([][]byte, len(certChain)) + for i, cert := range certChain { + derCertificates[i] = cert.Raw + } + + keyEntry.CertificateChain = derCertificates + + p.Keys[strconv.Itoa(keyVersion)] = keyEntry + return p.Persist(ctx, storage) +}