mirror of
https://github.com/caddyserver/caddy.git
synced 2024-11-26 02:09:47 +08:00
9f9ad21aaa
This issue was caused by connHijackerTransport trying to record HTTP response headers by "hijacking" the Read() method of the plain net.Conn. This does not simply work over TLS though since this will record the TLS handshake and encrypted data instead of the actual content. This commit fixes the problem by providing an alternative transport.DialTLS which correctly hijacks the overlying tls.Conn instead.
525 lines
14 KiB
Go
525 lines
14 KiB
Go
// This file is adapted from code in the net/http/httputil
|
|
// package of the Go standard library, which is by the
|
|
// Go Authors, and bears this copyright and license info:
|
|
//
|
|
// Copyright 2011 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
//
|
|
// This file has been modified from the standard lib to
|
|
// meet the needs of the application.
|
|
|
|
package proxy
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/net/http2"
|
|
|
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
|
)
|
|
|
|
var defaultDialer = &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
|
|
var bufferPool = sync.Pool{New: createBuffer}
|
|
|
|
func createBuffer() interface{} {
|
|
return make([]byte, 32*1024)
|
|
}
|
|
|
|
// onExitFlushLoop is a callback set by tests to detect the state of the
|
|
// flushLoop() goroutine.
|
|
var onExitFlushLoop func()
|
|
|
|
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
|
// sends it to another server, proxying the response back to the
|
|
// client.
|
|
type ReverseProxy struct {
|
|
// Director must be a function which modifies
|
|
// the request into a new request to be sent
|
|
// using Transport. Its response is then copied
|
|
// back to the original client unmodified.
|
|
Director func(*http.Request)
|
|
|
|
// The transport used to perform proxy requests.
|
|
// If nil, http.DefaultTransport is used.
|
|
Transport http.RoundTripper
|
|
|
|
// FlushInterval specifies the flush interval
|
|
// to flush to the client while copying the
|
|
// response body.
|
|
// If zero, no periodic flushing is done.
|
|
FlushInterval time.Duration
|
|
}
|
|
|
|
// Though the relevant directive prefix is just "unix:", url.Parse
|
|
// will - assuming the regular URL scheme - add additional slashes
|
|
// as if "unix" was a request protocol.
|
|
// What we need is just the path, so if "unix:/var/run/www.socket"
|
|
// was the proxy directive, the parsed hostName would be
|
|
// "unix:///var/run/www.socket", hence the ambiguous trimming.
|
|
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
|
|
return func(network, addr string) (conn net.Conn, err error) {
|
|
return net.Dial("unix", hostName[len("unix://"):])
|
|
}
|
|
}
|
|
|
|
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
|
|
// URLs to the scheme, host, and base path provided in target. If the
|
|
// target's path is "/base" and the incoming request was for "/dir",
|
|
// the target request will be for /base/dir.
|
|
// Without logic: target's path is "/", incoming is "/api/messages",
|
|
// without is "/api", then the target request will be for /messages.
|
|
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
|
|
targetQuery := target.RawQuery
|
|
director := func(req *http.Request) {
|
|
if target.Scheme == "unix" {
|
|
// to make Dial work with unix URL,
|
|
// scheme and host have to be faked
|
|
req.URL.Scheme = "http"
|
|
req.URL.Host = "socket"
|
|
} else {
|
|
req.URL.Scheme = target.Scheme
|
|
req.URL.Host = target.Host
|
|
}
|
|
|
|
// We should remove the `without` prefix at first.
|
|
if without != "" {
|
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
|
|
if req.URL.Opaque != "" {
|
|
req.URL.Opaque = strings.TrimPrefix(req.URL.Opaque, without)
|
|
}
|
|
if req.URL.RawPath != "" {
|
|
req.URL.RawPath = strings.TrimPrefix(req.URL.RawPath, without)
|
|
}
|
|
}
|
|
|
|
hadTrailingSlash := strings.HasSuffix(req.URL.Path, "/")
|
|
req.URL.Path = path.Join(target.Path, req.URL.Path)
|
|
// path.Join will strip off the last /, so put it back if it was there.
|
|
if hadTrailingSlash && !strings.HasSuffix(req.URL.Path, "/") {
|
|
req.URL.Path = req.URL.Path + "/"
|
|
}
|
|
|
|
// Trims the path of the socket from the URL path.
|
|
// This is done because req.URL passed to your proxied service
|
|
// will have the full path of the socket file prefixed to it.
|
|
// Calling /test on a server that proxies requests to
|
|
// unix:/var/run/www.socket will thus set the requested path
|
|
// to /var/run/www.socket/test, rendering paths useless.
|
|
if target.Scheme == "unix" {
|
|
// See comment on socketDial for the trim
|
|
socketPrefix := target.String()[len("unix://"):]
|
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix)
|
|
}
|
|
|
|
if targetQuery == "" || req.URL.RawQuery == "" {
|
|
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
|
} else {
|
|
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
|
}
|
|
}
|
|
rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
|
|
if target.Scheme == "unix" {
|
|
rp.Transport = &http.Transport{
|
|
Dial: socketDial(target.String()),
|
|
}
|
|
} else if keepalive != http.DefaultMaxIdleConnsPerHost {
|
|
// if keepalive is equal to the default,
|
|
// just use default transport, to avoid creating
|
|
// a brand new transport
|
|
transport := &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
Dial: defaultDialer.Dial,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
if keepalive == 0 {
|
|
transport.DisableKeepAlives = true
|
|
} else {
|
|
transport.MaxIdleConnsPerHost = keepalive
|
|
}
|
|
if httpserver.HTTP2 {
|
|
http2.ConfigureTransport(transport)
|
|
}
|
|
rp.Transport = transport
|
|
}
|
|
return rp
|
|
}
|
|
|
|
// UseInsecureTransport is used to facilitate HTTPS proxying
|
|
// when it is OK for upstream to be using a bad certificate,
|
|
// since this transport skips verification.
|
|
func (rp *ReverseProxy) UseInsecureTransport() {
|
|
if rp.Transport == nil {
|
|
transport := &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
Dial: defaultDialer.Dial,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
if httpserver.HTTP2 {
|
|
http2.ConfigureTransport(transport)
|
|
}
|
|
rp.Transport = transport
|
|
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
|
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
|
// No http2.ConfigureTransport() here.
|
|
// For now this is only added in places where
|
|
// an http.Transport is actually created.
|
|
}
|
|
}
|
|
|
|
// ServeHTTP serves the proxied request to the upstream by performing a roundtrip.
|
|
// It is designed to handle websocket connection upgrades as well.
|
|
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
|
|
transport := rp.Transport
|
|
if requestIsWebsocket(outreq) {
|
|
transport = newConnHijackerTransport(transport)
|
|
} else if transport == nil {
|
|
transport = http.DefaultTransport
|
|
}
|
|
|
|
rp.Director(outreq)
|
|
outreq.Proto = "HTTP/1.1"
|
|
outreq.ProtoMajor = 1
|
|
outreq.ProtoMinor = 1
|
|
outreq.Close = false
|
|
|
|
res, err := transport.RoundTrip(outreq)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if respUpdateFn != nil {
|
|
respUpdateFn(res)
|
|
}
|
|
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
|
res.Body.Close()
|
|
hj, ok := rw.(http.Hijacker)
|
|
if !ok {
|
|
panic(httpserver.NonHijackerError{Underlying: rw})
|
|
}
|
|
|
|
conn, _, err := hj.Hijack()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Close()
|
|
|
|
var backendConn net.Conn
|
|
if hj, ok := transport.(*connHijackerTransport); ok {
|
|
backendConn = hj.Conn
|
|
if _, err := conn.Write(hj.Replay); err != nil {
|
|
return err
|
|
}
|
|
bufferPool.Put(hj.Replay)
|
|
} else {
|
|
backendConn, err = net.Dial("tcp", outreq.URL.Host)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
outreq.Write(backendConn)
|
|
}
|
|
defer backendConn.Close()
|
|
|
|
go func() {
|
|
io.Copy(backendConn, conn) // write tcp stream to backend.
|
|
}()
|
|
io.Copy(conn, backendConn) // read tcp stream from backend.
|
|
} else {
|
|
defer res.Body.Close()
|
|
for _, h := range hopHeaders {
|
|
res.Header.Del(h)
|
|
}
|
|
copyHeader(rw.Header(), res.Header)
|
|
rw.WriteHeader(res.StatusCode)
|
|
rp.copyResponse(rw, res.Body)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
|
buf := bufferPool.Get().([]byte)
|
|
defer bufferPool.Put(buf)
|
|
|
|
if rp.FlushInterval != 0 {
|
|
if wf, ok := dst.(writeFlusher); ok {
|
|
mlw := &maxLatencyWriter{
|
|
dst: wf,
|
|
latency: rp.FlushInterval,
|
|
done: make(chan bool),
|
|
}
|
|
go mlw.flushLoop()
|
|
defer mlw.stop()
|
|
dst = mlw
|
|
}
|
|
}
|
|
|
|
// `CopyBuffer` only uses `buf` up to it's length and
|
|
// panics if it's 0 => Extend it's length up to it's capacity.
|
|
io.CopyBuffer(dst, src, buf[:cap(buf)])
|
|
}
|
|
|
|
// skip these headers if they already exist.
|
|
// see https://github.com/mholt/caddy/pull/1112#discussion_r80092582
|
|
var skipHeaders = map[string]struct{}{
|
|
"Content-Type": {},
|
|
"Content-Disposition": {},
|
|
"Accept-Ranges": {},
|
|
"Set-Cookie": {},
|
|
"Cache-Control": {},
|
|
"Expires": {},
|
|
}
|
|
|
|
func copyHeader(dst, src http.Header) {
|
|
for k, vv := range src {
|
|
if _, ok := dst[k]; ok {
|
|
// skip some predefined headers
|
|
// see https://github.com/mholt/caddy/issues/1086
|
|
if _, shouldSkip := skipHeaders[k]; shouldSkip {
|
|
continue
|
|
}
|
|
// otherwise, overwrite
|
|
dst.Del(k)
|
|
}
|
|
for _, v := range vv {
|
|
dst.Add(k, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
|
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
|
var hopHeaders = []string{
|
|
"Connection",
|
|
"Keep-Alive",
|
|
"Proxy-Authenticate",
|
|
"Proxy-Authorization",
|
|
"Te", // canonicalized version of "TE"
|
|
"Trailers",
|
|
"Transfer-Encoding",
|
|
"Upgrade",
|
|
"Alternate-Protocol",
|
|
"Alt-Svc",
|
|
}
|
|
|
|
type respUpdateFn func(resp *http.Response)
|
|
|
|
type hijackedConn struct {
|
|
net.Conn
|
|
hj *connHijackerTransport
|
|
}
|
|
|
|
func (c *hijackedConn) Read(b []byte) (n int, err error) {
|
|
n, err = c.Conn.Read(b)
|
|
c.hj.Replay = append(c.hj.Replay, b[:n]...)
|
|
return
|
|
}
|
|
|
|
func (c *hijackedConn) Close() error {
|
|
return nil
|
|
}
|
|
|
|
type connHijackerTransport struct {
|
|
*http.Transport
|
|
Conn net.Conn
|
|
Replay []byte
|
|
}
|
|
|
|
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
|
t := &http.Transport{
|
|
MaxIdleConnsPerHost: -1,
|
|
}
|
|
if b, _ := base.(*http.Transport); b != nil {
|
|
t.Proxy = b.Proxy
|
|
t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
|
|
t.TLSClientConfig.NextProtos = nil
|
|
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
|
t.Dial = b.Dial
|
|
t.DialTLS = b.DialTLS
|
|
} else {
|
|
t.Proxy = http.ProxyFromEnvironment
|
|
t.TLSHandshakeTimeout = 10 * time.Second
|
|
}
|
|
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
|
|
|
|
dial := getTransportDial(t)
|
|
dialTLS := getTransportDialTLS(t)
|
|
|
|
t.Dial = func(network, addr string) (net.Conn, error) {
|
|
c, err := dial(network, addr)
|
|
hj.Conn = c
|
|
return &hijackedConn{c, hj}, err
|
|
}
|
|
|
|
if dialTLS != nil {
|
|
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
|
c, err := dialTLS(network, addr)
|
|
hj.Conn = c
|
|
return &hijackedConn{c, hj}, err
|
|
}
|
|
}
|
|
|
|
return hj
|
|
}
|
|
|
|
// getTransportDial always returns a plain Dialer
|
|
// and defaults to the existing t.Dial.
|
|
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
|
if t.Dial != nil {
|
|
return t.Dial
|
|
}
|
|
return defaultDialer.Dial
|
|
}
|
|
|
|
// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
|
|
// and defaults to the existing t.DialTLS.
|
|
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
|
if t.DialTLS != nil {
|
|
return t.DialTLS
|
|
}
|
|
if t.TLSClientConfig == nil {
|
|
return nil
|
|
}
|
|
|
|
// newConnHijackerTransport will modify t.Dial after calling this method
|
|
// => Create a backup reference.
|
|
plainDial := getTransportDial(t)
|
|
|
|
return func(network, addr string) (net.Conn, error) {
|
|
plainConn, err := plainDial(network, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tlsConn := tls.Client(plainConn, t.TLSClientConfig)
|
|
errc := make(chan error, 2)
|
|
var timer *time.Timer
|
|
if d := t.TLSHandshakeTimeout; d != 0 {
|
|
timer = time.AfterFunc(d, func() {
|
|
errc <- tlsHandshakeTimeoutError{}
|
|
})
|
|
}
|
|
go func() {
|
|
err := tlsConn.Handshake()
|
|
if timer != nil {
|
|
timer.Stop()
|
|
}
|
|
errc <- err
|
|
}()
|
|
if err := <-errc; err != nil {
|
|
plainConn.Close()
|
|
return nil, err
|
|
}
|
|
if !t.TLSClientConfig.InsecureSkipVerify {
|
|
serverName := t.TLSClientConfig.ServerName
|
|
if serverName == "" {
|
|
serverName = addr
|
|
idx := strings.LastIndex(serverName, ":")
|
|
if idx != -1 {
|
|
serverName = serverName[:idx]
|
|
}
|
|
}
|
|
if err := tlsConn.VerifyHostname(serverName); err != nil {
|
|
plainConn.Close()
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return tlsConn, nil
|
|
}
|
|
}
|
|
|
|
type tlsHandshakeTimeoutError struct{}
|
|
|
|
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
|
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
|
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
|
|
|
// cloneTLSClientConfig is like cloneTLSConfig but omits
|
|
// the fields SessionTicketsDisabled and SessionTicketKey.
|
|
// This makes it safe to call cloneTLSClientConfig on a config
|
|
// in active use by a server.
|
|
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
|
|
if cfg == nil {
|
|
return &tls.Config{}
|
|
}
|
|
return &tls.Config{
|
|
Rand: cfg.Rand,
|
|
Time: cfg.Time,
|
|
Certificates: cfg.Certificates,
|
|
NameToCertificate: cfg.NameToCertificate,
|
|
GetCertificate: cfg.GetCertificate,
|
|
RootCAs: cfg.RootCAs,
|
|
NextProtos: cfg.NextProtos,
|
|
ServerName: cfg.ServerName,
|
|
ClientAuth: cfg.ClientAuth,
|
|
ClientCAs: cfg.ClientCAs,
|
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
|
CipherSuites: cfg.CipherSuites,
|
|
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
|
ClientSessionCache: cfg.ClientSessionCache,
|
|
MinVersion: cfg.MinVersion,
|
|
MaxVersion: cfg.MaxVersion,
|
|
CurvePreferences: cfg.CurvePreferences,
|
|
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
|
|
Renegotiation: cfg.Renegotiation,
|
|
}
|
|
}
|
|
|
|
func requestIsWebsocket(req *http.Request) bool {
|
|
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
|
}
|
|
|
|
type writeFlusher interface {
|
|
io.Writer
|
|
http.Flusher
|
|
}
|
|
|
|
type maxLatencyWriter struct {
|
|
dst writeFlusher
|
|
latency time.Duration
|
|
|
|
lk sync.Mutex // protects Write + Flush
|
|
done chan bool
|
|
}
|
|
|
|
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
|
|
m.lk.Lock()
|
|
defer m.lk.Unlock()
|
|
return m.dst.Write(p)
|
|
}
|
|
|
|
func (m *maxLatencyWriter) flushLoop() {
|
|
t := time.NewTicker(m.latency)
|
|
defer t.Stop()
|
|
for {
|
|
select {
|
|
case <-m.done:
|
|
if onExitFlushLoop != nil {
|
|
onExitFlushLoop()
|
|
}
|
|
return
|
|
case <-t.C:
|
|
m.lk.Lock()
|
|
m.dst.Flush()
|
|
m.lk.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *maxLatencyWriter) stop() { m.done <- true }
|