mirror of
https://github.com/caddyserver/caddy.git
synced 2025-02-07 22:02:00 +08:00
Implement custom cert selection policies; optimize matching for SNI
This commit is contained in:
parent
5a4a1421de
commit
210d0cf7f1
|
@ -2,8 +2,11 @@ package caddytls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"bitbucket.org/lightcodelabs/caddy2"
|
"bitbucket.org/lightcodelabs/caddy2"
|
||||||
"github.com/go-acme/lego/challenge/tlsalpn01"
|
"github.com/go-acme/lego/challenge/tlsalpn01"
|
||||||
|
@ -26,7 +29,7 @@ func (cp ConnectionPolicies) TLSConfig(ctx caddy2.Context) (*tls.Config, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading handshake matcher module '%s': %s", modName, err)
|
return nil, fmt.Errorf("loading handshake matcher module '%s': %s", modName, err)
|
||||||
}
|
}
|
||||||
cp[i].Matchers = append(cp[i].Matchers, val.(ConnectionMatcher))
|
cp[i].matchers = append(cp[i].matchers, val.(ConnectionMatcher))
|
||||||
}
|
}
|
||||||
cp[i].MatchersRaw = nil // allow GC to deallocate - TODO: Does this help?
|
cp[i].MatchersRaw = nil // allow GC to deallocate - TODO: Does this help?
|
||||||
}
|
}
|
||||||
|
@ -39,11 +42,34 @@ func (cp ConnectionPolicies) TLSConfig(ctx caddy2.Context) (*tls.Config, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// using ServerName to match policies is extremely common, especially in configs
|
||||||
|
// with lots and lots of different policies; we can fast-track those by indexing
|
||||||
|
// them by SNI, so we don't have to iterate potentially thousands of policies
|
||||||
|
indexedBySNI := make(map[string]ConnectionPolicies)
|
||||||
|
if len(cp) > 30 {
|
||||||
|
for _, p := range cp {
|
||||||
|
for _, m := range p.matchers {
|
||||||
|
if sni, ok := m.(MatchServerName); ok {
|
||||||
|
for _, sniName := range sni {
|
||||||
|
indexedBySNI[sniName] = append(indexedBySNI[sniName], p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &tls.Config{
|
return &tls.Config{
|
||||||
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
// filter policies by SNI first, if possible, to speed things up
|
||||||
|
// when there may be lots of policies
|
||||||
|
possiblePolicies := cp
|
||||||
|
if indexedPolicies, ok := indexedBySNI[hello.ServerName]; ok {
|
||||||
|
possiblePolicies = indexedPolicies
|
||||||
|
}
|
||||||
|
|
||||||
policyLoop:
|
policyLoop:
|
||||||
for _, pol := range cp {
|
for _, pol := range possiblePolicies {
|
||||||
for _, matcher := range pol.Matchers {
|
for _, matcher := range pol.matchers {
|
||||||
if !matcher.Match(hello) {
|
if !matcher.Match(hello) {
|
||||||
continue policyLoop
|
continue policyLoop
|
||||||
}
|
}
|
||||||
|
@ -65,16 +91,18 @@ type ConnectionPolicy struct {
|
||||||
ProtocolMin string `json:"protocol_min,omitempty"`
|
ProtocolMin string `json:"protocol_min,omitempty"`
|
||||||
ProtocolMax string `json:"protocol_max,omitempty"`
|
ProtocolMax string `json:"protocol_max,omitempty"`
|
||||||
|
|
||||||
|
CertSelection *CertSelectionPolicy `json:"certificate_selection,omitempty"`
|
||||||
|
|
||||||
// TODO: Client auth
|
// TODO: Client auth
|
||||||
|
|
||||||
// TODO: see if starlark could be useful here - enterprise only
|
// TODO: see if starlark could be useful here - enterprise only
|
||||||
StarlarkHandshake string `json:"starlark_handshake,omitempty"`
|
StarlarkHandshake string `json:"starlark_handshake,omitempty"`
|
||||||
|
|
||||||
Matchers []ConnectionMatcher
|
matchers []ConnectionMatcher
|
||||||
stdTLSConfig *tls.Config
|
stdTLSConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cp *ConnectionPolicy) buildStandardTLSConfig(ctx caddy2.Context) error {
|
func (p *ConnectionPolicy) buildStandardTLSConfig(ctx caddy2.Context) error {
|
||||||
tlsAppIface, err := ctx.App("tls")
|
tlsAppIface, err := ctx.App("tls")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("getting tls app: %v", err)
|
return fmt.Errorf("getting tls app: %v", err)
|
||||||
|
@ -82,17 +110,17 @@ func (cp *ConnectionPolicy) buildStandardTLSConfig(ctx caddy2.Context) error {
|
||||||
tlsApp := tlsAppIface.(*TLS)
|
tlsApp := tlsAppIface.(*TLS)
|
||||||
|
|
||||||
cfg := &tls.Config{
|
cfg := &tls.Config{
|
||||||
NextProtos: cp.ALPN,
|
NextProtos: p.ALPN,
|
||||||
PreferServerCipherSuites: true,
|
PreferServerCipherSuites: true,
|
||||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
// TODO: Must fix https://github.com/mholt/caddy/issues/2588
|
|
||||||
// (allow customizing the selection of a very specific certificate
|
|
||||||
// based on the ClientHelloInfo)
|
|
||||||
cfgTpl, err := tlsApp.getConfigForName(hello.ServerName)
|
cfgTpl, err := tlsApp.getConfigForName(hello.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting config for name %s: %v", hello.ServerName, err)
|
return nil, fmt.Errorf("getting config for name %s: %v", hello.ServerName, err)
|
||||||
}
|
}
|
||||||
newCfg := certmagic.New(tlsApp.certCache, cfgTpl)
|
newCfg := certmagic.New(tlsApp.certCache, cfgTpl)
|
||||||
|
if p.CertSelection != nil {
|
||||||
|
newCfg.CertSelector = makeCertSelector(p)
|
||||||
|
}
|
||||||
return newCfg.GetCertificate(hello)
|
return newCfg.GetCertificate(hello)
|
||||||
},
|
},
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
|
@ -102,7 +130,7 @@ func (cp *ConnectionPolicy) buildStandardTLSConfig(ctx caddy2.Context) error {
|
||||||
|
|
||||||
// add all the cipher suites in order, without duplicates
|
// add all the cipher suites in order, without duplicates
|
||||||
cipherSuitesAdded := make(map[uint16]struct{})
|
cipherSuitesAdded := make(map[uint16]struct{})
|
||||||
for _, csName := range cp.CipherSuites {
|
for _, csName := range p.CipherSuites {
|
||||||
csID := supportedCipherSuites[csName]
|
csID := supportedCipherSuites[csName]
|
||||||
if _, ok := cipherSuitesAdded[csID]; !ok {
|
if _, ok := cipherSuitesAdded[csID]; !ok {
|
||||||
cipherSuitesAdded[csID] = struct{}{}
|
cipherSuitesAdded[csID] = struct{}{}
|
||||||
|
@ -112,7 +140,7 @@ func (cp *ConnectionPolicy) buildStandardTLSConfig(ctx caddy2.Context) error {
|
||||||
|
|
||||||
// add all the curve preferences in order, without duplicates
|
// add all the curve preferences in order, without duplicates
|
||||||
curvesAdded := make(map[tls.CurveID]struct{})
|
curvesAdded := make(map[tls.CurveID]struct{})
|
||||||
for _, curveName := range cp.Curves {
|
for _, curveName := range p.Curves {
|
||||||
curveID := supportedCurves[curveName]
|
curveID := supportedCurves[curveName]
|
||||||
if _, ok := curvesAdded[curveID]; !ok {
|
if _, ok := curvesAdded[curveID]; !ok {
|
||||||
curvesAdded[curveID] = struct{}{}
|
curvesAdded[curveID] = struct{}{}
|
||||||
|
@ -122,7 +150,7 @@ func (cp *ConnectionPolicy) buildStandardTLSConfig(ctx caddy2.Context) error {
|
||||||
|
|
||||||
// ensure ALPN includes the ACME TLS-ALPN protocol
|
// ensure ALPN includes the ACME TLS-ALPN protocol
|
||||||
var alpnFound bool
|
var alpnFound bool
|
||||||
for _, a := range cp.ALPN {
|
for _, a := range p.ALPN {
|
||||||
if a == tlsalpn01.ACMETLS1Protocol {
|
if a == tlsalpn01.ACMETLS1Protocol {
|
||||||
alpnFound = true
|
alpnFound = true
|
||||||
break
|
break
|
||||||
|
@ -133,23 +161,76 @@ func (cp *ConnectionPolicy) buildStandardTLSConfig(ctx caddy2.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// min and max protocol versions
|
// min and max protocol versions
|
||||||
if cp.ProtocolMin != "" {
|
if p.ProtocolMin != "" {
|
||||||
cfg.MinVersion = supportedProtocols[cp.ProtocolMin]
|
cfg.MinVersion = supportedProtocols[p.ProtocolMin]
|
||||||
}
|
}
|
||||||
if cp.ProtocolMax != "" {
|
if p.ProtocolMax != "" {
|
||||||
cfg.MaxVersion = supportedProtocols[cp.ProtocolMax]
|
cfg.MaxVersion = supportedProtocols[p.ProtocolMax]
|
||||||
}
|
}
|
||||||
if cp.ProtocolMin > cp.ProtocolMax {
|
if p.ProtocolMin > p.ProtocolMax {
|
||||||
return fmt.Errorf("protocol min (%x) cannot be greater than protocol max (%x)", cp.ProtocolMin, cp.ProtocolMax)
|
return fmt.Errorf("protocol min (%x) cannot be greater than protocol max (%x)", p.ProtocolMin, p.ProtocolMax)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: client auth, and other fields
|
// TODO: client auth, and other fields
|
||||||
|
|
||||||
cp.stdTLSConfig = cfg
|
p.stdTLSConfig = cfg
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CertSelectionPolicy represents a policy for selecting the certificate
|
||||||
|
// used to complete a handshake when there may be multiple options. All
|
||||||
|
// fields specified must match the candidate certificate for it to be chosen.
|
||||||
|
// This was needed to solve https://github.com/mholt/caddy/issues/2588.
|
||||||
|
type CertSelectionPolicy struct {
|
||||||
|
SerialNumber *big.Int `json:"serial_number,omitempty"`
|
||||||
|
SubjectOrganization string `json:"subject.organization,omitempty"`
|
||||||
|
PublicKeyAlgorithm pkAlgorithm `json:"public_key_algorithm,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeCertSelector(p *ConnectionPolicy) func(*tls.ClientHelloInfo, []certmagic.Certificate) (certmagic.Certificate, error) {
|
||||||
|
return func(hello *tls.ClientHelloInfo, choices []certmagic.Certificate) (certmagic.Certificate, error) {
|
||||||
|
for _, cert := range choices {
|
||||||
|
var matchOrg bool
|
||||||
|
if p.CertSelection.SubjectOrganization != "" {
|
||||||
|
for _, org := range cert.Subject.Organization {
|
||||||
|
if p.CertSelection.SubjectOrganization == org {
|
||||||
|
matchOrg = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matchOrg {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.CertSelection.PublicKeyAlgorithm != pkAlgorithm(x509.UnknownPublicKeyAlgorithm) &&
|
||||||
|
pkAlgorithm(cert.PublicKeyAlgorithm) != p.CertSelection.PublicKeyAlgorithm {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.CertSelection.SerialNumber != nil &&
|
||||||
|
cert.SerialNumber.Cmp(p.CertSelection.SerialNumber) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
return certmagic.Certificate{}, fmt.Errorf("no certificates matched custom selection policy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type pkAlgorithm x509.PublicKeyAlgorithm
|
||||||
|
|
||||||
|
// UnmarshalJSON satisfies json.Unmarshaler.
|
||||||
|
func (a *pkAlgorithm) UnmarshalJSON(b []byte) error {
|
||||||
|
algoStr := strings.ToLower(strings.Trim(string(b), `"`))
|
||||||
|
algo, ok := publicKeyAlgorithms[algoStr]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unrecognized public key algorithm: %s (expected one of %v)",
|
||||||
|
algoStr, publicKeyAlgorithms)
|
||||||
|
}
|
||||||
|
a = &algo
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ConnectionMatcher is a type which matches TLS handshakes.
|
// ConnectionMatcher is a type which matches TLS handshakes.
|
||||||
type ConnectionMatcher interface {
|
type ConnectionMatcher interface {
|
||||||
Match(*tls.ClientHelloInfo) bool
|
Match(*tls.ClientHelloInfo) bool
|
||||||
|
|
|
@ -11,7 +11,7 @@ type MatchServerName []string
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
caddy2.RegisterModule(caddy2.Module{
|
caddy2.RegisterModule(caddy2.Module{
|
||||||
Name: "tls.handshake_match.host",
|
Name: "tls.handshake_match.sni",
|
||||||
New: func() interface{} { return MatchServerName{} },
|
New: func() interface{} { return MatchServerName{} },
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package caddytls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -316,4 +317,11 @@ var supportedProtocols = map[string]uint16{
|
||||||
"tls1.3": tls.VersionTLS13,
|
"tls1.3": tls.VersionTLS13,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// publicKeyAlgorithms is the map of supported public key algorithms.
|
||||||
|
var publicKeyAlgorithms = map[string]pkAlgorithm{
|
||||||
|
"rsa": pkAlgorithm(x509.RSA),
|
||||||
|
"dsa": pkAlgorithm(x509.DSA),
|
||||||
|
"ecdsa": pkAlgorithm(x509.ECDSA),
|
||||||
|
}
|
||||||
|
|
||||||
const automateKey = "automate"
|
const automateKey = "automate"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user