From ba1d2a8124505c832bc541362ed7912b6d1e340b Mon Sep 17 00:00:00 2001
From: Mohammed Al Sahaf <msaa1990@gmail.com>
Date: Wed, 15 Nov 2023 22:47:46 +0300
Subject: [PATCH] allow more customizable options in CSRs

---
 modules/caddypki/ca.go       |  43 ++++++--
 modules/caddypki/csr.go      |  40 +++++++-
 modules/caddypki/csr_test.go | 184 +++++++++++++++++++++++++----------
 3 files changed, 204 insertions(+), 63 deletions(-)

diff --git a/modules/caddypki/ca.go b/modules/caddypki/ca.go
index 6d25b8f76..e57f0d5fc 100644
--- a/modules/caddypki/ca.go
+++ b/modules/caddypki/ca.go
@@ -16,7 +16,9 @@ package caddypki
 
 import (
 	"crypto"
+	"crypto/rand"
 	"crypto/x509"
+	"crypto/x509/pkix"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -432,40 +434,61 @@ func (ca CA) generateCSR(csrReq csrRequest) (csr *x509.CertificateRequest, err e
 	csrKeyPEM, err := ca.storage.Load(ca.ctx, ca.storageKeyCSRKey(csrReq.ID))
 	if err != nil {
 		if !errors.Is(err, fs.ErrNotExist) {
-			return nil, fmt.Errorf("loading csr key '%s': %v", csrReq.ID, err)
+			return csr, fmt.Errorf("loading csr key '%s': %v", csrReq.ID, err)
 		}
 		if csrReq.Key == nil {
 			signer, err = keyutil.GenerateDefaultSigner()
 			if err != nil {
-				return nil, err
+				return csr, err
 			}
 		} else {
 			signer, err = keyutil.GenerateSigner(csrReq.Key.Type.String(), csrReq.Key.Curve.String(), csrReq.Key.Size)
 			if err != nil {
-				return nil, err
+				return csr, err
 			}
 		}
 
 		csrKeyPEM, err = certmagic.PEMEncodePrivateKey(signer)
 		if err != nil {
-			return nil, fmt.Errorf("encoding csr key: %v", err)
+			return csr, fmt.Errorf("encoding csr key: %v", err)
 		}
 		if err := ca.storage.Store(ca.ctx, ca.storageKeyCSRKey(csrReq.ID), csrKeyPEM); err != nil {
-			return nil, fmt.Errorf("saving csr key: %v", err)
+			return csr, fmt.Errorf("saving csr key: %v", err)
 		}
 	}
 	if signer == nil {
 		signer, err = certmagic.PEMDecodePrivateKey(csrKeyPEM)
 		if err != nil {
-			return nil, fmt.Errorf("decoding csr key: %v", err)
+			return csr, fmt.Errorf("decoding csr key: %v", err)
 		}
 	}
 
-	csr, err = x509util.CreateCertificateRequest("", csrReq.SANs, signer)
-	if err != nil {
-		return nil, err
+	var subject pkix.Name
+	if csrReq.Request != nil && csrReq.Request.Subject != nil {
+		subject = pkix.Name{
+			Country:            csrReq.Request.Subject.Country,
+			Organization:       csrReq.Request.Subject.Organization,
+			OrganizationalUnit: csrReq.Request.Subject.OrganizationalUnit,
+			Locality:           csrReq.Request.Subject.Locality,
+			Province:           csrReq.Request.Subject.Province,
+			StreetAddress:      csrReq.Request.Subject.StreetAddress,
+			PostalCode:         csrReq.Request.Subject.PostalCode,
+			CommonName:         csrReq.Request.Subject.CommonName,
+		}
 	}
-	return csr, nil
+	dnsNames, ips, emails, uris := x509util.SplitSANs(csrReq.Request.SANs)
+
+	csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{
+		Subject:        subject,
+		DNSNames:       dnsNames,
+		IPAddresses:    ips,
+		EmailAddresses: emails,
+		URIs:           uris,
+	}, signer)
+	if err != nil {
+		return csr, err
+	}
+	return x509.ParseCertificateRequest(csrBytes)
 }
 
 // AuthorityConfig is used to help a CA configure
diff --git a/modules/caddypki/csr.go b/modules/caddypki/csr.go
index 2f379ef9d..d00c001ed 100644
--- a/modules/caddypki/csr.go
+++ b/modules/caddypki/csr.go
@@ -2,7 +2,9 @@ package caddypki
 
 import (
 	"encoding/json"
+	"errors"
 	"fmt"
+	"strings"
 )
 
 // The key type to be used for signing the CSR. The possible types are:
@@ -150,10 +152,44 @@ type csrRequest struct {
 	// The values are case-sensitive.
 	Key *keyParameters `json:"key,omitempty"`
 
-	// SANs is a list of subject alternative names for the certificate.
-	SANs []string `json:"sans"`
+	Request *requestParameters `json:"request,omitempty"`
 }
 
 func (c csrRequest) validate() error {
+	if !c.Request.valid() {
+		return errors.New("the 'request' field is not valid")
+	}
 	return c.Key.validate()
 }
+
+type requestParameters struct {
+	Subject *subject `json:"subject,omitempty"`
+
+	// SANs is a list of subject alternative names for the certificate.
+	SANs []string `json:"sans,omitempty"`
+}
+
+type subject struct {
+	CommonName         string   `json:"cn,omitempty"`
+	Country            []string `json:"c,omitempty"`
+	Organization       []string `json:"o,omitempty"`
+	OrganizationalUnit []string `json:"ou,omitempty"`
+	Locality           []string `json:"l,omitempty"`
+	Province           []string `json:"s,omitempty"`
+	StreetAddress      []string `json:"street_address,omitempty"`
+	PostalCode         []string `json:"postal_code,omitempty"`
+}
+
+func (rp *requestParameters) valid() bool {
+	if rp == nil || (len(rp.SANs) == 0 && rp.Subject == nil) {
+		return false
+	}
+	if len(rp.SANs) > 0 {
+		for _, san := range rp.SANs {
+			if strings.TrimSpace(san) == "" {
+				return false
+			}
+		}
+	}
+	return rp.Subject == nil || (rp.Subject != nil && len(strings.TrimSpace(rp.Subject.CommonName)) > 0)
+}
diff --git a/modules/caddypki/csr_test.go b/modules/caddypki/csr_test.go
index 19a56d65c..09870e658 100644
--- a/modules/caddypki/csr_test.go
+++ b/modules/caddypki/csr_test.go
@@ -2,7 +2,6 @@ package caddypki
 
 import (
 	"encoding/json"
-	"reflect"
 	"testing"
 )
 
@@ -19,12 +18,12 @@ func TestParseKeyType(t *testing.T) {
 			expected: keyTypeEC,
 		},
 		{
-			name:  "lowercase EC is recognized",
+			name:  "lowercase EC is rejected",
 			input: `"ec"`,
 			err:   "unknown key type: ec",
 		},
 		{
-			name:  "mixed case EC is recognized",
+			name:  "mixed case EC is rejected",
 			input: `"eC"`,
 			err:   "unknown key type: eC",
 		},
@@ -34,12 +33,12 @@ func TestParseKeyType(t *testing.T) {
 			expected: keyTypeRSA,
 		},
 		{
-			name:  "lowercase rsa is not accepted",
+			name:  "lowercase rsa is rejected",
 			input: `"rsa"`,
 			err:   "unknown key type: rsa",
 		},
 		{
-			name:  "mixed case RSA is not accepted",
+			name:  "mixed case RSA is rejected",
 			input: `"RsA"`,
 			err:   "unknown key type: RsA",
 		},
@@ -49,17 +48,17 @@ func TestParseKeyType(t *testing.T) {
 			expected: keyTypeOKP,
 		},
 		{
-			name:  "lowercase OKP is not accepted",
+			name:  "lowercase OKP is rejected",
 			input: `"okp"`,
 			err:   "unknown key type: okp",
 		},
 		{
-			name:  "mixed case OKP is not accepted",
+			name:  "mixed case OKP is rejected",
 			input: `"OkP"`,
 			err:   "unknown key type: OkP",
 		},
 		{
-			name:  "unknown key type is an error",
+			name:  "unknown key type is rejected",
 			input: `"foo"`,
 			err:   "unknown key type: foo",
 		},
@@ -89,7 +88,7 @@ func TestParseKeyType(t *testing.T) {
 	}
 }
 
-func TestCSRRequestValidate(t *testing.T) {
+func TestCSRKeyParameterValidate(t *testing.T) {
 	tests := []struct {
 		name    string
 		key     *keyParameters
@@ -221,71 +220,154 @@ func TestCSRRequestValidate(t *testing.T) {
 			wantErr: true,
 		},
 	}
-
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			c := csrRequest{
-				Key: tt.key,
-			}
-			if err := c.validate(); (err != nil) != tt.wantErr {
-				t.Errorf("csrRequest.validate() error = %v, wantErr %v", err, tt.wantErr)
+			if err := tt.key.validate(); (err != nil) != tt.wantErr {
+				t.Errorf("keyParameter.validate() error = %v, wantErr %v", err, tt.wantErr)
 			}
 		})
 	}
 }
 
-func TestCSRRequestUnmarshalJSON(t *testing.T) {
+func TestParseCurve(t *testing.T) {
 	tests := []struct {
-		name    string
-		request string
-		want    csrRequest
-		err     string
+		name     string
+		input    string
+		expected curve
+		err      string
 	}{
 		{
-			name:    "empty request is valid",
-			request: "{}",
-			want: csrRequest{
-				Key: nil,
-			},
+			name:     "Ed25519 is recognized",
+			input:    `"Ed25519"`,
+			expected: curveEd25519,
 		},
 		{
-			name:    "RSA with size 2048 is valid",
-			request: `{"key":{"type":"RSA","size":2048}}`,
-			want: csrRequest{
-				Key: &keyParameters{
-					Type: keyTypeRSA,
-					Size: 2048,
-				},
-			},
+			name:  "ed25519 is rejected",
+			input: `"ed25519"`,
+			err:   "unknown curve: ed25519",
 		},
 		{
-			name:    "EC key with curve P-256 is valid",
-			request: `{"key":{"type":"EC","curve":"P-256"}}`,
-			want: csrRequest{
-				Key: &keyParameters{
-					Type:  keyTypeEC,
-					Curve: "P-256",
-				},
-			},
+			name:  "eD25519 is rejected",
+			input: `"eD25519"`,
+			err:   "unknown curve: eD25519",
+		},
+		{
+			name:     "X25519 is recognized",
+			input:    `"X25519"`,
+			expected: curveX25519,
+		},
+		{
+			name:  "x25519 is rejected",
+			input: `"x25519"`,
+			err:   "unknown curve: x25519",
+		},
+		{
+			name:     "P-256 is recognized",
+			input:    `"P-256"`,
+			expected: curveP256,
+		},
+		{
+			name:  "p-256 is rejected",
+			input: `"p-256"`,
+			err:   "unknown curve: p-256",
+		},
+
+		{
+			name:     "P-384 is recognized",
+			input:    `"P-384"`,
+			expected: curveP384,
+		},
+		{
+			name:  "p-384 is rejected",
+			input: `"p-384"`,
+			err:   "unknown curve: p-384",
+		},
+
+		{
+			name:     "P-521 is recognized",
+			input:    `"P-521"`,
+			expected: curveP521,
+		},
+		{
+			name:  "p-521 is rejected",
+			input: `"p-521"`,
+			err:   "unknown curve: p-521",
 		},
 	}
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			var c csrRequest
-			err := json.Unmarshal([]byte(tt.request), &c)
-			if tt.err != "" {
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			var kt curve
+
+			err := json.Unmarshal([]byte(test.input), &kt)
+			if test.err != "" {
 				if err == nil {
-					t.Errorf("expected error %q, but got nil", tt.err)
+					t.Errorf("expected error %q, but got nil", test.err)
 				}
-				if err.Error() != tt.err {
-					t.Errorf("expected error %q, but got %q", tt.err, err.Error())
+				if err.Error() != test.err {
+					t.Errorf("expected error %q, but got %q", test.err, err.Error())
 				}
+				return
 			}
 			if err != nil {
 				t.Errorf("expected no error, but got %q", err.Error())
+				return
 			}
-			if !reflect.DeepEqual(c, tt.want) {
-				t.Errorf("csrRequest.unmarshalJSON() = %v, want %v", c, tt.want)
+			if kt != test.expected {
+				t.Errorf("expected %v, but got %v", test.expected, kt)
+			}
+		})
+	}
+}
+
+func TestRequestParametersValidation(t *testing.T) {
+	tests := []struct {
+		name string
+		req  *requestParameters
+		want bool
+	}{
+		{
+			name: "nil request is invalid",
+			req:  nil,
+			want: false,
+		},
+		{
+			name: "empty request is invalid",
+			req:  &requestParameters{},
+			want: false,
+		},
+		{
+			name: "request containing empty SAN value is invalid",
+			req: &requestParameters{
+				SANs: []string{"example.com", "", "foo.com"},
+			},
+			want: false,
+		},
+		{
+			name: "request with SANs is valid",
+			req: &requestParameters{
+				SANs: []string{"example.com"},
+			},
+			want: true,
+		},
+		{
+			name: "request with non-empty CommonName is valid",
+			req: &requestParameters{
+				Subject: &subject{CommonName: "example.com"},
+			},
+			want: true,
+		},
+		{
+			name: "request with empty-space CommonName is invalid",
+			req: &requestParameters{
+				Subject: &subject{CommonName: " "},
+			},
+			want: false,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := tt.req.valid(); got != tt.want {
+				t.Errorf("requestParameters.valid() = %v, want %v", got, tt.want)
 			}
 		})
 	}