mirror of
https://github.com/caddyserver/caddy.git
synced 2024-11-25 17:56:34 +08:00
proxy: Fixed support for TLS verification of WebSocket connections
This commit is contained in:
parent
153d4a5ac6
commit
b857265f9c
|
@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
|||
MaxIdleConnsPerHost: -1,
|
||||
}
|
||||
if b, _ := base.(*http.Transport); b != nil {
|
||||
tlsClientConfig := b.TLSClientConfig
|
||||
if tlsClientConfig.NextProtos != nil {
|
||||
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
|
||||
tlsClientConfig.NextProtos = nil
|
||||
}
|
||||
|
||||
t.Proxy = b.Proxy
|
||||
t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
|
||||
t.TLSClientConfig.NextProtos = nil
|
||||
t.TLSClientConfig = tlsClientConfig
|
||||
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
||||
t.Dial = b.Dial
|
||||
t.DialTLS = b.DialTLS
|
||||
|
@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
|||
|
||||
dial := getTransportDial(t)
|
||||
dialTLS := getTransportDialTLS(t)
|
||||
|
||||
t.Dial = func(network, addr string) (net.Conn, error) {
|
||||
c, err := dial(network, addr)
|
||||
hj.Conn = c
|
||||
return &hijackedConn{c, hj}, err
|
||||
}
|
||||
|
||||
if dialTLS != nil {
|
||||
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||
c, err := dialTLS(network, addr)
|
||||
hj.Conn = c
|
||||
return &hijackedConn{c, hj}, err
|
||||
}
|
||||
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||
c, err := dialTLS(network, addr)
|
||||
hj.Conn = c
|
||||
return &hijackedConn{c, hj}, err
|
||||
}
|
||||
|
||||
return hj
|
||||
|
@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e
|
|||
return defaultDialer.Dial
|
||||
}
|
||||
|
||||
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
|
||||
// getTransportDial always returns a TLS Dialer
|
||||
// and defaults to the existing t.DialTLS.
|
||||
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||
if t.DialTLS != nil {
|
||||
return t.DialTLS
|
||||
}
|
||||
if t.TLSClientConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// newConnHijackerTransport will modify t.Dial after calling this method
|
||||
// => Create a backup reference.
|
||||
plainDial := getTransportDial(t)
|
||||
|
||||
// The following DialTLS implementation stems from the Go stdlib and
|
||||
// is identical to what happens if DialTLS is not provided.
|
||||
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
|
||||
return func(network, addr string) (net.Conn, error) {
|
||||
plainConn, err := plainDial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(plainConn, t.TLSClientConfig)
|
||||
tlsClientConfig := t.TLSClientConfig
|
||||
if tlsClientConfig == nil {
|
||||
tlsClientConfig = &tls.Config{}
|
||||
}
|
||||
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
|
||||
tlsClientConfig.ServerName = stripPort(addr)
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(plainConn, tlsClientConfig)
|
||||
errc := make(chan error, 2)
|
||||
var timer *time.Timer
|
||||
if d := t.TLSHandshakeTimeout; d != 0 {
|
||||
|
@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
|
|||
plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !t.TLSClientConfig.InsecureSkipVerify {
|
||||
serverName := t.TLSClientConfig.ServerName
|
||||
if serverName == "" {
|
||||
serverName = addr
|
||||
idx := strings.LastIndex(serverName, ":")
|
||||
if idx != -1 {
|
||||
serverName = serverName[:idx]
|
||||
}
|
||||
if !tlsClientConfig.InsecureSkipVerify {
|
||||
hostname := tlsClientConfig.ServerName
|
||||
if hostname == "" {
|
||||
hostname = stripPort(addr)
|
||||
}
|
||||
if err := tlsConn.VerifyHostname(serverName); err != nil {
|
||||
if err := tlsConn.VerifyHostname(hostname); err != nil {
|
||||
plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
|
|||
}
|
||||
}
|
||||
|
||||
// stripPort returns address without its port if it has one and
|
||||
// works with IP addresses as well as hostnames formatted as host:port.
|
||||
//
|
||||
// IPv6 addresses (excluding the port) must be enclosed in
|
||||
// square brackets similar to the requirements of Go's stdlib.
|
||||
func stripPort(address string) string {
|
||||
// Keep in mind that the address might be a IPv6 address
|
||||
// and thus contain a colon, but not have a port.
|
||||
portIdx := strings.LastIndex(address, ":")
|
||||
ipv6Idx := strings.LastIndex(address, "]")
|
||||
if portIdx > ipv6Idx {
|
||||
address = address[:portIdx]
|
||||
}
|
||||
return address
|
||||
}
|
||||
|
||||
type tlsHandshakeTimeoutError struct{}
|
||||
|
||||
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||
|
|
Loading…
Reference in New Issue
Block a user