Transit backend: Create CSR's from keys in transit and import certificate chains (#21081)

* setup initial boilerplate code for sign csr endpoint

* add function to sign csr

* working version of sign csr endpoint

* improving errors for csr create and sign endpoint

* initial implementation for import leaf certificate endpoint

* check if more than one certificate was provided in the ceritificate chain

* improve validate cert public key matches transit key

* convert provided cert chain from PEM to DER so it can be parsed by
x509.ParseCertificates and fixing other bugs

* fix creation of csr from csrTemplate

* add missing persist of certificate chain after validations in set-certificate endpoint

* allow exporting a certificate-chain

* move function declaration to end of page

* improving variable and function names, removing comments

* fix certificate chain parsing - work in progress

* test for signCsr endpoint

* use Operations instead of Callbacks in framework.Path

* setup test for set-certificate endpoint

fix problems with sign-csr endpoint returning base64

* finish set-certificate endpoint test

* use public key KeyEntry fields instead of retrieving public key from private

* improve error message and make better distinction between client and server error

also moved check of key types before checking if key match to endpoint handler

* check if private key has been imported for key version selected when signing a csr

* improve errors

* add endpoint description and synopsis

* fix functions calls in backend as function names changed

* improve import cert chain test

* trim whitespaces on export certificate chain

* changelog

* pass context from handler function to policy Persist

* make fmt run

* fix: assign returned error from PersistCertificateChain to err so it can be evaluated

* additional validations and improvements to parseCertificateChain function

* add validation to check if there is only one certificate in the certificate chain and it is in the first position

* import cert chain test: move creation of cluster to exported test function

* move check of end-cert pub key algorithm and key transit algorithm match into a separate function

* test export certificate chain

* Update sdk/helper/keysutil/policy.go

Co-authored-by: Alexander Scheel <alexander.m.scheel@gmail.com>

* fix validateLeafCertPosition

* reject certificate actions on policies that allow key derivation and remove derived checks

* return UserError from CreateCSR SDK function as 400 in transit API handler

* add derived check for ED5519 keys on CreateCSR SDK func

* remove unecessary calls of x509.CreateCertificateRequest

* move validate key type match back into SDK ValidateLeafCertMatch function

* add additional validations (ValidateLeafCertKeyMatch, etc) in SDK PersistCertificateChain function

* remove uncessary call of ValidateLeafCertKeyMatch in parseImportCertChainWrite

* store certificate chain as a [][]byte instead of []*x509.Certificate

* include persisted ca chain in import cert-chain response

* remove NOTE comment

* allow exporting cert-chain even if exportable is set as false

* remove NOTE comment

* add certifcate chain to formatKeyPublic if present

also added an additional check to validate if field is added when
certchain is present

---------

Co-authored-by: Alexander Scheel <alexander.m.scheel@gmail.com>
This commit is contained in:
Gabriel Santos 2023-08-22 13:24:56 +01:00 committed by GitHub
parent d50bd4eb05
commit 1996377b4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 948 additions and 8 deletions

View file

@ -73,6 +73,8 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
b.pathTrim(),
b.pathCacheConfig(),
b.pathConfigKeys(),
b.pathCreateCsr(),
b.pathImportCertChain(),
},
Secrets: []*framework.Secret{},

View file

@ -0,0 +1,300 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package transit
import (
"context"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"strings"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
)
func (b *backend) pathCreateCsr() *framework.Path {
return &framework.Path{
Pattern: "keys/" + framework.GenericNameRegex("name") + "/csr",
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Required: true,
Description: "Name of the key",
},
"version": {
Type: framework.TypeInt,
Required: false,
Description: "Optional version of key, 'latest' if not set",
},
"csr": {
Type: framework.TypeString,
Required: false,
Description: `PEM encoded CSR template. The information attributes
will be used as a basis for the CSR with the key in transit. If not set, an empty CSR is returned.`,
},
},
Operations: map[logical.Operation]framework.OperationHandler{
// NOTE: Create and Update?
logical.CreateOperation: &framework.PathOperation{
Callback: b.pathCreateCsrWrite,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "create",
},
},
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathCreateCsrWrite,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "update",
},
},
},
HelpSynopsis: pathCreateCsrHelpSyn,
HelpDescription: pathCreateCsrHelpDesc,
}
}
func (b *backend) pathImportCertChain() *framework.Path {
return &framework.Path{
// NOTE: `set-certificate` or `set_certificate`? Paths seem to use different
// case, such as `transit/wrapping_key` and `transit/cache-config`.
Pattern: "keys/" + framework.GenericNameRegex("name") + "/set-certificate",
Fields: map[string]*framework.FieldSchema{
"name": {
Type: framework.TypeString,
Required: true,
Description: "Name of the key",
},
"version": {
Type: framework.TypeInt,
Required: false,
Description: "Optional version of key, 'latest' if not set",
},
"certificate_chain": {
Type: framework.TypeString,
Required: true,
Description: `PEM encoded certificate chain. It should be composed
by one or more concatenated PEM blocks and ordered starting from the end-entity certificate.`,
},
},
Operations: map[logical.Operation]framework.OperationHandler{
// NOTE: Create and Update?
logical.CreateOperation: &framework.PathOperation{
Callback: b.pathImportCertChainWrite,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "create",
},
},
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathImportCertChainWrite,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "update",
},
},
},
HelpSynopsis: pathImportCertChainHelpSyn,
HelpDescription: pathImportCertChainHelpDesc,
}
}
func (b *backend) pathCreateCsrWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
if err != nil {
return nil, err
}
if p == nil {
return logical.ErrorResponse(fmt.Sprintf("key with provided name '%s' not found", name)), logical.ErrInvalidRequest
}
if !b.System().CachingDisabled() {
p.Lock(false) // NOTE: No lock on "read" operations?
}
defer p.Unlock()
// Check if transit key supports signing
if !p.Type.SigningSupported() {
return logical.ErrorResponse(fmt.Sprintf("key type '%s' does not support signing", p.Type)), logical.ErrInvalidRequest
}
// Check if key can be derived
if p.Derived {
return logical.ErrorResponse("operation not supported on keys with derivation enabled"), logical.ErrInvalidRequest
}
// Transit key version
signingKeyVersion := p.LatestVersion
// NOTE: BYOK endpoints seem to remove "v" prefix from version,
// are versions like that also supported?
if version, ok := d.GetOk("version"); ok {
signingKeyVersion = version.(int)
}
// Read and parse CSR template
pemCsrTemplate := d.Get("csr").(string)
csrTemplate, err := parseCsr(pemCsrTemplate)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
pemCsr, err := p.CreateCsr(signingKeyVersion, csrTemplate)
if err != nil {
prefixedErr := fmt.Errorf("could not create the csr: %w", err)
switch err.(type) {
case errutil.UserError:
return logical.ErrorResponse(prefixedErr.Error()), logical.ErrInvalidRequest
default:
return nil, prefixedErr
}
}
resp := &logical.Response{
Data: map[string]interface{}{
"name": p.Name,
"type": p.Type.String(),
"csr": string(pemCsr),
},
}
return resp, nil
}
func (b *backend) pathImportCertChainWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
name := d.Get("name").(string)
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: req.Storage,
Name: name,
}, b.GetRandomReader())
if err != nil {
return nil, err
}
if p == nil {
return logical.ErrorResponse(fmt.Sprintf("key with provided name '%s' not found", name)), logical.ErrInvalidRequest
}
if !b.System().CachingDisabled() {
p.Lock(true) // NOTE: Lock as we are might write to the policy
}
defer p.Unlock()
// Check if transit key supports signing
if !p.Type.SigningSupported() {
return logical.ErrorResponse(fmt.Sprintf("key type %s does not support signing", p.Type)), logical.ErrInvalidRequest
}
// Check if key can be derived
if p.Derived {
return logical.ErrorResponse("operation not supported on keys with derivation enabled"), logical.ErrInvalidRequest
}
// Transit key version
keyVersion := p.LatestVersion
if version, ok := d.GetOk("version"); ok {
keyVersion = version.(int)
}
// Get certificate chain
pemCertChain := d.Get("certificate_chain").(string)
certChain, err := parseCertificateChain(pemCertChain)
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
err = p.ValidateAndPersistCertificateChain(ctx, keyVersion, certChain, req.Storage)
if err != nil {
prefixedErr := fmt.Errorf("failed to persist certificate chain: %w", err)
switch err.(type) {
case errutil.UserError:
return logical.ErrorResponse(prefixedErr.Error()), logical.ErrInvalidRequest
default:
return nil, prefixedErr
}
}
resp := &logical.Response{
Data: map[string]interface{}{
"name": p.Name,
"type": p.Type.String(),
"certificate-chain": pemCertChain,
},
}
return resp, nil
}
func parseCsr(csrStr string) (*x509.CertificateRequest, error) {
if csrStr == "" {
return &x509.CertificateRequest{}, nil
}
block, _ := pem.Decode([]byte(csrStr))
if block == nil {
return nil, errors.New("could not decode PEM certificate request")
}
csr, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil {
return nil, err
}
return csr, nil
}
func parseCertificateChain(certChainString string) ([]*x509.Certificate, error) {
var certificates []*x509.Certificate
var pemCertBlocks []*pem.Block
pemBytes := []byte(strings.TrimSpace(certChainString))
for len(pemBytes) > 0 {
var pemCertBlock *pem.Block
pemCertBlock, pemBytes = pem.Decode(pemBytes)
if pemCertBlock == nil {
return nil, errors.New("could not decode PEM block in certificate chain")
}
switch pemCertBlock.Type {
case "CERTIFICATE", "X05 CERTIFICATE":
pemCertBlocks = append(pemCertBlocks, pemCertBlock)
default:
// Ignore any other entries
}
}
if len(pemCertBlocks) == 0 {
return nil, errors.New("provided certificate chain did not contain any valid PEM certificate")
}
for _, certBlock := range pemCertBlocks {
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate in certificate chain: %w", err)
}
certificates = append(certificates, cert)
}
return certificates, nil
}
const pathCreateCsrHelpSyn = `Create a CSR from a key in transit`
const pathCreateCsrHelpDesc = `This path is used to create a CSR from a key in
transit. If a CSR template is provided, its significant information, expect key
related data, are included in the CSR otherwise an empty CSR is returned.
`
const pathImportCertChainHelpSyn = `Imports an externally-signed certificate
chain into an existing key version`
const pathImportCertChainHelpDesc = `This path is used to import an externally-
signed certificate chain into a key in transit. The leaf certificate key has to
match the selected key in transit.
`

View file

@ -0,0 +1,260 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package transit
import (
"context"
cryptoRand "crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"reflect"
"strings"
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require"
)
func TestTransit_Certs_CreateCsr(t *testing.T) {
// NOTE: Use an existing CSR or generate one here?
templateCsr := `
-----BEGIN CERTIFICATE REQUEST-----
MIICRTCCAS0CAQAwADCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAM49
McW7u3ILuAJfSFLUtGOMGBytHmMFcjTiX+5JcajFj0Uszb+HQ7eIsJJNXhVc/7fg
Z01DZvcCqb9ChEWE3xi4GEkPMXay7p7G1ooSLnQp6Z0lL5CuIFfMVOTvjfhTwRaJ
l9v2mMlm80BeiAUBqeoyGVrIh5fKASxaE0jrhjAxhGzqrXdDnL8A4na6ArprV4iS
aEAziODd2WmplSKgUwEaFdeG1t1bJf3o5ZQRCnKNtQcAk8UmgtvFEO8ohGMln/Fj
O7u7s6iRhOGf1g1NCAP5pGqxNx3bjz5f/CUcTSIGAReEomg41QTIhD9muCTL8qnm
6lS87wkGTv7qbeIGB7sCAwEAAaAAMA0GCSqGSIb3DQEBCwUAA4IBAQAfjE+jNqIk
4V1tL3g5XPjxr2+QcwddPf8opmbAzgt0+TiIHcDGBAxsXyi7sC9E5AFfFp7W07Zv
r5+v4i529K9q0BgGtHFswoEnhd4dC8Ye53HtSoEtXkBpZMDrtbS7eZa9WccT6zNx
4taTkpptZVrmvPj+jLLFkpKJJ3d+Gbrp6hiORPadT+igLKkqvTeocnhOdAtt427M
RXTVgN14pV3tqO+5MXzNw5tGNPcwWARWwPH9eCRxLwLUuxE4Qu73pUeEFjDEfGkN
iBnlTsTXBOMqSGryEkmRaZslWDvblvYeObYw+uc3kCbJ7jRy9soVwkbb5FueF/yC
O1aQIm23HrrG
-----END CERTIFICATE REQUEST-----
`
testTransit_CreateCsr(t, "rsa-2048", templateCsr)
testTransit_CreateCsr(t, "rsa-3072", templateCsr)
testTransit_CreateCsr(t, "rsa-4096", templateCsr)
testTransit_CreateCsr(t, "ecdsa-p256", templateCsr)
testTransit_CreateCsr(t, "ecdsa-p384", templateCsr)
testTransit_CreateCsr(t, "ecdsa-p521", templateCsr)
testTransit_CreateCsr(t, "ed25519", templateCsr)
testTransit_CreateCsr(t, "aes256-gcm96", templateCsr)
}
func testTransit_CreateCsr(t *testing.T, keyType, pemTemplateCsr string) {
var resp *logical.Response
var err error
b, s := createBackendWithStorage(t)
// Create the policy
policyReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/test-key",
Storage: s,
Data: map[string]interface{}{
"type": keyType,
},
}
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("resp: %#v\nerr: %v", resp, err)
}
csrSignReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/test-key/csr",
Storage: s,
Data: map[string]interface{}{
"csr": pemTemplateCsr,
},
}
resp, err = b.HandleRequest(context.Background(), csrSignReq)
switch keyType {
case "rsa-2048", "rsa-3072", "rsa-4096", "ecdsa-p256", "ecdsa-p384", "ecdsa-p521", "ed25519":
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("failed to sign CSR, err:%v resp:%#v", err, resp)
}
signedCsrBytes, ok := resp.Data["csr"]
if !ok {
t.Fatal("expected response data to hold a 'csr' field")
}
signedCsr, err := parseCsr(signedCsrBytes.(string))
if err != nil {
t.Fatalf("failed to parse returned csr, err:%v", err)
}
templateCsr, err := parseCsr(pemTemplateCsr)
if err != nil {
t.Fatalf("failed to parse returned template csr, err:%v", err)
}
// NOTE: Check other fields?
if !reflect.DeepEqual(signedCsr.Subject, templateCsr.Subject) {
t.Fatalf("subjects should have matched, err:%v", err)
}
default:
if err == nil || (resp != nil && !resp.IsError()) {
t.Fatalf("should have failed to sign CSR, provided key type does not support signing")
}
}
}
// NOTE: Tests are using two 'different' methods of checking for errors, which one sould we prefer?
func TestTransit_Certs_ImportCertChain(t *testing.T) {
// Create Cluster
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": Factory,
"pki": pki.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
// Mount transit backend
err := client.Sys().Mount("transit", &api.MountInput{
Type: "transit",
})
require.NoError(t, err)
// Mount PKI backend
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
require.NoError(t, err)
testTransit_ImportCertChain(t, client, "rsa-2048")
testTransit_ImportCertChain(t, client, "rsa-3072")
testTransit_ImportCertChain(t, client, "rsa-4096")
testTransit_ImportCertChain(t, client, "ecdsa-p256")
testTransit_ImportCertChain(t, client, "ecdsa-p384")
testTransit_ImportCertChain(t, client, "ecdsa-p521")
testTransit_ImportCertChain(t, client, "ed25519")
}
func testTransit_ImportCertChain(t *testing.T, apiClient *api.Client, keyType string) {
keyName := fmt.Sprintf("%s", keyType)
issuerName := fmt.Sprintf("%s-issuer", keyType)
// Create transit key
_, err := apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s", keyName), map[string]interface{}{
"type": keyType,
})
require.NoError(t, err)
// Setup a new CSR
privKey, err := rsa.GenerateKey(cryptoRand.Reader, 3072)
require.NoError(t, err)
var csrTemplate x509.CertificateRequest
csrTemplate.Subject.CommonName = "example.com"
reqCsrBytes, err := x509.CreateCertificateRequest(cryptoRand.Reader, &csrTemplate, privKey)
require.NoError(t, err)
pemTemplateCsr := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: reqCsrBytes,
})
t.Logf("csr: %v", string(pemTemplateCsr))
// Create CSR from template CSR fields and key in transit
resp, err := apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s/csr", keyName), map[string]interface{}{
"csr": string(pemTemplateCsr),
})
require.NoError(t, err)
require.NotNil(t, resp)
pemCsr := resp.Data["csr"].(string)
// Generate PKI root
resp, err = apiClient.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"issuer_name": issuerName,
"common_name": "PKI Root X1",
})
require.NoError(t, err)
require.NotNil(t, resp)
rootCertPEM := resp.Data["certificate"].(string)
pemBlock, _ := pem.Decode([]byte(rootCertPEM))
require.NotNil(t, pemBlock)
rootCert, err := x509.ParseCertificate(pemBlock.Bytes)
require.NoError(t, err)
// Create role to be used in the certificate issuing
resp, err = apiClient.Logical().Write("pki/roles/example-dot-com", map[string]interface{}{
"issuer_ref": issuerName,
"allowed_domains": "example.com",
"allow_bare_domains": true,
"basic_constraints_valid_for_non_ca": true,
"key_type": "any",
})
require.NoError(t, err)
// Sign the CSR
resp, err = apiClient.Logical().Write("pki/sign/example-dot-com", map[string]interface{}{
"issuer_ref": issuerName,
"csr": pemCsr,
"ttl": "10m",
})
require.NoError(t, err)
require.NotNil(t, resp)
leafCertPEM := resp.Data["certificate"].(string)
pemBlock, _ = pem.Decode([]byte(leafCertPEM))
require.NotNil(t, pemBlock)
leafCert, err := x509.ParseCertificate(pemBlock.Bytes)
require.NoError(t, err)
require.NoError(t, leafCert.CheckSignatureFrom(rootCert))
t.Logf("root: %v", rootCertPEM)
t.Logf("leaf: %v", leafCertPEM)
certificateChain := strings.Join([]string{leafCertPEM, rootCertPEM}, "\n")
// Import certificate chain to transit key version
resp, err = apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s/set-certificate", keyName), map[string]interface{}{
"certificate_chain": certificateChain,
})
require.NoError(t, err)
require.NotNil(t, resp)
resp, err = apiClient.Logical().Read(fmt.Sprintf("transit/keys/%s", keyName))
require.NotNil(t, resp)
keys, ok := resp.Data["keys"].(map[string]interface{})
if !ok {
t.Fatalf("could not cast Keys value")
}
keyData, ok := keys["1"].(map[string]interface{})
if !ok {
t.Fatalf("could not cast key version 1 from keys")
}
_, present := keyData["certificate_chain"]
if !present {
t.Fatalf("certificate chain not present in key version 1")
}
}

View file

@ -21,10 +21,11 @@ import (
)
const (
exportTypeEncryptionKey = "encryption-key"
exportTypeSigningKey = "signing-key"
exportTypeHMACKey = "hmac-key"
exportTypePublicKey = "public-key"
exportTypeEncryptionKey = "encryption-key"
exportTypeSigningKey = "signing-key"
exportTypeHMACKey = "hmac-key"
exportTypePublicKey = "public-key"
exportTypeCertificateChain = "certificate-chain"
)
func (b *backend) pathExportKeys() *framework.Path {
@ -71,6 +72,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
case exportTypeSigningKey:
case exportTypeHMACKey:
case exportTypePublicKey:
case exportTypeCertificateChain:
default:
return logical.ErrorResponse(fmt.Sprintf("invalid export type: %s", exportType)), logical.ErrInvalidRequest
}
@ -90,7 +92,7 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
}
defer p.Unlock()
if !p.Exportable && exportType != exportTypePublicKey {
if !p.Exportable && exportType != exportTypePublicKey && exportType != exportTypeCertificateChain {
return logical.ErrorResponse("private key material is not exportable"), nil
}
@ -103,6 +105,10 @@ func (b *backend) pathPolicyExportRead(ctx context.Context, req *logical.Request
if !p.Type.SigningSupported() {
return logical.ErrorResponse("signing not supported for the key"), logical.ErrInvalidRequest
}
case exportTypeCertificateChain:
if !p.Type.SigningSupported() {
return logical.ErrorResponse("certificate chain not supported for keys that do not support signing"), logical.ErrInvalidRequest
}
}
retKeys := map[string]string{}
@ -241,6 +247,23 @@ func getExportKey(policy *keysutil.Policy, key *keysutil.KeyEntry, exportType st
}
return rsaKey, nil
}
case exportTypeCertificateChain:
if key.CertificateChain == nil {
return "", errors.New("selected key version does not have a certificate chain imported")
}
var pemCerts []string
for _, derCertBytes := range key.CertificateChain {
pemCert := strings.TrimSpace(string(pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: derCertBytes,
})))
pemCerts = append(pemCerts, pemCert)
}
certChain := strings.Join(pemCerts, "\n")
return certChain, nil
}
return "", fmt.Errorf("unknown key type %v for export type %v", policy.Type, exportType)

View file

@ -5,12 +5,23 @@ package transit
import (
"context"
cryptoRand "crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"reflect"
"strconv"
"strings"
"testing"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/pki"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/sdk/logical"
"github.com/stretchr/testify/require"
)
func TestTransit_Export_KeyVersion_ExportsCorrectVersion(t *testing.T) {
@ -381,3 +392,134 @@ func TestTransit_Export_EncryptionKey_DoesNotExportHMACKey(t *testing.T) {
t.Fatal("Encryption key data matched hmac key data")
}
}
func TestTransit_Export_CertificateChain(t *testing.T) {
generateKeys(t)
// Create Cluster
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": Factory,
"pki": pki.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
client := cores[0].Client
// Mount transit backend
err := client.Sys().Mount("transit", &api.MountInput{
Type: "transit",
})
require.NoError(t, err)
// Mount PKI backend
err = client.Sys().Mount("pki", &api.MountInput{
Type: "pki",
})
require.NoError(t, err)
testTransit_exportCertificateChain(t, client, "rsa-2048")
testTransit_exportCertificateChain(t, client, "rsa-3072")
testTransit_exportCertificateChain(t, client, "rsa-4096")
testTransit_exportCertificateChain(t, client, "ecdsa-p256")
testTransit_exportCertificateChain(t, client, "ecdsa-p384")
testTransit_exportCertificateChain(t, client, "ecdsa-p521")
testTransit_exportCertificateChain(t, client, "ed25519")
}
func testTransit_exportCertificateChain(t *testing.T, apiClient *api.Client, keyType string) {
keyName := fmt.Sprintf("%s", keyType)
issuerName := fmt.Sprintf("%s-issuer", keyType)
// Get key to be imported
privKey := getKey(t, keyType)
privKeyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
require.NoError(t, err, fmt.Sprintf("failed to marshal private key: %s", err))
// Create CSR
var csrTemplate x509.CertificateRequest
csrTemplate.Subject.CommonName = "example.com"
csrBytes, err := x509.CreateCertificateRequest(cryptoRand.Reader, &csrTemplate, privKey)
require.NoError(t, err, fmt.Sprintf("failed to create CSR: %s", err))
pemCsr := string(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrBytes,
}))
// Generate PKI root
_, err = apiClient.Logical().Write("pki/root/generate/internal", map[string]interface{}{
"issuer_name": issuerName,
"common_name": "PKI Root X1",
})
require.NoError(t, err)
// Create role to be used in the certificate issuing
_, err = apiClient.Logical().Write("pki/roles/example-dot-com", map[string]interface{}{
"issuer_ref": issuerName,
"allowed_domains": "example.com",
"allow_bare_domains": true,
"basic_constraints_valid_for_non_ca": true,
"key_type": "any",
})
require.NoError(t, err)
// Sign the CSR
resp, err := apiClient.Logical().Write("pki/sign/example-dot-com", map[string]interface{}{
"issuer_ref": issuerName,
"csr": pemCsr,
"ttl": "10m",
})
require.NoError(t, err)
require.NotNil(t, resp)
leafCertPEM := resp.Data["certificate"].(string)
// Get wrapping key
resp, err = apiClient.Logical().Read("transit/wrapping_key")
require.NoError(t, err)
require.NotNil(t, resp)
pubWrappingKeyString := strings.TrimSpace(resp.Data["public_key"].(string))
wrappingKeyPemBlock, _ := pem.Decode([]byte(pubWrappingKeyString))
pubWrappingKey, err := x509.ParsePKIXPublicKey(wrappingKeyPemBlock.Bytes)
require.NoError(t, err, "failed to parse wrapping key")
blob := wrapTargetPKCS8ForImport(t, pubWrappingKey.(*rsa.PublicKey), privKeyBytes, "SHA256")
// Import key
_, err = apiClient.Logical().Write(fmt.Sprintf("/transit/keys/%s/import", keyName), map[string]interface{}{
"ciphertext": blob,
"type": keyType,
})
require.NoError(t, err)
// NOTE: Should it be possible for the "import" endpoint to also import the cert chain?
// Import cert chain
_, err = apiClient.Logical().Write(fmt.Sprintf("transit/keys/%s/set-certificate", keyName), map[string]interface{}{
"certificate_chain": leafCertPEM,
})
require.NoError(t, err)
// Export cert chain
resp, err = apiClient.Logical().Read(fmt.Sprintf("transit/export/certificate-chain/%s", keyName))
require.NoError(t, err)
require.NotNil(t, resp)
exportedKeys := resp.Data["keys"].(map[string]interface{})
exportedCertChainPEM := exportedKeys["1"].(string)
if exportedCertChainPEM != leafCertPEM {
t.Fatalf("expected exported cert chain to match with imported value")
}
}

View file

@ -12,6 +12,7 @@ import (
"encoding/pem"
"fmt"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ed25519"
@ -269,9 +270,10 @@ func (b *backend) pathPolicyWrite(ctx context.Context, req *logical.Request, d *
// Built-in helper type for returning asymmetric keys
type asymKey struct {
Name string `json:"name" structs:"name" mapstructure:"name"`
PublicKey string `json:"public_key" structs:"public_key" mapstructure:"public_key"`
CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"creation_time"`
Name string `json:"name" structs:"name" mapstructure:"name"`
PublicKey string `json:"public_key" structs:"public_key" mapstructure:"public_key"`
CertificateChain string `json:"certificate_chain" structs:"certificate_chain" mapstructure:"certificate_chain"`
CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"creation_time"`
}
func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
@ -379,6 +381,18 @@ func (b *backend) formatKeyPolicy(p *keysutil.Policy, context []byte) (*logical.
if key.CreationTime.IsZero() {
key.CreationTime = time.Unix(v.DeprecatedCreationTime, 0)
}
if v.CertificateChain != nil {
var pemCerts []string
for _, derCertBytes := range v.CertificateChain {
pemCert := strings.TrimSpace(string(pem.EncodeToMemory(
&pem.Block{
Type: "CERTIFICATE",
Bytes: derCertBytes,
})))
pemCerts = append(pemCerts, pemCert)
}
key.CertificateChain = strings.Join(pemCerts, "\n")
}
switch p.Type {
case keysutil.KeyType_ECDSA_P256:

3
changelog/21081.txt Normal file
View file

@ -0,0 +1,3 @@
```release-note:improvement
secrets/transit: Add support to create CSRs from keys in transit engine and import/export x509 certificates
```

View file

@ -244,6 +244,10 @@ type KeyEntry struct {
DeprecatedCreationTime int64 `json:"creation_time"`
ManagedKeyUUID string `json:"managed_key_id,omitempty"`
// Key entry certificate chain. If set, leaf certificate key matches the
// KeyEntry key
CertificateChain [][]byte `json:"certificate_chain"`
}
func (ke *KeyEntry) IsPrivateKeyMissing() bool {
@ -2393,3 +2397,195 @@ func wrapTargetPKCS8ForImport(wrappingKey *rsa.PublicKey, preppedTargetKey []byt
wrappedKeys := append(ephKeyWrapped, targetKeyWrapped...)
return base64.StdEncoding.EncodeToString(wrappedKeys), nil
}
func (p *Policy) CreateCsr(keyVersion int, csrTemplate *x509.CertificateRequest) ([]byte, error) {
if !p.Type.SigningSupported() {
return nil, errutil.UserError{Err: fmt.Sprintf("key type '%s' does not support signing", p.Type)}
}
keyEntry, err := p.safeGetKeyEntry(keyVersion)
if err != nil {
return nil, err
}
if keyEntry.IsPrivateKeyMissing() {
return nil, errutil.UserError{Err: "private key not imported for key version selected"}
}
csrTemplate.Signature = nil
csrTemplate.SignatureAlgorithm = x509.UnknownSignatureAlgorithm
var key crypto.Signer
switch p.Type {
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
var curve elliptic.Curve
switch p.Type {
case KeyType_ECDSA_P384:
curve = elliptic.P384()
case KeyType_ECDSA_P521:
curve = elliptic.P521()
default:
curve = elliptic.P256()
}
key = &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: curve,
X: keyEntry.EC_X,
Y: keyEntry.EC_Y,
},
D: keyEntry.EC_D,
}
case KeyType_ED25519:
if p.Derived {
return nil, errutil.UserError{Err: "operation not supported on keys with derivation enabled"}
}
key = ed25519.PrivateKey(keyEntry.Key)
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
key = keyEntry.RSAKey
default:
return nil, errutil.InternalError{Err: fmt.Sprintf("selected key type '%s' does not support signing", p.Type.String())}
}
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, csrTemplate, key)
if err != nil {
return nil, fmt.Errorf("could not create the cerfificate request: %w", err)
}
pemCsr := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrBytes,
})
return pemCsr, nil
}
func (p *Policy) ValidateLeafCertKeyMatch(keyVersion int, certPublicKeyAlgorithm x509.PublicKeyAlgorithm, certPublicKey any) (bool, error) {
if !p.Type.SigningSupported() {
return false, errutil.UserError{Err: fmt.Sprintf("key type '%s' does not support signing", p.Type)}
}
var keyTypeMatches bool
switch p.Type {
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
if certPublicKeyAlgorithm == x509.ECDSA {
keyTypeMatches = true
}
case KeyType_ED25519:
if certPublicKeyAlgorithm == x509.Ed25519 {
keyTypeMatches = true
}
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
if certPublicKeyAlgorithm == x509.RSA {
keyTypeMatches = true
}
}
if !keyTypeMatches {
return false, errutil.UserError{Err: fmt.Sprintf("provided leaf certificate public key algorithm '%s' does not match the transit key type '%s'",
certPublicKeyAlgorithm, p.Type)}
}
keyEntry, err := p.safeGetKeyEntry(keyVersion)
if err != nil {
return false, err
}
switch certPublicKeyAlgorithm {
case x509.ECDSA:
certPublicKey := certPublicKey.(*ecdsa.PublicKey)
var curve elliptic.Curve
switch p.Type {
case KeyType_ECDSA_P384:
curve = elliptic.P384()
case KeyType_ECDSA_P521:
curve = elliptic.P521()
default:
curve = elliptic.P256()
}
publicKey := &ecdsa.PublicKey{
Curve: curve,
X: keyEntry.EC_X,
Y: keyEntry.EC_Y,
}
// NOTE: Is it worth having this check?
if publicKey.Curve != certPublicKey.Curve {
return false, nil
}
return publicKey.Equal(certPublicKey), nil
case x509.Ed25519:
if p.Derived {
return false, errutil.UserError{Err: "operation not supported on keys with derivation enabled"}
}
certPublicKey := certPublicKey.(ed25519.PublicKey)
raw, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
if err != nil {
return false, err
}
publicKey := ed25519.PublicKey(raw)
return publicKey.Equal(certPublicKey), nil
case x509.RSA:
certPublicKey := certPublicKey.(*rsa.PublicKey)
publicKey := keyEntry.RSAKey.PublicKey
return publicKey.Equal(certPublicKey), nil
case x509.UnknownPublicKeyAlgorithm:
return false, errutil.InternalError{Err: fmt.Sprint("certificate signed with an unknown algorithm")}
}
return false, nil
}
func (p *Policy) ValidateAndPersistCertificateChain(ctx context.Context, keyVersion int, certChain []*x509.Certificate, storage logical.Storage) error {
if len(certChain) == 0 {
return errutil.UserError{Err: "expected at least one certificate in the parsed certificate chain"}
}
if certChain[0].BasicConstraintsValid && certChain[0].IsCA {
return errutil.UserError{Err: "certificate in the first position is not a leaf certificate"}
}
for _, cert := range certChain[1:] {
if cert.BasicConstraintsValid && !cert.IsCA {
return errutil.UserError{Err: "provided certificate chain contains more than one leaf certificate"}
}
}
valid, err := p.ValidateLeafCertKeyMatch(keyVersion, certChain[0].PublicKeyAlgorithm, certChain[0].PublicKey)
if err != nil {
prefixedErr := fmt.Errorf("could not validate key match between leaf certificate key and key version in transit: %w", err)
switch err.(type) {
case errutil.UserError:
return errutil.UserError{Err: prefixedErr.Error()}
default:
return prefixedErr
}
}
if !valid {
return fmt.Errorf("leaf certificate public key does match the key version selected")
}
keyEntry, err := p.safeGetKeyEntry(keyVersion)
if err != nil {
return err
}
// Convert the certificate chain to DER format
derCertificates := make([][]byte, len(certChain))
for i, cert := range certChain {
derCertificates[i] = cert.Raw
}
keyEntry.CertificateChain = derCertificates
p.Keys[strconv.Itoa(keyVersion)] = keyEntry
return p.Persist(ctx, storage)
}