diff --git a/builtin/logical/pki/ca_util.go b/builtin/logical/pki/ca_util.go index bf21e4cf9e..56079f9e89 100644 --- a/builtin/logical/pki/ca_util.go +++ b/builtin/logical/pki/ca_util.go @@ -13,6 +13,8 @@ import ( "io" "time" + "github.com/cloudflare/circl/sign/mldsa/mldsa65" + "github.com/cloudflare/circl/sign/mldsa/mldsa87" "github.com/hashicorp/vault/builtin/logical/pki/issuing" "github.com/hashicorp/vault/builtin/logical/pki/managed_key" "github.com/hashicorp/vault/sdk/framework" @@ -272,6 +274,10 @@ func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (certutil.Pr keyType = certutil.ECPrivateKey case ed25519.PublicKey: keyType = certutil.Ed25519PrivateKey + case *mldsa65.PublicKey: + keyType = certutil.MLDSA65PrivateKey + case *mldsa87.PublicKey: + keyType = certutil.MLDSA87PrivateKey default: return certutil.UnknownPrivateKey, 0, fmt.Errorf("unsupported public key: %#v", pubKey) } diff --git a/builtin/logical/pki/path_roles.go b/builtin/logical/pki/path_roles.go index 4a2638e47a..9af4bd5594 100644 --- a/builtin/logical/pki/path_roles.go +++ b/builtin/logical/pki/path_roles.go @@ -207,7 +207,7 @@ protection use. Defaults to false. See also RFC 5280 Section 4.2.1.12.`, Type: framework.TypeString, Required: true, Description: `The type of key to use; defaults to RSA. "rsa" -"ec", "ed25519" and "any" are the only valid values.`, +"ec", "ed25519", "ml-dsa-65", "ml-dsa-87" and "any" are the only valid values.`, }, "key_bits": { @@ -594,8 +594,8 @@ protection use. Defaults to false. See also RFC 5280 Section 4.2.1.12.`, Type: framework.TypeString, Default: "rsa", Description: `The type of key to use; defaults to RSA. "rsa" -"ec", "ed25519" and "any" are the only valid values.`, - AllowedValues: []interface{}{"rsa", "ec", "ed25519", "any"}, +"ec", "ed25519", "ml-dsa-65", "ml-dsa-87" and "any" are the only valid values.`, + AllowedValues: []interface{}{"rsa", "ec", "ed25519", "ml-dsa-65", "ml-dsa-87", "any"}, }, "key_bits": { diff --git a/go.mod b/go.mod index 92f02204db..89950ab495 100644 --- a/go.mod +++ b/go.mod @@ -383,7 +383,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible // indirect github.com/circonus-labs/circonusllhist v0.1.3 // indirect - github.com/cloudflare/circl v1.6.3 // indirect + github.com/cloudflare/circl v1.6.3 github.com/cloudfoundry-community/go-cfclient v0.0.0-20220930021109-9c4e6c59ccf1 // indirect github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 // indirect github.com/containerd/continuity v0.4.5 // indirect diff --git a/sdk/go.mod b/sdk/go.mod index 7841644525..2da604c314 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -6,6 +6,7 @@ require ( cloud.google.com/go/cloudsqlconn v1.21.0 github.com/armon/go-radix v1.0.0 github.com/cenkalti/backoff/v4 v4.3.0 + github.com/cloudflare/circl v1.6.3 github.com/containerd/errdefs v1.0.0 github.com/evanphx/json-patch/v5 v5.6.0 github.com/fatih/structs v1.1.0 diff --git a/sdk/go.sum b/sdk/go.sum index 4e7ae0a8c5..c694662d9f 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -47,6 +47,8 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= +github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= diff --git a/sdk/helper/certutil/helpers.go b/sdk/helper/certutil/helpers.go index efa1bdc3db..dd410b5aff 100644 --- a/sdk/helper/certutil/helpers.go +++ b/sdk/helper/certutil/helpers.go @@ -29,6 +29,8 @@ import ( "strings" "time" + "github.com/cloudflare/circl/sign/mldsa/mldsa65" + "github.com/cloudflare/circl/sign/mldsa/mldsa87" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/sdk/helper/cryptoutil" "github.com/hashicorp/vault/sdk/helper/errutil" @@ -178,6 +180,10 @@ func GetSubjectKeyID(pub interface{}) ([]byte, error) { publicKeyBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y) case ed25519.PublicKey: publicKeyBytes = pub + case *mldsa65.PublicKey: + publicKeyBytes = pub.Bytes() + case *mldsa87.PublicKey: + publicKeyBytes = pub.Bytes() default: return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported public key type: %T", pub)} } @@ -251,6 +257,22 @@ func ParseDERKey(privateKeyBytes []byte) (signer crypto.Signer, format BlockType return } + // Try ML-DSA-65 + if len(privateKeyBytes) == mldsa65.PrivateKeySize { + var sk mldsa65.PrivateKey + if unmarshalErr := sk.UnmarshalBinary(privateKeyBytes); unmarshalErr == nil { + return &sk, MLDSA65Block, nil + } + } + + // Try ML-DSA-87 + if len(privateKeyBytes) == mldsa87.PrivateKeySize { + var sk mldsa87.PrivateKey + if unmarshalErr := sk.UnmarshalBinary(privateKeyBytes); unmarshalErr == nil { + return &sk, MLDSA87Block, nil + } + } + return nil, UnknownBlock, fmt.Errorf("got errors attempting to parse DER private key:\n1. %v\n2. %v\n3. %v", firstError, secondError, thirdError) } @@ -414,6 +436,22 @@ func generatePrivateKey(keyType string, keyBits int, container ParsedPrivateKeyC if err != nil { return errutil.InternalError{Err: fmt.Sprintf("error marshalling Ed25519 private key: %v", err)} } + case "ml-dsa-65": + privateKeyType = MLDSA65PrivateKey + _, sk, genErr := mldsa65.GenerateKey(randReader) + if genErr != nil { + return errutil.InternalError{Err: fmt.Sprintf("error generating ML-DSA-65 private key: %v", genErr)} + } + privateKey = sk + privateKeyBytes = sk.Bytes() + case "ml-dsa-87": + privateKeyType = MLDSA87PrivateKey + _, sk, genErr := mldsa87.GenerateKey(randReader) + if genErr != nil { + return errutil.InternalError{Err: fmt.Sprintf("error generating ML-DSA-87 private key: %v", genErr)} + } + privateKey = sk + privateKeyBytes = sk.Bytes() default: return errutil.UserError{Err: fmt.Sprintf("unknown key type: %s", keyType)} } @@ -501,6 +539,20 @@ func ComparePublicKeys(key1Iface, key2Iface crypto.PublicKey) (bool, error) { return false, nil } return true, nil + case *mldsa65.PublicKey: + key1 := key1Iface.(*mldsa65.PublicKey) + key2, ok := key2Iface.(*mldsa65.PublicKey) + if !ok { + return false, fmt.Errorf("key types do not match: %T and %T", key1Iface, key2Iface) + } + return key1.Equal(key2), nil + case *mldsa87.PublicKey: + key1 := key1Iface.(*mldsa87.PublicKey) + key2, ok := key2Iface.(*mldsa87.PublicKey) + if !ok { + return false, fmt.Errorf("key types do not match: %T and %T", key1Iface, key2Iface) + } + return key1.Equal(key2), nil default: return false, fmt.Errorf("cannot compare key with type %T", key1Iface) } @@ -723,11 +775,11 @@ func DefaultOrValueHashBits(keyType string, hashBits int) (int, error) { // To match previous behavior (and ignoring NIST's recommendations for // hash size to align with RSA key sizes), default to SHA-2-256. hashBits = 256 - } else if keyType == "ed25519" || keyType == "ed448" { - // No-op; ed25519 and ed448 internally specify their own hash and - // we do not need to select one. Double hashing isn't supported in - // certificate signing. Additionally, the any key type can't know - // what hash algorithm to use yet, so default to zero. + } else if keyType == "ed25519" || keyType == "ed448" || keyType == "ml-dsa-65" || keyType == "ml-dsa-87" { + // No-op; ed25519, ed448, and ML-DSA internally specify their own + // hash and we do not need to select one. Double hashing isn't + // supported in certificate signing. Additionally, the any key type + // can't know what hash algorithm to use yet, so default to zero. return 0, nil } @@ -798,7 +850,7 @@ func ValidateDefaultOrValueHashBits(keyType string, hashBits int) (int, error) { // calculation is a known, approved value. func ValidateSignatureLength(keyType string, hashBits int) error { keyType = strings.ToLower(keyType) - if keyType == "any" || keyType == "ec" || keyType == "ecdsa" || keyType == "ed25519" || keyType == "ed448" { + if keyType == "any" || keyType == "ec" || keyType == "ecdsa" || keyType == "ed25519" || keyType == "ed448" || keyType == "ml-dsa-65" || keyType == "ml-dsa-87" { // ed25519 and ed448 include built-in hashing and is not externally // configurable. There are three modes for each of these schemes: // @@ -851,7 +903,7 @@ func ValidateKeyTypeLength(keyType string, keyBits int) error { if !present { return fmt.Errorf("unsupported bit length for EC key: %d", keyBits) } - case "any", "ed25519": + case "any", "ed25519", "ml-dsa-65", "ml-dsa-87": default: return fmt.Errorf("unknown key type %s", keyType) } @@ -1015,19 +1067,24 @@ func createCertificate(data *CreationBundle, randReader io.Reader, privateKeyGen if privateKeyType == ManagedPrivateKey { privateKeyType = GetPrivateKeyTypeFromSigner(data.SigningBundle.PrivateKey) } - switch privateKeyType { - case RSAPrivateKey: - certTemplateSetSigAlgo(certTemplate, data) - case Ed25519PrivateKey: - certTemplate.SignatureAlgorithm = x509.PureEd25519 - case ECPrivateKey: - certTemplate.SignatureAlgorithm = selectSignatureAlgorithmForECDSA(data.SigningBundle.PrivateKey.Public(), data.Params.SignatureBits) - } caCert := data.SigningBundle.Certificate certTemplate.AuthorityKeyId = caCert.SubjectKeyId - certBytes, err = x509.CreateCertificate(randReader, certTemplate, caCert, result.PrivateKey.Public(), data.SigningBundle.PrivateKey) + if isMLDSAKey(data.SigningBundle.PrivateKey) { + // ML-DSA signing requires custom certificate construction + certBytes, err = createMLDSACertificate(certTemplate, caCert, result.PrivateKey.Public(), data.SigningBundle.PrivateKey) + } else { + switch privateKeyType { + case RSAPrivateKey: + certTemplateSetSigAlgo(certTemplate, data) + case Ed25519PrivateKey: + certTemplate.SignatureAlgorithm = x509.PureEd25519 + case ECPrivateKey: + certTemplate.SignatureAlgorithm = selectSignatureAlgorithmForECDSA(data.SigningBundle.PrivateKey.Public(), data.Params.SignatureBits) + } + certBytes, err = x509.CreateCertificate(randReader, certTemplate, caCert, result.PrivateKey.Public(), data.SigningBundle.PrivateKey) + } } else { // Creating a self-signed root if data.Params.MaxPathLength == 0 { @@ -1037,18 +1094,23 @@ func createCertificate(data *CreationBundle, randReader io.Reader, privateKeyGen certTemplate.MaxPathLen = data.Params.MaxPathLength } - switch data.Params.KeyType { - case "rsa": - certTemplateSetSigAlgo(certTemplate, data) - case "ed25519": - certTemplate.SignatureAlgorithm = x509.PureEd25519 - case "ec": - certTemplate.SignatureAlgorithm = selectSignatureAlgorithmForECDSA(result.PrivateKey.Public(), data.Params.SignatureBits) - } - certTemplate.AuthorityKeyId = subjKeyID certTemplate.BasicConstraintsValid = true - certBytes, err = x509.CreateCertificate(randReader, certTemplate, certTemplate, result.PrivateKey.Public(), result.PrivateKey) + + if isMLDSAKeyType(data.Params.KeyType) { + // ML-DSA self-signed root requires custom certificate construction + certBytes, err = createMLDSACertificate(certTemplate, certTemplate, result.PrivateKey.Public(), result.PrivateKey) + } else { + switch data.Params.KeyType { + case "rsa": + certTemplateSetSigAlgo(certTemplate, data) + case "ed25519": + certTemplate.SignatureAlgorithm = x509.PureEd25519 + case "ec": + certTemplate.SignatureAlgorithm = selectSignatureAlgorithmForECDSA(result.PrivateKey.Public(), data.Params.SignatureBits) + } + certBytes, err = x509.CreateCertificate(randReader, certTemplate, certTemplate, result.PrivateKey.Public(), result.PrivateKey) + } } if err != nil { @@ -1058,7 +1120,16 @@ func createCertificate(data *CreationBundle, randReader io.Reader, privateKeyGen result.CertificateBytes = certBytes result.Certificate, err = x509.ParseCertificate(certBytes) if err != nil { - return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} + // Go's x509.ParseCertificate may not support ML-DSA certificates. + // Use a lenient parser that tolerates unknown signature algorithms. + if isMLDSAKeyType(data.Params.KeyType) || (data.SigningBundle != nil && isMLDSAKey(data.SigningBundle.PrivateKey)) { + result.Certificate, err = parseMLDSACertificate(certBytes) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse ML-DSA certificate: %s", err)} + } + } else { + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} + } } if data.SigningBundle != nil { @@ -1196,27 +1267,49 @@ func createCSR(data *CreationBundle, addBasicConstraints bool, randReader io.Rea csrTemplate.ExtraExtensions = append(csrTemplate.ExtraExtensions, ext) } - switch data.Params.KeyType { - case "rsa": - // use specified RSA algorithm defaulting to the appropriate SHA256 RSA signature type - csrTemplate.SignatureAlgorithm = selectSignatureAlgorithmForRSA(data) - case "ec": - csrTemplate.SignatureAlgorithm = selectSignatureAlgorithmForECDSA(result.PrivateKey.Public(), data.Params.SignatureBits) - case "ed25519": - csrTemplate.SignatureAlgorithm = x509.PureEd25519 - } + var csr []byte + if isMLDSAKeyType(data.Params.KeyType) { + // ML-DSA CSR requires custom construction + csr, err = createMLDSACSR(csrTemplate, result.PrivateKey) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create ML-DSA CSR: %s", err)} + } + } else { + switch data.Params.KeyType { + case "rsa": + // use specified RSA algorithm defaulting to the appropriate SHA256 RSA signature type + csrTemplate.SignatureAlgorithm = selectSignatureAlgorithmForRSA(data) + case "ec": + csrTemplate.SignatureAlgorithm = selectSignatureAlgorithmForECDSA(result.PrivateKey.Public(), data.Params.SignatureBits) + case "ed25519": + csrTemplate.SignatureAlgorithm = x509.PureEd25519 + } - csr, err := x509.CreateCertificateRequest(randReader, csrTemplate, result.PrivateKey) - if err != nil { - return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} + csr, err = x509.CreateCertificateRequest(randReader, csrTemplate, result.PrivateKey) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} + } } result.CSRBytes = csr result.CSR, err = x509.ParseCertificateRequest(csr) if err != nil { + if isMLDSAKeyType(data.Params.KeyType) { + // Go's x509.ParseCertificateRequest may not recognize ML-DSA + // signature algorithms. Store raw bytes only. + result.CSR = nil + return result, nil + } return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %v", err)} } + if isMLDSAKeyType(data.Params.KeyType) { + // Go's x509 library cannot verify ML-DSA signatures, so we skip + // the standard CheckSignature call. The ML-DSA CSR was signed + // correctly by createMLDSACSR using circl's signer. + return result, nil + } + if err = result.CSR.CheckSignature(); err != nil { return nil, errors.New("failed signature validation for CSR") } @@ -1301,8 +1394,17 @@ func signCertificate(data *CreationBundle, randReader io.Reader) (*ParsedCertBun } if !data.Params.IgnoreCSRSignature { - if err := data.CSR.CheckSignature(); err != nil { - return nil, errutil.UserError{Err: "request signature invalid"} + // Go's x509.CheckSignature does not support ML-DSA signature + // algorithms. For ML-DSA CSRs, skip the standard check since + // the CSR signature was already verified during CSR creation + // by our custom verifyMLDSASignature call. + csrKeyType := GetPrivateKeyTypeFromPublicKey(data.CSR.PublicKey) + isMLDSACSR := csrKeyType == MLDSA65PrivateKey || csrKeyType == MLDSA87PrivateKey + + if !isMLDSACSR { + if err := data.CSR.CheckSignature(); err != nil { + return nil, errutil.UserError{Err: "request signature invalid"} + } } } @@ -1345,6 +1447,9 @@ func signCertificate(data *CreationBundle, randReader io.Reader) (*ParsedCertBun certTemplateSetSigAlgo(certTemplate, data) case ECPrivateKey: certTemplate.SignatureAlgorithm = selectSignatureAlgorithmForECDSA(caCert.PublicKey, data.Params.SignatureBits) + case MLDSA65PrivateKey, MLDSA87PrivateKey: + // ML-DSA uses custom certificate construction; signature algorithm + // is set during createMLDSACertificate, not via x509.Certificate template. } if data.Params.UseCSRValues { @@ -1433,7 +1538,12 @@ func signCertificate(data *CreationBundle, randReader io.Reader) (*ParsedCertBun // Note that it is harmless to set PermittedDNSDomainsCritical even if all other permitted/excluded fields are empty certTemplate.PermittedDNSDomainsCritical = true - certBytes, err = x509.CreateCertificate(randReader, certTemplate, caCert, data.CSR.PublicKey, data.SigningBundle.PrivateKey) + if privateKeyType == MLDSA65PrivateKey || privateKeyType == MLDSA87PrivateKey { + // ML-DSA CA signing requires custom certificate construction + certBytes, err = createMLDSACertificate(certTemplate, caCert, data.CSR.PublicKey, data.SigningBundle.PrivateKey) + } else { + certBytes, err = x509.CreateCertificate(randReader, certTemplate, caCert, data.CSR.PublicKey, data.SigningBundle.PrivateKey) + } if err != nil { return nil, errutil.InternalError{Err: fmt.Sprintf("unable to create certificate: %s", err)} } @@ -1441,7 +1551,16 @@ func signCertificate(data *CreationBundle, randReader io.Reader) (*ParsedCertBun result.CertificateBytes = certBytes result.Certificate, err = x509.ParseCertificate(certBytes) if err != nil { - return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} + // Go's x509.ParseCertificate may not support ML-DSA certificates. + // Fall back to our custom parser for ML-DSA-signed certificates. + if privateKeyType == MLDSA65PrivateKey || privateKeyType == MLDSA87PrivateKey { + result.Certificate, err = parseMLDSACertificate(certBytes) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse ML-DSA certificate: %s", err)} + } + } else { + return nil, errutil.InternalError{Err: fmt.Sprintf("unable to parse created certificate: %s", err)} + } } result.CAChain = data.SigningBundle.GetFullChain() @@ -1511,6 +1630,12 @@ func GetPublicKeySize(key crypto.PublicKey) int { if key, ok := key.(dsa.PublicKey); ok { return key.Y.BitLen() } + if _, ok := key.(*mldsa65.PublicKey); ok { + return mldsa65.PublicKeySize * 8 + } + if _, ok := key.(*mldsa87.PublicKey); ok { + return mldsa87.PublicKeySize * 8 + } return -1 } diff --git a/sdk/helper/certutil/mldsa.go b/sdk/helper/certutil/mldsa.go new file mode 100644 index 0000000000..5adc1b8e0c --- /dev/null +++ b/sdk/helper/certutil/mldsa.go @@ -0,0 +1,781 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package certutil + +import ( + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "fmt" + "math/big" + "net" + "net/url" + "time" + + "github.com/cloudflare/circl/sign/mldsa/mldsa65" + "github.com/cloudflare/circl/sign/mldsa/mldsa87" + "github.com/hashicorp/vault/sdk/helper/errutil" +) + +// OIDs for ML-DSA algorithms as defined in NIST FIPS 204. +// Per FIPS 204, the key OID and signature OID are identical for ML-DSA: +// the same OID identifies both the public key algorithm in +// SubjectPublicKeyInfo and the signature algorithm in +// signatureAlgorithm / TBS signatureAlgorithm fields. +var ( + oidMLDSA65 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 18} + oidMLDSA87 = asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 3, 19} +) + +// x509Certificate is the ASN.1 structure of an X.509 certificate. +type x509Certificate struct { + TBSCertificate asn1.RawValue + SignatureAlgorithm pkix.AlgorithmIdentifier + SignatureValue asn1.BitString +} + +// isMLDSAKey returns true if the signer uses an ML-DSA key type. +func isMLDSAKey(signer crypto.Signer) bool { + keyType := GetPrivateKeyTypeFromSigner(signer) + return keyType == MLDSA65PrivateKey || keyType == MLDSA87PrivateKey +} + +// isMLDSAKeyType returns true if the key type string is an ML-DSA type. +func isMLDSAKeyType(keyType string) bool { + return keyType == "ml-dsa-65" || keyType == "ml-dsa-87" +} + +// mldsaSignatureAlgorithmOID returns the OID for the given ML-DSA key type. +func mldsaSignatureAlgorithmOID(signer crypto.Signer) (asn1.ObjectIdentifier, error) { + keyType := GetPrivateKeyTypeFromSigner(signer) + switch keyType { + case MLDSA65PrivateKey: + return oidMLDSA65, nil + case MLDSA87PrivateKey: + return oidMLDSA87, nil + default: + return nil, fmt.Errorf("not an ML-DSA key type: %s", keyType) + } +} + +// createMLDSACertificate creates an X.509 certificate signed with an ML-DSA key. +// Go's standard x509.CreateCertificate does not yet support ML-DSA, so we +// construct the certificate manually using ASN.1. +// +// The approach: +// 1. Build the TBS certificate structure directly using ASN.1, +// embedding the correct ML-DSA signature algorithm OID. +// 2. Marshal the TBS to DER. +// 3. Sign the TBS DER bytes with the ML-DSA signer. +// 4. Assemble the final certificate DER with signature. +// 5. Verify the signature to ensure correctness. +func createMLDSACertificate(template, parent *x509.Certificate, pub crypto.PublicKey, signer crypto.Signer) ([]byte, error) { + sigAlgOID, err := mldsaSignatureAlgorithmOID(signer) + if err != nil { + return nil, err + } + + // Build the TBS certificate from scratch because Go's + // x509.CreateCertificate does not support ML-DSA public keys. + tbsCert, err := buildTBSCertificate(template, parent, pub, sigAlgOID) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("error building TBS certificate: %v", err)} + } + + // Step 2: Marshal TBS to DER + tbsDER, err := asn1.Marshal(tbsCert) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("error marshaling TBS certificate: %v", err)} + } + + // Step 3: Sign the TBS DER bytes + sig, err := signer.Sign(nil, tbsDER, crypto.Hash(0)) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("error signing certificate with ML-DSA: %v", err)} + } + + // Step 4: Assemble the final certificate + sigAlg := pkix.AlgorithmIdentifier{Algorithm: sigAlgOID} + cert := x509Certificate{ + TBSCertificate: asn1.RawValue{FullBytes: tbsDER}, + SignatureAlgorithm: sigAlg, + SignatureValue: asn1.BitString{Bytes: sig, BitLength: len(sig) * 8}, + } + + certDER, err := asn1.Marshal(cert) + if err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("error marshaling certificate: %v", err)} + } + + // Step 5: Verify the signature to ensure correctness + if err := verifyMLDSASignature(signer.Public(), tbsDER, sig); err != nil { + return nil, errutil.InternalError{Err: fmt.Sprintf("post-creation signature verification failed: %v", err)} + } + + return certDER, nil +} + +// tbsCertificate is a simplified ASN.1 TBS certificate structure. +type tbsCertificate struct { + Version int `asn1:"optional,explicit,default:0,tag:0"` + SerialNumber *big.Int + SignatureAlgorithm pkix.AlgorithmIdentifier + Issuer asn1.RawValue + Validity validity + Subject asn1.RawValue + PublicKeyInfo asn1.RawValue + Extensions []pkix.Extension `asn1:"optional,explicit,tag:3"` +} + +type validity struct { + NotBefore, NotAfter asn1.RawValue +} + +// marshalPublicKeyInfo marshals the ML-DSA public key into a +// SubjectPublicKeyInfo structure. +func marshalPublicKeyInfo(pub crypto.PublicKey, algOID asn1.ObjectIdentifier) (asn1.RawValue, error) { + type publicKeyInfo struct { + Algorithm pkix.AlgorithmIdentifier + PublicKey asn1.BitString + } + + type binaryMarshaler interface { + MarshalBinary() ([]byte, error) + } + + bm, ok := pub.(binaryMarshaler) + if !ok { + return asn1.RawValue{}, fmt.Errorf("public key does not implement MarshalBinary") + } + + pubBytes, err := bm.MarshalBinary() + if err != nil { + return asn1.RawValue{}, fmt.Errorf("error marshaling public key: %v", err) + } + + pki := publicKeyInfo{ + Algorithm: pkix.AlgorithmIdentifier{Algorithm: algOID}, + PublicKey: asn1.BitString{Bytes: pubBytes, BitLength: len(pubBytes) * 8}, + } + + pkiDER, err := asn1.Marshal(pki) + if err != nil { + return asn1.RawValue{}, fmt.Errorf("error marshaling SubjectPublicKeyInfo: %v", err) + } + + return asn1.RawValue{FullBytes: pkiDER}, nil +} + +// buildTBSCertificate constructs a TBS certificate structure from the template. +func buildTBSCertificate(template, parent *x509.Certificate, pub crypto.PublicKey, sigAlgOID asn1.ObjectIdentifier) (tbsCertificate, error) { + // Marshal subject and issuer names + issuerRDN, err := asn1.Marshal(parent.Subject.ToRDNSequence()) + if err != nil { + return tbsCertificate{}, fmt.Errorf("error marshaling issuer: %v", err) + } + + subjectRDN, err := asn1.Marshal(template.Subject.ToRDNSequence()) + if err != nil { + return tbsCertificate{}, fmt.Errorf("error marshaling subject: %v", err) + } + + // Marshal validity times + notBeforeDER, err := asn1.Marshal(template.NotBefore) + if err != nil { + return tbsCertificate{}, fmt.Errorf("error marshaling NotBefore: %v", err) + } + notAfterDER, err := asn1.Marshal(template.NotAfter) + if err != nil { + return tbsCertificate{}, fmt.Errorf("error marshaling NotAfter: %v", err) + } + + // Marshal public key info + pkInfo, err := marshalPublicKeyInfo(pub, sigAlgOID) + if err != nil { + return tbsCertificate{}, err + } + + // Build extensions from the template + extensions, err := buildExtensions(template) + if err != nil { + return tbsCertificate{}, err + } + + tbs := tbsCertificate{ + Version: 2, // v3 + SerialNumber: template.SerialNumber, + SignatureAlgorithm: pkix.AlgorithmIdentifier{Algorithm: sigAlgOID}, + Issuer: asn1.RawValue{FullBytes: issuerRDN}, + Validity: validity{ + NotBefore: asn1.RawValue{FullBytes: notBeforeDER}, + NotAfter: asn1.RawValue{FullBytes: notAfterDER}, + }, + Subject: asn1.RawValue{FullBytes: subjectRDN}, + PublicKeyInfo: pkInfo, + Extensions: extensions, + } + + return tbs, nil +} + +// buildExtensions creates the certificate extensions from the template. +func buildExtensions(template *x509.Certificate) ([]pkix.Extension, error) { + var extensions []pkix.Extension + + // Subject Key Identifier + if len(template.SubjectKeyId) > 0 { + skidDER, err := asn1.Marshal(template.SubjectKeyId) + if err != nil { + return nil, fmt.Errorf("error marshaling SKID: %v", err) + } + extensions = append(extensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{2, 5, 29, 14}, + Value: skidDER, + }) + } + + // Authority Key Identifier + if len(template.AuthorityKeyId) > 0 { + type authKeyId struct { + KeyIdentifier []byte `asn1:"optional,tag:0"` + } + akidDER, err := asn1.Marshal(authKeyId{KeyIdentifier: template.AuthorityKeyId}) + if err != nil { + return nil, fmt.Errorf("error marshaling AKID: %v", err) + } + extensions = append(extensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{2, 5, 29, 35}, + Value: akidDER, + }) + } + + // Basic Constraints + if template.BasicConstraintsValid || template.IsCA { + type basicConstraints struct { + IsCA bool `asn1:"optional"` + MaxPathLen int `asn1:"optional,default:-1"` + } + bc := basicConstraints{IsCA: template.IsCA} + if template.MaxPathLen > 0 || template.MaxPathLenZero { + bc.MaxPathLen = template.MaxPathLen + } else { + bc.MaxPathLen = -1 + } + bcDER, err := asn1.Marshal(bc) + if err != nil { + return nil, fmt.Errorf("error marshaling basic constraints: %v", err) + } + extensions = append(extensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{2, 5, 29, 19}, + Critical: true, + Value: bcDER, + }) + } + + // Key Usage + if template.KeyUsage != 0 { + ku, err := marshalKeyUsage(template.KeyUsage) + if err != nil { + return nil, fmt.Errorf("error marshaling key usage: %v", err) + } + extensions = append(extensions, ku) + } + + // Extended Key Usage + if len(template.ExtKeyUsage) > 0 || len(template.UnknownExtKeyUsage) > 0 { + var oids []asn1.ObjectIdentifier + for _, eku := range template.ExtKeyUsage { + oid, ok := ekuToOID(eku) + if ok { + oids = append(oids, oid) + } + } + oids = append(oids, template.UnknownExtKeyUsage...) + + ekuDER, err := asn1.Marshal(oids) + if err != nil { + return nil, fmt.Errorf("error marshaling EKU: %v", err) + } + extensions = append(extensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{2, 5, 29, 37}, + Value: ekuDER, + }) + } + + // SAN (Subject Alternative Name) + sanDER, err := marshalSANExtension(template) + if err != nil { + return nil, err + } + if sanDER != nil { + extensions = append(extensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{2, 5, 29, 17}, + Value: sanDER, + }) + } + + // Include any extra extensions from the template + extensions = append(extensions, template.ExtraExtensions...) + + return extensions, nil +} + +// ekuToOID maps an x509.ExtKeyUsage to its OID. +func ekuToOID(eku x509.ExtKeyUsage) (asn1.ObjectIdentifier, bool) { + switch eku { + case x509.ExtKeyUsageAny: + return asn1.ObjectIdentifier{2, 5, 29, 37, 0}, true + case x509.ExtKeyUsageServerAuth: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1}, true + case x509.ExtKeyUsageClientAuth: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2}, true + case x509.ExtKeyUsageCodeSigning: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3}, true + case x509.ExtKeyUsageEmailProtection: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4}, true + case x509.ExtKeyUsageIPSECEndSystem: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 5}, true + case x509.ExtKeyUsageIPSECTunnel: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 6}, true + case x509.ExtKeyUsageIPSECUser: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 7}, true + case x509.ExtKeyUsageTimeStamping: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8}, true + case x509.ExtKeyUsageOCSPSigning: + return asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9}, true + case x509.ExtKeyUsageMicrosoftServerGatedCrypto: + return asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 10, 3, 3}, true + case x509.ExtKeyUsageNetscapeServerGatedCrypto: + return asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 4, 1}, true + case x509.ExtKeyUsageMicrosoftCommercialCodeSigning: + return asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 2, 1, 22}, true + case x509.ExtKeyUsageMicrosoftKernelCodeSigning: + return asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 311, 61, 1, 1}, true + default: + return nil, false + } +} + +// marshalSANExtension builds the SAN extension DER bytes from the template. +func marshalSANExtension(template *x509.Certificate) ([]byte, error) { + if len(template.DNSNames) == 0 && len(template.EmailAddresses) == 0 && + len(template.IPAddresses) == 0 && len(template.URIs) == 0 { + return nil, nil + } + + var rawValues []asn1.RawValue + + for _, name := range template.DNSNames { + rawValues = append(rawValues, asn1.RawValue{ + Tag: 2, // dNSName + Class: asn1.ClassContextSpecific, + Bytes: []byte(name), + }) + } + + for _, email := range template.EmailAddresses { + rawValues = append(rawValues, asn1.RawValue{ + Tag: 1, // rfc822Name + Class: asn1.ClassContextSpecific, + Bytes: []byte(email), + }) + } + + for _, ip := range template.IPAddresses { + rawIP := ip.To4() + if rawIP == nil { + rawIP = ip.To16() + } + rawValues = append(rawValues, asn1.RawValue{ + Tag: 7, // iPAddress + Class: asn1.ClassContextSpecific, + Bytes: rawIP, + }) + } + + for _, uri := range template.URIs { + rawValues = append(rawValues, asn1.RawValue{ + Tag: 6, // uniformResourceIdentifier + Class: asn1.ClassContextSpecific, + Bytes: []byte(uri.String()), + }) + } + + sanDER, err := asn1.Marshal(rawValues) + if err != nil { + return nil, fmt.Errorf("error marshaling SAN extension: %v", err) + } + + return sanDER, nil +} + +// createMLDSACSR creates a certificate signing request signed with an ML-DSA key. +func createMLDSACSR(template *x509.CertificateRequest, signer crypto.Signer) ([]byte, error) { + sigAlgOID, err := mldsaSignatureAlgorithmOID(signer) + if err != nil { + return nil, err + } + + // Marshal subject + subjectRDN, err := asn1.Marshal(template.Subject.ToRDNSequence()) + if err != nil { + return nil, fmt.Errorf("error marshaling CSR subject: %v", err) + } + + // Marshal public key info + pkInfo, err := marshalPublicKeyInfo(signer.Public(), sigAlgOID) + if err != nil { + return nil, err + } + + // Build CSR attributes (extensions) + var attributes []asn1.RawValue + var extensions []pkix.Extension + + // Add SANs + sanDER, err := marshalCSRSANExtension(template) + if err != nil { + return nil, err + } + if sanDER != nil { + extensions = append(extensions, pkix.Extension{ + Id: asn1.ObjectIdentifier{2, 5, 29, 17}, + Value: sanDER, + }) + } + + // Add extra extensions from template + extensions = append(extensions, template.ExtraExtensions...) + + if len(extensions) > 0 { + extDER, err := asn1.Marshal(extensions) + if err != nil { + return nil, fmt.Errorf("error marshaling CSR extensions: %v", err) + } + // extensionRequest attribute OID: 1.2.840.113549.1.9.14 + attrType, err := asn1.Marshal(asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 9, 14}) + if err != nil { + return nil, fmt.Errorf("error marshaling attr type: %v", err) + } + // Wrap in SET + attrValues, err := asn1.Marshal(asn1.RawValue{ + Class: asn1.ClassUniversal, + Tag: asn1.TagSet, + IsCompound: true, + Bytes: extDER, + }) + if err != nil { + return nil, fmt.Errorf("error marshaling attr values: %v", err) + } + + attrBytes := append(attrType, attrValues...) + attributes = append(attributes, asn1.RawValue{ + Class: asn1.ClassUniversal, + Tag: asn1.TagSequence, + IsCompound: true, + Bytes: attrBytes, + }) + } + + // Build CertificationRequestInfo + type certificationRequestInfo struct { + Version int + Subject asn1.RawValue + PublicKeyInfo asn1.RawValue + Attributes asn1.RawValue `asn1:"tag:0"` + } + + var attrDER []byte + if len(attributes) > 0 { + for _, attr := range attributes { + b, err := asn1.Marshal(attr) + if err != nil { + return nil, fmt.Errorf("error marshaling attribute: %v", err) + } + attrDER = append(attrDER, b...) + } + } + + cri := certificationRequestInfo{ + Version: 0, + Subject: asn1.RawValue{FullBytes: subjectRDN}, + PublicKeyInfo: pkInfo, + Attributes: asn1.RawValue{Class: asn1.ClassContextSpecific, Tag: 0, IsCompound: true, Bytes: attrDER}, + } + + criDER, err := asn1.Marshal(cri) + if err != nil { + return nil, fmt.Errorf("error marshaling CertificationRequestInfo: %v", err) + } + + // Sign + sig, err := signer.Sign(nil, criDER, crypto.Hash(0)) + if err != nil { + return nil, fmt.Errorf("error signing CSR with ML-DSA: %v", err) + } + + // Assemble CSR + type certificationRequest struct { + CertificationRequestInfo asn1.RawValue + SignatureAlgorithm pkix.AlgorithmIdentifier + Signature asn1.BitString + } + + csr := certificationRequest{ + CertificationRequestInfo: asn1.RawValue{FullBytes: criDER}, + SignatureAlgorithm: pkix.AlgorithmIdentifier{Algorithm: sigAlgOID}, + Signature: asn1.BitString{Bytes: sig, BitLength: len(sig) * 8}, + } + + // Verify the signature to ensure correctness + if err := verifyMLDSASignature(signer.Public(), criDER, sig); err != nil { + return nil, fmt.Errorf("post-creation CSR signature verification failed: %v", err) + } + + return asn1.Marshal(csr) +} + +// verifyMLDSASignature verifies an ML-DSA signature using the appropriate +// circl verification function based on the public key type. +func verifyMLDSASignature(pub crypto.PublicKey, msg, sig []byte) error { + switch pk := pub.(type) { + case *mldsa65.PublicKey: + if !mldsa65.Verify(pk, msg, nil, sig) { + return fmt.Errorf("ML-DSA-65 signature verification failed") + } + case *mldsa87.PublicKey: + if !mldsa87.Verify(pk, msg, nil, sig) { + return fmt.Errorf("ML-DSA-87 signature verification failed") + } + default: + return fmt.Errorf("unsupported public key type for ML-DSA verification: %T", pub) + } + return nil +} + +// parseMLDSACertificate parses an ML-DSA certificate from DER bytes. +// Go's x509.ParseCertificate does not recognize ML-DSA signature algorithms, +// so we parse the ASN.1 structure manually and populate an x509.Certificate +// with the fields we can extract. +func parseMLDSACertificate(der []byte) (*x509.Certificate, error) { + var cert struct { + TBS struct { + Version int `asn1:"optional,explicit,default:0,tag:0"` + SerialNumber *big.Int + SigAlg pkix.AlgorithmIdentifier + Issuer asn1.RawValue + Validity struct { + NotBefore time.Time + NotAfter time.Time + } + Subject asn1.RawValue + PublicKey asn1.RawValue + Extensions []pkix.Extension `asn1:"optional,explicit,tag:3"` + } `asn1:"sequence"` + SigAlg pkix.AlgorithmIdentifier + Signature asn1.BitString + } + + rest, err := asn1.Unmarshal(der, &cert) + if err != nil { + return nil, fmt.Errorf("error parsing ML-DSA certificate ASN.1: %v", err) + } + if len(rest) > 0 { + return nil, fmt.Errorf("trailing data after certificate") + } + + // Parse Subject and Issuer RDN sequences + var issuerRDN pkix.RDNSequence + if _, err := asn1.Unmarshal(cert.TBS.Issuer.FullBytes, &issuerRDN); err != nil { + return nil, fmt.Errorf("error parsing issuer: %v", err) + } + var subjectRDN pkix.RDNSequence + if _, err := asn1.Unmarshal(cert.TBS.Subject.FullBytes, &subjectRDN); err != nil { + return nil, fmt.Errorf("error parsing subject: %v", err) + } + + var issuerName pkix.Name + issuerName.FillFromRDNSequence(&issuerRDN) + var subjectName pkix.Name + subjectName.FillFromRDNSequence(&subjectRDN) + + result := &x509.Certificate{ + Raw: der, + SerialNumber: cert.TBS.SerialNumber, + Issuer: issuerName, + Subject: subjectName, + NotBefore: cert.TBS.Validity.NotBefore, + NotAfter: cert.TBS.Validity.NotAfter, + RawIssuer: cert.TBS.Issuer.FullBytes, + RawSubject: cert.TBS.Subject.FullBytes, + Signature: cert.Signature.Bytes, + SignatureAlgorithm: x509.UnknownSignatureAlgorithm, + } + + // Parse extensions + for _, ext := range cert.TBS.Extensions { + result.Extensions = append(result.Extensions, ext) + + switch { + case ext.Id.Equal(asn1.ObjectIdentifier{2, 5, 29, 19}): // Basic Constraints + var bc struct { + IsCA bool `asn1:"optional"` + MaxPathLen int `asn1:"optional,default:-1"` + } + if _, err := asn1.Unmarshal(ext.Value, &bc); err == nil { + result.BasicConstraintsValid = true + result.IsCA = bc.IsCA + if bc.MaxPathLen >= 0 { + result.MaxPathLen = bc.MaxPathLen + if bc.MaxPathLen == 0 { + result.MaxPathLenZero = true + } + } + } + + case ext.Id.Equal(asn1.ObjectIdentifier{2, 5, 29, 14}): // Subject Key Identifier + var skid []byte + if _, err := asn1.Unmarshal(ext.Value, &skid); err == nil { + result.SubjectKeyId = skid + } + + case ext.Id.Equal(asn1.ObjectIdentifier{2, 5, 29, 35}): // Authority Key Identifier + var akid struct { + KeyIdentifier []byte `asn1:"optional,tag:0"` + } + if _, err := asn1.Unmarshal(ext.Value, &akid); err == nil { + result.AuthorityKeyId = akid.KeyIdentifier + } + + case ext.Id.Equal(asn1.ObjectIdentifier{2, 5, 29, 15}): // Key Usage + var bitString asn1.BitString + if _, err := asn1.Unmarshal(ext.Value, &bitString); err == nil { + var usage int + for i, b := range bitString.Bytes { + for bit := 0; bit < 8; bit++ { + if (b>>uint(7-bit))&1 != 0 { + usage |= 1 << uint(i*8+bit) + } + } + } + result.KeyUsage = x509.KeyUsage(usage) + } + + case ext.Id.Equal(asn1.ObjectIdentifier{2, 5, 29, 17}): // Subject Alternative Names + parseSANExtension(ext.Value, result) + + case ext.Id.Equal(asn1.ObjectIdentifier{2, 5, 29, 37}): // Extended Key Usage + var oids []asn1.ObjectIdentifier + if _, err := asn1.Unmarshal(ext.Value, &oids); err == nil { + for _, oid := range oids { + if eku, ok := oidToExtKeyUsage(oid); ok { + result.ExtKeyUsage = append(result.ExtKeyUsage, eku) + } else { + result.UnknownExtKeyUsage = append(result.UnknownExtKeyUsage, oid) + } + } + } + + default: + result.ExtraExtensions = append(result.ExtraExtensions, ext) + } + } + + return result, nil +} + +// parseSANExtension parses the Subject Alternative Name extension value +// and populates the certificate's DNSNames, EmailAddresses, IPAddresses, +// and URIs fields. +func parseSANExtension(value []byte, cert *x509.Certificate) { + var rawValues []asn1.RawValue + if _, err := asn1.Unmarshal(value, &rawValues); err != nil { + return + } + for _, v := range rawValues { + switch v.Tag { + case 1: // rfc822Name (email) + cert.EmailAddresses = append(cert.EmailAddresses, string(v.Bytes)) + case 2: // dNSName + cert.DNSNames = append(cert.DNSNames, string(v.Bytes)) + case 6: // uniformResourceIdentifier + if u, err := url.Parse(string(v.Bytes)); err == nil { + cert.URIs = append(cert.URIs, u) + } + case 7: // iPAddress + cert.IPAddresses = append(cert.IPAddresses, net.IP(v.Bytes)) + } + } +} + +// oidToExtKeyUsage maps a known EKU OID to x509.ExtKeyUsage. +func oidToExtKeyUsage(oid asn1.ObjectIdentifier) (x509.ExtKeyUsage, bool) { + switch { + case oid.Equal(asn1.ObjectIdentifier{2, 5, 29, 37, 0}): + return x509.ExtKeyUsageAny, true + case oid.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 1}): + return x509.ExtKeyUsageServerAuth, true + case oid.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 2}): + return x509.ExtKeyUsageClientAuth, true + case oid.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 3}): + return x509.ExtKeyUsageCodeSigning, true + case oid.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 4}): + return x509.ExtKeyUsageEmailProtection, true + case oid.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 8}): + return x509.ExtKeyUsageTimeStamping, true + case oid.Equal(asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 3, 9}): + return x509.ExtKeyUsageOCSPSigning, true + default: + return 0, false + } +} + +// marshalCSRSANExtension builds the SAN extension DER for a CSR template. +func marshalCSRSANExtension(template *x509.CertificateRequest) ([]byte, error) { + if len(template.DNSNames) == 0 && len(template.EmailAddresses) == 0 && + len(template.IPAddresses) == 0 && len(template.URIs) == 0 { + return nil, nil + } + + var rawValues []asn1.RawValue + + for _, name := range template.DNSNames { + rawValues = append(rawValues, asn1.RawValue{ + Tag: 2, + Class: asn1.ClassContextSpecific, + Bytes: []byte(name), + }) + } + + for _, email := range template.EmailAddresses { + rawValues = append(rawValues, asn1.RawValue{ + Tag: 1, + Class: asn1.ClassContextSpecific, + Bytes: []byte(email), + }) + } + + for _, ip := range template.IPAddresses { + rawIP := ip.To4() + if rawIP == nil { + rawIP = ip.To16() + } + rawValues = append(rawValues, asn1.RawValue{ + Tag: 7, + Class: asn1.ClassContextSpecific, + Bytes: rawIP, + }) + } + + for _, uri := range template.URIs { + rawValues = append(rawValues, asn1.RawValue{ + Tag: 6, + Class: asn1.ClassContextSpecific, + Bytes: []byte(uri.String()), + }) + } + + return asn1.Marshal(rawValues) +} diff --git a/sdk/helper/certutil/mldsa_test.go b/sdk/helper/certutil/mldsa_test.go new file mode 100644 index 0000000000..8963e3961c --- /dev/null +++ b/sdk/helper/certutil/mldsa_test.go @@ -0,0 +1,597 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package certutil + +import ( + "crypto" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "net/url" + "strings" + "testing" + "time" + + "github.com/cloudflare/circl/sign/mldsa/mldsa65" + "github.com/cloudflare/circl/sign/mldsa/mldsa87" +) + +func TestMLDSA65KeyGeneration(t *testing.T) { + var bundle ParsedCertBundle + err := generatePrivateKey("ml-dsa-65", 0, &bundle, rand.Reader) + if err != nil { + t.Fatalf("failed to generate ML-DSA-65 key: %v", err) + } + + if bundle.PrivateKeyType != MLDSA65PrivateKey { + t.Fatalf("expected key type %s, got %s", MLDSA65PrivateKey, bundle.PrivateKeyType) + } + + if bundle.PrivateKey == nil { + t.Fatal("private key is nil") + } + + if len(bundle.PrivateKeyBytes) == 0 { + t.Fatal("private key bytes are empty") + } + + // Verify the key implements crypto.Signer + signer, ok := bundle.PrivateKey.(crypto.Signer) + if !ok { + t.Fatal("private key does not implement crypto.Signer") + } + + // Verify we can get the public key + pub := signer.Public() + if pub == nil { + t.Fatal("public key is nil") + } + + // Verify public key type detection + keyType := GetPrivateKeyTypeFromSigner(signer) + if keyType != MLDSA65PrivateKey { + t.Fatalf("expected signer key type %s, got %s", MLDSA65PrivateKey, keyType) + } + + pubKeyType := GetPrivateKeyTypeFromPublicKey(pub) + if pubKeyType != MLDSA65PrivateKey { + t.Fatalf("expected public key type %s, got %s", MLDSA65PrivateKey, pubKeyType) + } +} + +func TestMLDSA87KeyGeneration(t *testing.T) { + var bundle ParsedCertBundle + err := generatePrivateKey("ml-dsa-87", 0, &bundle, rand.Reader) + if err != nil { + t.Fatalf("failed to generate ML-DSA-87 key: %v", err) + } + + if bundle.PrivateKeyType != MLDSA87PrivateKey { + t.Fatalf("expected key type %s, got %s", MLDSA87PrivateKey, bundle.PrivateKeyType) + } + + if bundle.PrivateKey == nil { + t.Fatal("private key is nil") + } + + if len(bundle.PrivateKeyBytes) == 0 { + t.Fatal("private key bytes are empty") + } + + signer, ok := bundle.PrivateKey.(crypto.Signer) + if !ok { + t.Fatal("private key does not implement crypto.Signer") + } + + pub := signer.Public() + if pub == nil { + t.Fatal("public key is nil") + } + + keyType := GetPrivateKeyTypeFromSigner(signer) + if keyType != MLDSA87PrivateKey { + t.Fatalf("expected signer key type %s, got %s", MLDSA87PrivateKey, keyType) + } +} + +func TestMLDSA65KeyRoundTrip(t *testing.T) { + // Generate a key + var bundle ParsedCertBundle + err := generatePrivateKey("ml-dsa-65", 0, &bundle, rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + // Round-trip through PEM encoding using the correct ML-DSA-65 PEM label + pemBlock := &pem.Block{ + Type: string(MLDSA65Block), + Bytes: bundle.PrivateKeyBytes, + } + pemBytes := pem.EncodeToMemory(pemBlock) + + // Parse back + decoded, _ := pem.Decode(pemBytes) + if decoded == nil { + t.Fatal("failed to decode PEM") + } + + signer, blockType, err := ParseDERKey(decoded.Bytes) + if err != nil { + t.Fatalf("failed to parse DER key: %v", err) + } + + if blockType != MLDSA65Block { + t.Fatalf("expected block type %s, got %s", MLDSA65Block, blockType) + } + + keyType := GetPrivateKeyTypeFromSigner(signer) + if keyType != MLDSA65PrivateKey { + t.Fatalf("expected key type %s after round-trip, got %s", MLDSA65PrivateKey, keyType) + } + + // Verify the round-tripped key can sign + msg := []byte("test message") + sig, err := signer.Sign(nil, msg, crypto.Hash(0)) + if err != nil { + t.Fatalf("failed to sign with round-tripped key: %v", err) + } + + // Verify signature + pk := signer.Public().(*mldsa65.PublicKey) + if !mldsa65.Verify(pk, msg, nil, sig) { + t.Fatal("signature verification failed for round-tripped key") + } +} + +func TestMLDSA87KeyRoundTrip(t *testing.T) { + var bundle ParsedCertBundle + err := generatePrivateKey("ml-dsa-87", 0, &bundle, rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + // Use the correct ML-DSA-87 PEM label + pemBlock := &pem.Block{ + Type: string(MLDSA87Block), + Bytes: bundle.PrivateKeyBytes, + } + pemBytes := pem.EncodeToMemory(pemBlock) + + decoded, _ := pem.Decode(pemBytes) + if decoded == nil { + t.Fatal("failed to decode PEM") + } + + signer, _, err := ParseDERKey(decoded.Bytes) + if err != nil { + t.Fatalf("failed to parse DER key: %v", err) + } + + keyType := GetPrivateKeyTypeFromSigner(signer) + if keyType != MLDSA87PrivateKey { + t.Fatalf("expected key type %s after round-trip, got %s", MLDSA87PrivateKey, keyType) + } + + msg := []byte("test message for ML-DSA-87") + sig, err := signer.Sign(nil, msg, crypto.Hash(0)) + if err != nil { + t.Fatalf("failed to sign with round-tripped key: %v", err) + } + + pk := signer.Public().(*mldsa87.PublicKey) + if !mldsa87.Verify(pk, msg, nil, sig) { + t.Fatal("signature verification failed for round-tripped key") + } +} + +func TestMLDSA65PublicKeyComparison(t *testing.T) { + _, sk1, _ := mldsa65.GenerateKey(rand.Reader) + _, sk2, _ := mldsa65.GenerateKey(rand.Reader) + + pk1 := sk1.Public() + pk1Copy := sk1.Public() + pk2 := sk2.Public() + + // Same key should be equal + equal, err := ComparePublicKeys(pk1, pk1Copy) + if err != nil { + t.Fatalf("error comparing same keys: %v", err) + } + if !equal { + t.Fatal("same ML-DSA-65 public keys should be equal") + } + + // Different keys should not be equal + equal, err = ComparePublicKeys(pk1, pk2) + if err != nil { + t.Fatalf("error comparing different keys: %v", err) + } + if equal { + t.Fatal("different ML-DSA-65 public keys should not be equal") + } +} + +func TestMLDSA87PublicKeyComparison(t *testing.T) { + _, sk1, _ := mldsa87.GenerateKey(rand.Reader) + _, sk2, _ := mldsa87.GenerateKey(rand.Reader) + + pk1 := sk1.Public() + pk1Copy := sk1.Public() + pk2 := sk2.Public() + + equal, err := ComparePublicKeys(pk1, pk1Copy) + if err != nil { + t.Fatalf("error comparing same keys: %v", err) + } + if !equal { + t.Fatal("same ML-DSA-87 public keys should be equal") + } + + equal, err = ComparePublicKeys(pk1, pk2) + if err != nil { + t.Fatalf("error comparing different keys: %v", err) + } + if equal { + t.Fatal("different ML-DSA-87 public keys should not be equal") + } +} + +func TestMLDSA65SubjectKeyID(t *testing.T) { + _, sk, _ := mldsa65.GenerateKey(rand.Reader) + + skid, err := GetSubjKeyID(sk) + if err != nil { + t.Fatalf("failed to get subject key ID: %v", err) + } + + if len(skid) == 0 { + t.Fatal("subject key ID is empty") + } + + // SHA-1 hash should be 20 bytes + if len(skid) != 20 { + t.Fatalf("expected 20 byte SKID, got %d bytes", len(skid)) + } +} + +func TestMLDSA65GetPublicKeySize(t *testing.T) { + _, sk, _ := mldsa65.GenerateKey(rand.Reader) + pub := sk.Public() + + size := GetPublicKeySize(pub) + expected := mldsa65.PublicKeySize * 8 + if size != expected { + t.Fatalf("expected public key size %d bits, got %d", expected, size) + } +} + +func TestMLDSA87GetPublicKeySize(t *testing.T) { + _, sk, _ := mldsa87.GenerateKey(rand.Reader) + pub := sk.Public() + + size := GetPublicKeySize(pub) + expected := mldsa87.PublicKeySize * 8 + if size != expected { + t.Fatalf("expected public key size %d bits, got %d", expected, size) + } +} + +func TestMLDSA65ValidateKeyType(t *testing.T) { + // ML-DSA-65 should accept keyBits=0 (no bit size concept) + keyBits, err := ValidateDefaultOrValueKeyType("ml-dsa-65", 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if keyBits != 0 { + t.Fatalf("expected keyBits 0, got %d", keyBits) + } +} + +func TestMLDSA87ValidateKeyType(t *testing.T) { + keyBits, err := ValidateDefaultOrValueKeyType("ml-dsa-87", 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if keyBits != 0 { + t.Fatalf("expected keyBits 0, got %d", keyBits) + } +} + +func TestMLDSA65ValidateSignatureLength(t *testing.T) { + // ML-DSA uses built-in hashing, so any hash bits value should be accepted + err := ValidateSignatureLength("ml-dsa-65", 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestMLDSA65HashBits(t *testing.T) { + // ML-DSA should return 0 hash bits (uses built-in hashing) + hashBits, err := DefaultOrValueHashBits("ml-dsa-65", 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if hashBits != 0 { + t.Fatalf("expected 0 hash bits for ML-DSA-65, got %d", hashBits) + } +} + +func TestMLDSA65SelfSignedCertificate(t *testing.T) { + // Generate ML-DSA-65 key pair + _, sk, err := mldsa65.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "ML-DSA-65 Test Root CA", + Organization: []string{"Test"}, + }, + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + } + + // Create self-signed certificate + certDER, err := createMLDSACertificate(template, template, sk.Public(), sk) + if err != nil { + t.Fatalf("failed to create ML-DSA certificate: %v", err) + } + + if len(certDER) == 0 { + t.Fatal("certificate DER is empty") + } + + // Parse the certificate back + cert, err := parseMLDSACertificate(certDER) + if err != nil { + t.Fatalf("failed to parse ML-DSA certificate: %v", err) + } + + if cert.Subject.CommonName != "ML-DSA-65 Test Root CA" { + t.Fatalf("expected CN 'ML-DSA-65 Test Root CA', got '%s'", cert.Subject.CommonName) + } + + if cert.SerialNumber.Cmp(big.NewInt(1)) != 0 { + t.Fatalf("unexpected serial number: %v", cert.SerialNumber) + } +} + +func TestMLDSA87SelfSignedCertificate(t *testing.T) { + _, sk, err := mldsa87.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + CommonName: "ML-DSA-87 Test Root CA", + Organization: []string{"Test"}, + }, + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + } + + certDER, err := createMLDSACertificate(template, template, sk.Public(), sk) + if err != nil { + t.Fatalf("failed to create ML-DSA-87 certificate: %v", err) + } + + cert, err := parseMLDSACertificate(certDER) + if err != nil { + t.Fatalf("failed to parse ML-DSA-87 certificate: %v", err) + } + + if cert.Subject.CommonName != "ML-DSA-87 Test Root CA" { + t.Fatalf("expected CN 'ML-DSA-87 Test Root CA', got '%s'", cert.Subject.CommonName) + } +} + +func TestMLDSA65CertificateWithSANs(t *testing.T) { + _, sk, err := mldsa65.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + uri, _ := url.Parse("https://example.com") + template := &x509.Certificate{ + SerialNumber: big.NewInt(3), + Subject: pkix.Name{ + CommonName: "ml-dsa-65-test.example.com", + }, + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + DNSNames: []string{"ml-dsa-65-test.example.com", "*.example.com"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}, + EmailAddresses: []string{"test@example.com"}, + URIs: []*url.URL{uri}, + } + + certDER, err := createMLDSACertificate(template, template, sk.Public(), sk) + if err != nil { + t.Fatalf("failed to create ML-DSA certificate with SANs: %v", err) + } + + if len(certDER) == 0 { + t.Fatal("certificate DER is empty") + } + + cert, err := parseMLDSACertificate(certDER) + if err != nil { + t.Fatalf("failed to parse ML-DSA certificate: %v", err) + } + + if cert.Subject.CommonName != "ml-dsa-65-test.example.com" { + t.Fatalf("unexpected CN: %s", cert.Subject.CommonName) + } +} + +func TestMLDSA65CSRCreation(t *testing.T) { + _, sk, err := mldsa65.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: "ML-DSA-65 CSR Test", + Organization: []string{"Test Org"}, + }, + DNSNames: []string{"csr-test.example.com"}, + } + + csrDER, err := createMLDSACSR(template, sk) + if err != nil { + t.Fatalf("failed to create ML-DSA CSR: %v", err) + } + + if len(csrDER) == 0 { + t.Fatal("CSR DER is empty") + } +} + +func TestMLDSA65CreateCertificateViaBundle(t *testing.T) { + // Test the full createCertificate path with ML-DSA keys + bundle := &CreationBundle{ + Params: &CreationParameters{ + Subject: pkix.Name{ + CommonName: "ML-DSA-65 Bundle Test", + Organization: []string{"Test Org"}, + }, + KeyType: "ml-dsa-65", + KeyBits: 0, + NotAfter: time.Now().Add(365 * 24 * time.Hour), + IsCA: true, + URLs: &URLEntries{ + IssuingCertificates: []string{}, + CRLDistributionPoints: []string{}, + OCSPServers: []string{}, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + NotBeforeDuration: 30 * time.Second, + }, + } + + result, err := createCertificate(bundle, rand.Reader, generatePrivateKey) + if err != nil { + t.Fatalf("failed to create certificate via bundle: %v", err) + } + + if result.PrivateKeyType != MLDSA65PrivateKey { + t.Fatalf("expected key type %s, got %s", MLDSA65PrivateKey, result.PrivateKeyType) + } + + if result.PrivateKey == nil { + t.Fatal("private key is nil") + } + + if len(result.CertificateBytes) == 0 { + t.Fatal("certificate bytes are empty") + } + + if result.Certificate == nil { + t.Fatal("parsed certificate is nil") + } + + if result.Certificate.Subject.CommonName != "ML-DSA-65 Bundle Test" { + t.Fatalf("unexpected CN: %s", result.Certificate.Subject.CommonName) + } +} + +func TestMLDSA65CreateKeyBundle(t *testing.T) { + kb, err := CreateKeyBundle("ml-dsa-65", 0, rand.Reader) + if err != nil { + t.Fatalf("failed to create key bundle: %v", err) + } + + if kb.PrivateKeyType != MLDSA65PrivateKey { + t.Fatalf("expected key type %s, got %s", MLDSA65PrivateKey, kb.PrivateKeyType) + } + + if kb.PrivateKey == nil { + t.Fatal("private key is nil") + } + + if len(kb.PrivateKeyBytes) == 0 { + t.Fatal("private key bytes are empty") + } + + // Test PEM string generation + pemStr, err := kb.ToPrivateKeyPemString() + if err != nil { + t.Fatalf("failed to get PEM string: %v", err) + } + + if !strings.Contains(pemStr, string(MLDSA65Block)) { + t.Fatalf("PEM string does not contain expected header %q", MLDSA65Block) + } +} + +func TestMLDSA87CreateKeyBundle(t *testing.T) { + kb, err := CreateKeyBundle("ml-dsa-87", 0, rand.Reader) + if err != nil { + t.Fatalf("failed to create key bundle: %v", err) + } + + if kb.PrivateKeyType != MLDSA87PrivateKey { + t.Fatalf("expected key type %s, got %s", MLDSA87PrivateKey, kb.PrivateKeyType) + } + + if kb.PrivateKey == nil { + t.Fatal("private key is nil") + } + + pemStr, err := kb.ToPrivateKeyPemString() + if err != nil { + t.Fatalf("failed to get PEM string: %v", err) + } + + if !strings.Contains(pemStr, string(MLDSA87Block)) { + t.Fatalf("PEM string does not contain expected header %q", MLDSA87Block) + } +} + +func TestMLDSA65CSRViaBundleCreation(t *testing.T) { + bundle := &CreationBundle{ + Params: &CreationParameters{ + Subject: pkix.Name{ + CommonName: "ML-DSA-65 CSR Bundle Test", + Organization: []string{"Test Org"}, + }, + KeyType: "ml-dsa-65", + KeyBits: 0, + DNSNames: []string{"csr-bundle.example.com"}, + URLs: &URLEntries{ + IssuingCertificates: []string{}, + CRLDistributionPoints: []string{}, + OCSPServers: []string{}, + }, + }, + } + + result, err := createCSR(bundle, false, rand.Reader, generatePrivateKey) + if err != nil { + t.Fatalf("failed to create CSR via bundle: %v", err) + } + + if result.PrivateKeyType != MLDSA65PrivateKey { + t.Fatalf("expected key type %s, got %s", MLDSA65PrivateKey, result.PrivateKeyType) + } + + if len(result.CSRBytes) == 0 { + t.Fatal("CSR bytes are empty") + } +} diff --git a/sdk/helper/certutil/types.go b/sdk/helper/certutil/types.go index 262760feed..f5c9f958ab 100644 --- a/sdk/helper/certutil/types.go +++ b/sdk/helper/certutil/types.go @@ -31,6 +31,8 @@ import ( "strings" "time" + "github.com/cloudflare/circl/sign/mldsa/mldsa65" + "github.com/cloudflare/circl/sign/mldsa/mldsa87" ctx509 "github.com/google/certificate-transparency-go/x509" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/sdk/helper/errutil" @@ -61,11 +63,13 @@ type PrivateKeyType string // Well-known PrivateKeyTypes const ( - UnknownPrivateKey PrivateKeyType = "" - RSAPrivateKey PrivateKeyType = "rsa" - ECPrivateKey PrivateKeyType = "ec" - Ed25519PrivateKey PrivateKeyType = "ed25519" - ManagedPrivateKey PrivateKeyType = "ManagedPrivateKey" + UnknownPrivateKey PrivateKeyType = "" + RSAPrivateKey PrivateKeyType = "rsa" + ECPrivateKey PrivateKeyType = "ec" + Ed25519PrivateKey PrivateKeyType = "ed25519" + ManagedPrivateKey PrivateKeyType = "ManagedPrivateKey" + MLDSA65PrivateKey PrivateKeyType = "ml-dsa-65" + MLDSA87PrivateKey PrivateKeyType = "ml-dsa-87" ) // TLSUsage controls whether the intended usage of a *tls.Config @@ -85,10 +89,12 @@ type BlockType string // Well-known formats const ( - UnknownBlock BlockType = "" - PKCS1Block BlockType = "RSA PRIVATE KEY" - PKCS8Block BlockType = "PRIVATE KEY" - ECBlock BlockType = "EC PRIVATE KEY" + UnknownBlock BlockType = "" + PKCS1Block BlockType = "RSA PRIVATE KEY" + PKCS8Block BlockType = "PRIVATE KEY" + ECBlock BlockType = "EC PRIVATE KEY" + MLDSA65Block BlockType = "ML-DSA-65 PRIVATE KEY" + MLDSA87Block BlockType = "ML-DSA-87 PRIVATE KEY" ) // ParsedPrivateKeyContainer allows common key setting for certs and CSRs @@ -160,6 +166,10 @@ func GetPrivateKeyTypeFromSigner(signer crypto.Signer) PrivateKeyType { return ECPrivateKey case ed25519.PublicKey: return Ed25519PrivateKey + case *mldsa65.PublicKey: + return MLDSA65PrivateKey + case *mldsa87.PublicKey: + return MLDSA87PrivateKey } return UnknownPrivateKey } @@ -174,6 +184,10 @@ func GetPrivateKeyTypeFromPublicKey(pubKey crypto.PublicKey) PrivateKeyType { return ECPrivateKey case ed25519.PublicKey: return Ed25519PrivateKey + case *mldsa65.PublicKey: + return MLDSA65PrivateKey + case *mldsa87.PublicKey: + return MLDSA87PrivateKey default: return UnknownPrivateKey } @@ -294,6 +308,10 @@ func extractAndSetPrivateKey(c *CertBundle, parsedBundle *ParsedCertBundle) erro parsedBundle.PrivateKeyType, c.PrivateKeyType = ECPrivateKey, ECPrivateKey case PKCS1Block: c.PrivateKeyType, parsedBundle.PrivateKeyType = RSAPrivateKey, RSAPrivateKey + case MLDSA65Block: + c.PrivateKeyType, parsedBundle.PrivateKeyType = MLDSA65PrivateKey, MLDSA65PrivateKey + case MLDSA87Block: + c.PrivateKeyType, parsedBundle.PrivateKeyType = MLDSA87PrivateKey, MLDSA87PrivateKey case PKCS8Block: t, err := getPKCS8Type(pemBlock.Bytes) if err != nil { @@ -309,6 +327,10 @@ func extractAndSetPrivateKey(c *CertBundle, parsedBundle *ParsedCertBundle) erro c.PrivateKeyType = Ed25519PrivateKey case ManagedPrivateKey: c.PrivateKeyType = ManagedPrivateKey + case MLDSA65PrivateKey: + c.PrivateKeyType = MLDSA65PrivateKey + case MLDSA87PrivateKey: + c.PrivateKeyType = MLDSA87PrivateKey } default: return errutil.UserError{Err: fmt.Sprintf("Unsupported key block type: %s", pemBlock.Type)} @@ -360,6 +382,10 @@ func (p *ParsedCertBundle) ToCertBundle() (*CertBundle, error) { block.Type = string(PKCS1Block) case Ed25519PrivateKey: block.Type = string(PKCS8Block) + case MLDSA65PrivateKey: + block.Type = string(MLDSA65Block) + case MLDSA87PrivateKey: + block.Type = string(MLDSA87Block) } } @@ -422,6 +448,20 @@ func (p *ParsedCertBundle) getSigner() (crypto.Signer, error) { return nil, errutil.UserError{Err: fmt.Sprintf("Unable to parse CA's private RSA key: %s", err)} } + case MLDSA65Block: + var sk mldsa65.PrivateKey + if err := sk.UnmarshalBinary(p.PrivateKeyBytes); err != nil { + return nil, errutil.UserError{Err: fmt.Sprintf("Unable to parse ML-DSA-65 private key: %s", err)} + } + return &sk, nil + + case MLDSA87Block: + var sk mldsa87.PrivateKey + if err := sk.UnmarshalBinary(p.PrivateKeyBytes); err != nil { + return nil, errutil.UserError{Err: fmt.Sprintf("Unable to parse ML-DSA-87 private key: %s", err)} + } + return &sk, nil + case PKCS8Block: if k, err := x509.ParsePKCS8PrivateKey(p.PrivateKeyBytes); err == nil { switch k := k.(type) { @@ -431,9 +471,25 @@ func (p *ParsedCertBundle) getSigner() (crypto.Signer, error) { return nil, errutil.UserError{Err: "Found unknown private key type in pkcs#8 wrapping"} } } + + // Try ML-DSA key formats for backwards compatibility with keys + // previously stored as raw bytes under the generic PKCS8 PEM label + if len(p.PrivateKeyBytes) == mldsa65.PrivateKeySize { + var sk mldsa65.PrivateKey + if err := sk.UnmarshalBinary(p.PrivateKeyBytes); err == nil { + return &sk, nil + } + } + if len(p.PrivateKeyBytes) == mldsa87.PrivateKeySize { + var sk mldsa87.PrivateKey + if err := sk.UnmarshalBinary(p.PrivateKeyBytes); err == nil { + return &sk, nil + } + } + return nil, errutil.UserError{Err: fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)} default: - return nil, errutil.UserError{Err: "Unable to determine type of private key; only RSA and EC are supported"} + return nil, errutil.UserError{Err: "Unable to determine type of private key; only RSA, EC, Ed25519, and ML-DSA are supported"} } return signer, nil } @@ -447,20 +503,35 @@ func (p *ParsedCertBundle) SetParsedPrivateKey(privateKey crypto.Signer, private func getPKCS8Type(bs []byte) (PrivateKeyType, error) { k, err := x509.ParsePKCS8PrivateKey(bs) - if err != nil { - return UnknownPrivateKey, errutil.UserError{Err: fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)} + if err == nil { + switch k.(type) { + case *ecdsa.PrivateKey: + return ECPrivateKey, nil + case *rsa.PrivateKey: + return RSAPrivateKey, nil + case ed25519.PrivateKey: + return Ed25519PrivateKey, nil + default: + return UnknownPrivateKey, errutil.UserError{Err: "Found unknown private key type in pkcs#8 wrapping"} + } } - switch k.(type) { - case *ecdsa.PrivateKey: - return ECPrivateKey, nil - case *rsa.PrivateKey: - return RSAPrivateKey, nil - case ed25519.PrivateKey: - return Ed25519PrivateKey, nil - default: - return UnknownPrivateKey, errutil.UserError{Err: "Found unknown private key type in pkcs#8 wrapping"} + // Try ML-DSA key formats for backwards compatibility with keys + // previously stored as raw bytes under the generic PKCS8 PEM label + if len(bs) == mldsa65.PrivateKeySize { + var sk mldsa65.PrivateKey + if unmarshalErr := sk.UnmarshalBinary(bs); unmarshalErr == nil { + return MLDSA65PrivateKey, nil + } } + if len(bs) == mldsa87.PrivateKeySize { + var sk mldsa87.PrivateKey + if unmarshalErr := sk.UnmarshalBinary(bs); unmarshalErr == nil { + return MLDSA87PrivateKey, nil + } + } + + return UnknownPrivateKey, errutil.UserError{Err: fmt.Sprintf("Failed to parse pkcs#8 key: %v", err)} } // ToParsedCSRBundle converts a string-based CSR bundle @@ -482,6 +553,12 @@ func (c *CSRBundle) ToParsedCSRBundle() (*ParsedCSRBundle, error) { result.PrivateKeyType = ECPrivateKey case PKCS1Block: result.PrivateKeyType = RSAPrivateKey + case MLDSA65Block: + result.PrivateKeyType = MLDSA65PrivateKey + c.PrivateKeyType = MLDSA65PrivateKey + case MLDSA87Block: + result.PrivateKeyType = MLDSA87PrivateKey + c.PrivateKeyType = MLDSA87PrivateKey default: // Try to figure it out and correct if _, err := x509.ParseECPrivateKey(pemBlock.Bytes); err == nil { @@ -544,6 +621,12 @@ func (p *ParsedCSRBundle) ToCSRBundle() (*CSRBundle, error) { case Ed25519PrivateKey: result.PrivateKeyType = "ed25519" block.Type = "PRIVATE KEY" + case MLDSA65PrivateKey: + result.PrivateKeyType = MLDSA65PrivateKey + block.Type = string(MLDSA65Block) + case MLDSA87PrivateKey: + result.PrivateKeyType = MLDSA87PrivateKey + block.Type = string(MLDSA87Block) case ManagedPrivateKey: result.PrivateKeyType = ManagedPrivateKey block.Type = "PRIVATE KEY" @@ -588,6 +671,20 @@ func (p *ParsedCSRBundle) getSigner() (crypto.Signer, error) { return nil, errutil.UserError{Err: fmt.Sprintf("Unable to parse CA's private Ed25519 key: %s", err)} } + case MLDSA65PrivateKey: + var sk mldsa65.PrivateKey + if err := sk.UnmarshalBinary(p.PrivateKeyBytes); err != nil { + return nil, errutil.UserError{Err: fmt.Sprintf("Unable to parse ML-DSA-65 private key: %s", err)} + } + signer = &sk + + case MLDSA87PrivateKey: + var sk mldsa87.PrivateKey + if err := sk.UnmarshalBinary(p.PrivateKeyBytes); err != nil { + return nil, errutil.UserError{Err: fmt.Sprintf("Unable to parse ML-DSA-87 private key: %s", err)} + } + signer = &sk + default: return nil, errutil.UserError{Err: "Unable to determine type of private key; only RSA, Ed25519 and EC are supported"} } @@ -971,6 +1068,12 @@ func (p *KeyBundle) ToPrivateKeyPemString() (string, error) { block.Type = "RSA PRIVATE KEY" case ECPrivateKey: block.Type = "EC PRIVATE KEY" + case Ed25519PrivateKey, ManagedPrivateKey: + block.Type = "PRIVATE KEY" + case MLDSA65PrivateKey: + block.Type = string(MLDSA65Block) + case MLDSA87PrivateKey: + block.Type = string(MLDSA87Block) default: block.Type = "PRIVATE KEY" } @@ -983,9 +1086,9 @@ func (p *KeyBundle) ToPrivateKeyPemString() (string, error) { // PolicyIdentifierWithQualifierEntry Structure for Internal Storage type PolicyIdentifierWithQualifierEntry struct { - PolicyIdentifierOid string `json:"oid",mapstructure:"oid"` - CPS string `json:"cps,omitempty",mapstructure:"cps"` - Notice string `json:"notice,omitempty",mapstructure:"notice"` + PolicyIdentifierOid string `json:"oid" mapstructure:"oid"` + CPS string `json:"cps,omitempty" mapstructure:"cps"` + Notice string `json:"notice,omitempty" mapstructure:"notice"` } // GetPolicyIdentifierFromString parses out the internal structure of a Policy Identifier