mirror of
https://github.com/caddyserver/caddy.git
synced 2024-11-26 02:09:47 +08:00
Merge pull request #757 from mholt/extend-tls-client-auth
Extend tls client auth
This commit is contained in:
commit
ddf4b1fd3b
|
@ -83,10 +83,30 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
|
|||
c.TLS.Ciphers = append(c.TLS.Ciphers, value)
|
||||
}
|
||||
case "clients":
|
||||
c.TLS.ClientCerts = c.RemainingArgs()
|
||||
if len(c.TLS.ClientCerts) == 0 {
|
||||
clientCertList := c.RemainingArgs()
|
||||
if len(clientCertList) == 0 {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
|
||||
listStart, mustProvideCA := 1, true
|
||||
switch clientCertList[0] {
|
||||
case "request":
|
||||
c.TLS.ClientAuth = tls.RequestClientCert
|
||||
mustProvideCA = false
|
||||
case "require":
|
||||
c.TLS.ClientAuth = tls.RequireAnyClientCert
|
||||
mustProvideCA = false
|
||||
case "verify_if_given":
|
||||
c.TLS.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
default:
|
||||
c.TLS.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
listStart = 0
|
||||
}
|
||||
if mustProvideCA && len(clientCertList) <= listStart {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
|
||||
c.TLS.ClientCerts = clientCertList[listStart:]
|
||||
case "load":
|
||||
c.Args(&loadDir)
|
||||
c.TLS.Manual = true
|
||||
|
|
|
@ -189,34 +189,69 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSetupParseWithClientAuth(t *testing.T) {
|
||||
// Test missing client cert file
|
||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients client_ca.crt client2_ca.crt
|
||||
clients
|
||||
}`
|
||||
c := setup.NewTestController(params)
|
||||
_, err := Setup(c)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
|
||||
if count := len(c.TLS.ClientCerts); count != 2 {
|
||||
t.Fatalf("Expected two client certs, had %d", count)
|
||||
}
|
||||
if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" {
|
||||
t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual)
|
||||
}
|
||||
if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" {
|
||||
t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual)
|
||||
}
|
||||
|
||||
// Test missing client cert file
|
||||
params = `tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients
|
||||
}`
|
||||
c = setup.NewTestController(params)
|
||||
_, err = Setup(c)
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, but no error returned")
|
||||
}
|
||||
|
||||
noCAs, twoCAs := []string{}, []string{"client_ca.crt", "client2_ca.crt"}
|
||||
for caseNumber, caseData := range []struct {
|
||||
params string
|
||||
clientAuthType tls.ClientAuthType
|
||||
expectedErr bool
|
||||
expectedCAs []string
|
||||
}{
|
||||
{"", tls.NoClientCert, false, noCAs},
|
||||
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients client_ca.crt client2_ca.crt
|
||||
}`, tls.RequireAndVerifyClientCert, false, twoCAs},
|
||||
// now come modifier
|
||||
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients request
|
||||
}`, tls.RequestClientCert, false, noCAs},
|
||||
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients require
|
||||
}`, tls.RequireAnyClientCert, false, noCAs},
|
||||
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients verify_if_given client_ca.crt client2_ca.crt
|
||||
}`, tls.VerifyClientCertIfGiven, false, twoCAs},
|
||||
{`tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients verify_if_given
|
||||
}`, tls.VerifyClientCertIfGiven, true, noCAs},
|
||||
} {
|
||||
c := setup.NewTestController(caseData.params)
|
||||
_, err := Setup(c)
|
||||
if caseData.expectedErr {
|
||||
if err == nil {
|
||||
t.Errorf("In case %d: Expected an error, got: %v", caseNumber, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err)
|
||||
}
|
||||
|
||||
if caseData.clientAuthType != c.TLS.ClientAuth {
|
||||
t.Errorf("In case %d: Expected TLS client auth type %v, got: %v",
|
||||
caseNumber, caseData.clientAuthType, c.TLS.ClientAuth)
|
||||
}
|
||||
|
||||
if count := len(c.TLS.ClientCerts); count < len(caseData.expectedCAs) {
|
||||
t.Fatalf("In case %d: Expected %d client certs, had %d", caseNumber, len(caseData.expectedCAs), count)
|
||||
}
|
||||
|
||||
for idx, expected := range caseData.expectedCAs {
|
||||
if actual := c.TLS.ClientCerts[idx]; actual != expected {
|
||||
t.Errorf("In case %d: Expected %dth client cert file to be '%s', but was '%s'",
|
||||
caseNumber, idx, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupParseWithKeyType(t *testing.T) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/mholt/caddy/middleware"
|
||||
|
@ -75,4 +76,5 @@ type TLSConfig struct {
|
|||
ProtocolMaxVersion uint16
|
||||
PreferServerCipherSuites bool
|
||||
ClientCerts []string
|
||||
ClientAuth tls.ClientAuthType
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -332,6 +333,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// Use URL.RawPath If you need the original, "raw" URL.Path in your middleware.
|
||||
// Collapse any ./ ../ /// madness here instead of doing that in every plugin.
|
||||
if r.URL.Path != "/" {
|
||||
path := filepath.Clean(r.URL.Path)
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
r.URL.Path = path
|
||||
}
|
||||
|
||||
// Execute the optional request callback if it exists and it's not disabled
|
||||
if s.ReqCallback != nil && !s.vhosts[host].config.TLS.Manual && s.ReqCallback(w, r) {
|
||||
return
|
||||
|
@ -368,17 +379,19 @@ func DefaultErrorFunc(w http.ResponseWriter, r *http.Request, status int) {
|
|||
// setupClientAuth sets up TLS client authentication only if
|
||||
// any of the TLS configs specified at least one cert file.
|
||||
func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error {
|
||||
var clientAuth bool
|
||||
whatClientAuth := tls.NoClientCert
|
||||
for _, cfg := range tlsConfigs {
|
||||
if len(cfg.ClientCerts) > 0 {
|
||||
clientAuth = true
|
||||
break
|
||||
if whatClientAuth < cfg.ClientAuth { // Use the most restrictive.
|
||||
whatClientAuth = cfg.ClientAuth
|
||||
}
|
||||
}
|
||||
|
||||
if clientAuth {
|
||||
if whatClientAuth != tls.NoClientCert {
|
||||
pool := x509.NewCertPool()
|
||||
for _, cfg := range tlsConfigs {
|
||||
if len(cfg.ClientCerts) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, caFile := range cfg.ClientCerts {
|
||||
caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect
|
||||
if err != nil {
|
||||
|
@ -390,7 +403,7 @@ func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error {
|
|||
}
|
||||
}
|
||||
config.ClientCAs = pool
|
||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
config.ClientAuth = whatClientAuth
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Reference in New Issue
Block a user