caddy/caddyhttp/proxy/reverseproxy.go
Leonard Hecker 9f9ad21aaa Fixed #1292: Failure to proxy WebSockets over HTTPS
This issue was caused by connHijackerTransport trying to record HTTP
response headers by "hijacking" the Read() method of the plain net.Conn.
This does not simply work over TLS though since this will record the TLS
handshake and encrypted data instead of the actual content.
This commit fixes the problem by providing an alternative transport.DialTLS
which correctly hijacks the overlying tls.Conn instead.
2016-12-26 20:52:36 +01:00

525 lines
14 KiB
Go

// This file is adapted from code in the net/http/httputil
// package of the Go standard library, which is by the
// Go Authors, and bears this copyright and license info:
//
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// This file has been modified from the standard lib to
// meet the needs of the application.
package proxy
import (
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
"path"
"strings"
"sync"
"time"
"golang.org/x/net/http2"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
var defaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
var bufferPool = sync.Pool{New: createBuffer}
func createBuffer() interface{} {
return make([]byte, 32*1024)
}
// onExitFlushLoop is a callback set by tests to detect the state of the
// flushLoop() goroutine.
var onExitFlushLoop func()
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
type ReverseProxy struct {
// Director must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
Director func(*http.Request)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body.
// If zero, no periodic flushing is done.
FlushInterval time.Duration
}
// Though the relevant directive prefix is just "unix:", url.Parse
// will - assuming the regular URL scheme - add additional slashes
// as if "unix" was a request protocol.
// What we need is just the path, so if "unix:/var/run/www.socket"
// was the proxy directive, the parsed hostName would be
// "unix:///var/run/www.socket", hence the ambiguous trimming.
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
return func(network, addr string) (conn net.Conn, err error) {
return net.Dial("unix", hostName[len("unix://"):])
}
}
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
// Without logic: target's path is "/", incoming is "/api/messages",
// without is "/api", then the target request will be for /messages.
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
if target.Scheme == "unix" {
// to make Dial work with unix URL,
// scheme and host have to be faked
req.URL.Scheme = "http"
req.URL.Host = "socket"
} else {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
}
// We should remove the `without` prefix at first.
if without != "" {
req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
if req.URL.Opaque != "" {
req.URL.Opaque = strings.TrimPrefix(req.URL.Opaque, without)
}
if req.URL.RawPath != "" {
req.URL.RawPath = strings.TrimPrefix(req.URL.RawPath, without)
}
}
hadTrailingSlash := strings.HasSuffix(req.URL.Path, "/")
req.URL.Path = path.Join(target.Path, req.URL.Path)
// path.Join will strip off the last /, so put it back if it was there.
if hadTrailingSlash && !strings.HasSuffix(req.URL.Path, "/") {
req.URL.Path = req.URL.Path + "/"
}
// Trims the path of the socket from the URL path.
// This is done because req.URL passed to your proxied service
// will have the full path of the socket file prefixed to it.
// Calling /test on a server that proxies requests to
// unix:/var/run/www.socket will thus set the requested path
// to /var/run/www.socket/test, rendering paths useless.
if target.Scheme == "unix" {
// See comment on socketDial for the trim
socketPrefix := target.String()[len("unix://"):]
req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix)
}
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
}
} else if keepalive != http.DefaultMaxIdleConnsPerHost {
// if keepalive is equal to the default,
// just use default transport, to avoid creating
// a brand new transport
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
if keepalive == 0 {
transport.DisableKeepAlives = true
} else {
transport.MaxIdleConnsPerHost = keepalive
}
if httpserver.HTTP2 {
http2.ConfigureTransport(transport)
}
rp.Transport = transport
}
return rp
}
// UseInsecureTransport is used to facilitate HTTPS proxying
// when it is OK for upstream to be using a bad certificate,
// since this transport skips verification.
func (rp *ReverseProxy) UseInsecureTransport() {
if rp.Transport == nil {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
if httpserver.HTTP2 {
http2.ConfigureTransport(transport)
}
rp.Transport = transport
} else if transport, ok := rp.Transport.(*http.Transport); ok {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
// No http2.ConfigureTransport() here.
// For now this is only added in places where
// an http.Transport is actually created.
}
}
// ServeHTTP serves the proxied request to the upstream by performing a roundtrip.
// It is designed to handle websocket connection upgrades as well.
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
transport := rp.Transport
if requestIsWebsocket(outreq) {
transport = newConnHijackerTransport(transport)
} else if transport == nil {
transport = http.DefaultTransport
}
rp.Director(outreq)
outreq.Proto = "HTTP/1.1"
outreq.ProtoMajor = 1
outreq.ProtoMinor = 1
outreq.Close = false
res, err := transport.RoundTrip(outreq)
if err != nil {
return err
}
if respUpdateFn != nil {
respUpdateFn(res)
}
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
res.Body.Close()
hj, ok := rw.(http.Hijacker)
if !ok {
panic(httpserver.NonHijackerError{Underlying: rw})
}
conn, _, err := hj.Hijack()
if err != nil {
return err
}
defer conn.Close()
var backendConn net.Conn
if hj, ok := transport.(*connHijackerTransport); ok {
backendConn = hj.Conn
if _, err := conn.Write(hj.Replay); err != nil {
return err
}
bufferPool.Put(hj.Replay)
} else {
backendConn, err = net.Dial("tcp", outreq.URL.Host)
if err != nil {
return err
}
outreq.Write(backendConn)
}
defer backendConn.Close()
go func() {
io.Copy(backendConn, conn) // write tcp stream to backend.
}()
io.Copy(conn, backendConn) // read tcp stream from backend.
} else {
defer res.Body.Close()
for _, h := range hopHeaders {
res.Header.Del(h)
}
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
rp.copyResponse(rw, res.Body)
}
return nil
}
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
if rp.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
dst: wf,
latency: rp.FlushInterval,
done: make(chan bool),
}
go mlw.flushLoop()
defer mlw.stop()
dst = mlw
}
}
// `CopyBuffer` only uses `buf` up to it's length and
// panics if it's 0 => Extend it's length up to it's capacity.
io.CopyBuffer(dst, src, buf[:cap(buf)])
}
// skip these headers if they already exist.
// see https://github.com/mholt/caddy/pull/1112#discussion_r80092582
var skipHeaders = map[string]struct{}{
"Content-Type": {},
"Content-Disposition": {},
"Accept-Ranges": {},
"Set-Cookie": {},
"Cache-Control": {},
"Expires": {},
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
if _, ok := dst[k]; ok {
// skip some predefined headers
// see https://github.com/mholt/caddy/issues/1086
if _, shouldSkip := skipHeaders[k]; shouldSkip {
continue
}
// otherwise, overwrite
dst.Del(k)
}
for _, v := range vv {
dst.Add(k, v)
}
}
}
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
"Alternate-Protocol",
"Alt-Svc",
}
type respUpdateFn func(resp *http.Response)
type hijackedConn struct {
net.Conn
hj *connHijackerTransport
}
func (c *hijackedConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
c.hj.Replay = append(c.hj.Replay, b[:n]...)
return
}
func (c *hijackedConn) Close() error {
return nil
}
type connHijackerTransport struct {
*http.Transport
Conn net.Conn
Replay []byte
}
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
t := &http.Transport{
MaxIdleConnsPerHost: -1,
}
if b, _ := base.(*http.Transport); b != nil {
t.Proxy = b.Proxy
t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
t.TLSClientConfig.NextProtos = nil
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
t.Dial = b.Dial
t.DialTLS = b.DialTLS
} else {
t.Proxy = http.ProxyFromEnvironment
t.TLSHandshakeTimeout = 10 * time.Second
}
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
dial := getTransportDial(t)
dialTLS := getTransportDialTLS(t)
t.Dial = func(network, addr string) (net.Conn, error) {
c, err := dial(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
if dialTLS != nil {
t.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := dialTLS(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
}
return hj
}
// getTransportDial always returns a plain Dialer
// and defaults to the existing t.Dial.
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
if t.Dial != nil {
return t.Dial
}
return defaultDialer.Dial
}
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
// and defaults to the existing t.DialTLS.
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
}
if t.TLSClientConfig == nil {
return nil
}
// newConnHijackerTransport will modify t.Dial after calling this method
// => Create a backup reference.
plainDial := getTransportDial(t)
return func(network, addr string) (net.Conn, error) {
plainConn, err := plainDial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(plainConn, t.TLSClientConfig)
errc := make(chan error, 2)
var timer *time.Timer
if d := t.TLSHandshakeTimeout; d != 0 {
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
}
go func() {
err := tlsConn.Handshake()
if timer != nil {
timer.Stop()
}
errc <- err
}()
if err := <-errc; err != nil {
plainConn.Close()
return nil, err
}
if !t.TLSClientConfig.InsecureSkipVerify {
serverName := t.TLSClientConfig.ServerName
if serverName == "" {
serverName = addr
idx := strings.LastIndex(serverName, ":")
if idx != -1 {
serverName = serverName[:idx]
}
}
if err := tlsConn.VerifyHostname(serverName); err != nil {
plainConn.Close()
return nil, err
}
}
return tlsConn, nil
}
}
type tlsHandshakeTimeoutError struct{}
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
// cloneTLSClientConfig is like cloneTLSConfig but omits
// the fields SessionTicketsDisabled and SessionTicketKey.
// This makes it safe to call cloneTLSClientConfig on a config
// in active use by a server.
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
}
func requestIsWebsocket(req *http.Request) bool {
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
latency time.Duration
lk sync.Mutex // protects Write + Flush
done chan bool
}
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
m.lk.Lock()
defer m.lk.Unlock()
return m.dst.Write(p)
}
func (m *maxLatencyWriter) flushLoop() {
t := time.NewTicker(m.latency)
defer t.Stop()
for {
select {
case <-m.done:
if onExitFlushLoop != nil {
onExitFlushLoop()
}
return
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
}
}
}
func (m *maxLatencyWriter) stop() { m.done <- true }