2017-09-23 13:56:58 +08:00
|
|
|
// Copyright 2015 Light Code Labs, LLC
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
|
2016-06-06 11:51:56 +08:00
|
|
|
// This file is adapted from code in the net/http/httputil
|
|
|
|
// package of the Go standard library, which is by the
|
|
|
|
// Go Authors, and bears this copyright and license info:
|
|
|
|
//
|
|
|
|
// Copyright 2011 The Go Authors. All rights reserved.
|
|
|
|
// Use of this source code is governed by a BSD-style
|
|
|
|
// license that can be found in the LICENSE file.
|
|
|
|
//
|
|
|
|
// This file has been modified from the standard lib to
|
|
|
|
// meet the needs of the application.
|
|
|
|
|
|
|
|
package proxy
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/tls"
|
|
|
|
"io"
|
|
|
|
"net"
|
|
|
|
"net/http"
|
|
|
|
"net/url"
|
|
|
|
"strings"
|
|
|
|
"sync"
|
|
|
|
"time"
|
2016-10-11 10:15:33 +08:00
|
|
|
|
2016-12-22 03:44:07 +08:00
|
|
|
"golang.org/x/net/http2"
|
|
|
|
|
2017-09-12 09:49:02 +08:00
|
|
|
"github.com/lucas-clemente/quic-go"
|
|
|
|
"github.com/lucas-clemente/quic-go/h2quic"
|
2016-10-11 10:15:33 +08:00
|
|
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
2016-06-06 11:51:56 +08:00
|
|
|
)
|
|
|
|
|
2016-12-29 00:17:52 +08:00
|
|
|
var (
|
|
|
|
defaultDialer = &net.Dialer{
|
|
|
|
Timeout: 30 * time.Second,
|
|
|
|
KeepAlive: 30 * time.Second,
|
|
|
|
}
|
2016-12-27 03:52:36 +08:00
|
|
|
|
2016-12-29 00:17:52 +08:00
|
|
|
bufferPool = sync.Pool{New: createBuffer}
|
2017-09-12 09:49:02 +08:00
|
|
|
|
|
|
|
defaultCryptoHandshakeTimeout = 10 * time.Second
|
2016-12-29 00:17:52 +08:00
|
|
|
)
|
2016-07-21 09:06:14 +08:00
|
|
|
|
|
|
|
func createBuffer() interface{} {
|
2016-12-29 00:17:52 +08:00
|
|
|
return make([]byte, 0, 32*1024)
|
|
|
|
}
|
|
|
|
|
|
|
|
func pooledIoCopy(dst io.Writer, src io.Reader) {
|
|
|
|
buf := bufferPool.Get().([]byte)
|
|
|
|
defer bufferPool.Put(buf)
|
|
|
|
|
|
|
|
// CopyBuffer only uses buf up to its length and panics if it's 0.
|
|
|
|
// Due to that we extend buf's length to its capacity here and
|
|
|
|
// ensure it's always non-zero.
|
|
|
|
bufCap := cap(buf)
|
|
|
|
io.CopyBuffer(dst, src, buf[0:bufCap:bufCap])
|
2016-07-21 09:06:14 +08:00
|
|
|
}
|
|
|
|
|
2016-06-06 11:51:56 +08:00
|
|
|
// onExitFlushLoop is a callback set by tests to detect the state of the
|
|
|
|
// flushLoop() goroutine.
|
|
|
|
var onExitFlushLoop func()
|
|
|
|
|
|
|
|
// ReverseProxy is an HTTP Handler that takes an incoming request and
|
|
|
|
// sends it to another server, proxying the response back to the
|
|
|
|
// client.
|
|
|
|
type ReverseProxy struct {
|
|
|
|
// Director must be a function which modifies
|
|
|
|
// the request into a new request to be sent
|
|
|
|
// using Transport. Its response is then copied
|
|
|
|
// back to the original client unmodified.
|
|
|
|
Director func(*http.Request)
|
|
|
|
|
|
|
|
// The transport used to perform proxy requests.
|
|
|
|
// If nil, http.DefaultTransport is used.
|
|
|
|
Transport http.RoundTripper
|
|
|
|
|
|
|
|
// FlushInterval specifies the flush interval
|
|
|
|
// to flush to the client while copying the
|
|
|
|
// response body.
|
|
|
|
// If zero, no periodic flushing is done.
|
|
|
|
FlushInterval time.Duration
|
|
|
|
}
|
|
|
|
|
|
|
|
// Though the relevant directive prefix is just "unix:", url.Parse
|
|
|
|
// will - assuming the regular URL scheme - add additional slashes
|
|
|
|
// as if "unix" was a request protocol.
|
|
|
|
// What we need is just the path, so if "unix:/var/run/www.socket"
|
|
|
|
// was the proxy directive, the parsed hostName would be
|
|
|
|
// "unix:///var/run/www.socket", hence the ambiguous trimming.
|
|
|
|
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
|
|
|
|
return func(network, addr string) (conn net.Conn, err error) {
|
|
|
|
return net.Dial("unix", hostName[len("unix://"):])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-01-26 12:23:50 +08:00
|
|
|
func singleJoiningSlash(a, b string) string {
|
|
|
|
aslash := strings.HasSuffix(a, "/")
|
|
|
|
bslash := strings.HasPrefix(b, "/")
|
|
|
|
switch {
|
|
|
|
case aslash && bslash:
|
|
|
|
return a + b[1:]
|
|
|
|
case !aslash && !bslash && b != "":
|
|
|
|
return a + "/" + b
|
|
|
|
}
|
|
|
|
return a + b
|
|
|
|
}
|
|
|
|
|
2016-06-06 11:51:56 +08:00
|
|
|
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
|
|
|
|
// URLs to the scheme, host, and base path provided in target. If the
|
|
|
|
// target's path is "/base" and the incoming request was for "/dir",
|
|
|
|
// the target request will be for /base/dir.
|
|
|
|
// Without logic: target's path is "/", incoming is "/api/messages",
|
|
|
|
// without is "/api", then the target request will be for /messages.
|
2016-08-02 06:47:31 +08:00
|
|
|
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
|
2016-06-06 11:51:56 +08:00
|
|
|
targetQuery := target.RawQuery
|
|
|
|
director := func(req *http.Request) {
|
|
|
|
if target.Scheme == "unix" {
|
|
|
|
// to make Dial work with unix URL,
|
|
|
|
// scheme and host have to be faked
|
|
|
|
req.URL.Scheme = "http"
|
|
|
|
req.URL.Host = "socket"
|
|
|
|
} else {
|
|
|
|
req.URL.Scheme = target.Scheme
|
|
|
|
req.URL.Host = target.Host
|
|
|
|
}
|
2016-10-02 15:04:57 +08:00
|
|
|
|
httpserver/all: Clean up and standardize request URL handling (#1633)
* httpserver/all: Clean up and standardize request URL handling
The HTTP server now always creates a context value on the request which
is a copy of the request's URL struct. It should not be modified by
middlewares, but it is safe to get the value out of the request and make
changes to it locally-scoped. Thus, the value in the context always
stores the original request URL information as it was received. Any
rewrites that happen will be to the request's URL field directly.
The HTTP server no longer cleans /sanitizes the request URL. It made too
many strong assumptions and ended up making a lot of middleware more
complicated, including upstream proxying (and fastcgi). To alleviate
this complexity, we no longer change the request URL. Middlewares are
responsible to access the disk safely by using http.Dir or, if not
actually opening files, they can use httpserver.SafePath().
I'm hoping this will address issues with #1624, #1584, #1582, and others.
* staticfiles: Fix test on Windows
@abiosoft: I still can't figure out exactly what this is for. 😅
* Use (potentially) changed URL for browse redirects, as before
* Use filepath.ToSlash, clean up a couple proxy test cases
* Oops, fix variable name
2017-05-02 13:11:10 +08:00
|
|
|
// remove the `without` prefix
|
2016-10-02 15:04:57 +08:00
|
|
|
if without != "" {
|
|
|
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
|
|
|
|
if req.URL.Opaque != "" {
|
|
|
|
req.URL.Opaque = strings.TrimPrefix(req.URL.Opaque, without)
|
|
|
|
}
|
|
|
|
if req.URL.RawPath != "" {
|
|
|
|
req.URL.RawPath = strings.TrimPrefix(req.URL.RawPath, without)
|
|
|
|
}
|
2016-06-06 11:51:56 +08:00
|
|
|
}
|
2016-10-02 15:04:57 +08:00
|
|
|
|
2017-01-26 12:23:50 +08:00
|
|
|
// prefer returns val if it isn't empty, otherwise def
|
|
|
|
prefer := func(val, def string) string {
|
|
|
|
if val != "" {
|
|
|
|
return val
|
|
|
|
}
|
|
|
|
return def
|
|
|
|
}
|
httpserver/all: Clean up and standardize request URL handling (#1633)
* httpserver/all: Clean up and standardize request URL handling
The HTTP server now always creates a context value on the request which
is a copy of the request's URL struct. It should not be modified by
middlewares, but it is safe to get the value out of the request and make
changes to it locally-scoped. Thus, the value in the context always
stores the original request URL information as it was received. Any
rewrites that happen will be to the request's URL field directly.
The HTTP server no longer cleans /sanitizes the request URL. It made too
many strong assumptions and ended up making a lot of middleware more
complicated, including upstream proxying (and fastcgi). To alleviate
this complexity, we no longer change the request URL. Middlewares are
responsible to access the disk safely by using http.Dir or, if not
actually opening files, they can use httpserver.SafePath().
I'm hoping this will address issues with #1624, #1584, #1582, and others.
* staticfiles: Fix test on Windows
@abiosoft: I still can't figure out exactly what this is for. 😅
* Use (potentially) changed URL for browse redirects, as before
* Use filepath.ToSlash, clean up a couple proxy test cases
* Oops, fix variable name
2017-05-02 13:11:10 +08:00
|
|
|
|
2017-01-26 12:23:50 +08:00
|
|
|
// Make up the final URL by concatenating the request and target URL.
|
|
|
|
//
|
|
|
|
// If there is encoded part in request or target URL,
|
|
|
|
// the final URL should also be in encoded format.
|
|
|
|
// Here, we concatenate their encoded parts which are stored
|
|
|
|
// in URL.Opaque and URL.RawPath, if it is empty use
|
|
|
|
// URL.Path instead.
|
|
|
|
if req.URL.Opaque != "" || target.Opaque != "" {
|
|
|
|
req.URL.Opaque = singleJoiningSlash(
|
|
|
|
prefer(target.Opaque, target.Path),
|
|
|
|
prefer(req.URL.Opaque, req.URL.Path))
|
2016-10-11 17:06:49 +08:00
|
|
|
}
|
2017-01-26 12:23:50 +08:00
|
|
|
if req.URL.RawPath != "" || target.RawPath != "" {
|
|
|
|
req.URL.RawPath = singleJoiningSlash(
|
|
|
|
prefer(target.RawPath, target.Path),
|
|
|
|
prefer(req.URL.RawPath, req.URL.Path))
|
|
|
|
}
|
2017-04-22 03:02:15 +08:00
|
|
|
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
|
2016-10-11 17:06:49 +08:00
|
|
|
|
2016-06-06 11:51:56 +08:00
|
|
|
// Trims the path of the socket from the URL path.
|
|
|
|
// This is done because req.URL passed to your proxied service
|
|
|
|
// will have the full path of the socket file prefixed to it.
|
|
|
|
// Calling /test on a server that proxies requests to
|
|
|
|
// unix:/var/run/www.socket will thus set the requested path
|
|
|
|
// to /var/run/www.socket/test, rendering paths useless.
|
|
|
|
if target.Scheme == "unix" {
|
|
|
|
// See comment on socketDial for the trim
|
|
|
|
socketPrefix := target.String()[len("unix://"):]
|
|
|
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix)
|
2017-01-26 12:23:50 +08:00
|
|
|
if req.URL.Opaque != "" {
|
|
|
|
req.URL.Opaque = strings.TrimPrefix(req.URL.Opaque, socketPrefix)
|
|
|
|
}
|
|
|
|
if req.URL.RawPath != "" {
|
|
|
|
req.URL.RawPath = strings.TrimPrefix(req.URL.RawPath, socketPrefix)
|
|
|
|
}
|
2016-06-06 11:51:56 +08:00
|
|
|
}
|
2016-10-02 15:04:57 +08:00
|
|
|
|
|
|
|
if targetQuery == "" || req.URL.RawQuery == "" {
|
|
|
|
req.URL.RawQuery = targetQuery + req.URL.RawQuery
|
|
|
|
} else {
|
|
|
|
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
|
2016-06-06 11:51:56 +08:00
|
|
|
}
|
|
|
|
}
|
2017-09-12 09:49:02 +08:00
|
|
|
|
2016-06-06 11:51:56 +08:00
|
|
|
rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
|
|
|
|
if target.Scheme == "unix" {
|
|
|
|
rp.Transport = &http.Transport{
|
|
|
|
Dial: socketDial(target.String()),
|
|
|
|
}
|
2017-09-12 09:49:02 +08:00
|
|
|
} else if target.Scheme == "quic" {
|
|
|
|
rp.Transport = &h2quic.RoundTripper{
|
|
|
|
QuicConfig: &quic.Config{
|
|
|
|
HandshakeTimeout: defaultCryptoHandshakeTimeout,
|
|
|
|
},
|
|
|
|
}
|
2016-08-06 06:41:32 +08:00
|
|
|
} else if keepalive != http.DefaultMaxIdleConnsPerHost {
|
|
|
|
// if keepalive is equal to the default,
|
|
|
|
// just use default transport, to avoid creating
|
|
|
|
// a brand new transport
|
2016-12-22 03:44:07 +08:00
|
|
|
transport := &http.Transport{
|
2016-12-27 03:52:36 +08:00
|
|
|
Proxy: http.ProxyFromEnvironment,
|
|
|
|
Dial: defaultDialer.Dial,
|
2017-09-12 09:49:02 +08:00
|
|
|
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
|
2016-08-02 06:47:31 +08:00
|
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
|
|
}
|
2016-08-06 06:41:32 +08:00
|
|
|
if keepalive == 0 {
|
2016-12-22 03:44:07 +08:00
|
|
|
transport.DisableKeepAlives = true
|
2016-08-02 06:47:31 +08:00
|
|
|
} else {
|
2016-12-22 03:44:07 +08:00
|
|
|
transport.MaxIdleConnsPerHost = keepalive
|
2016-08-02 06:47:31 +08:00
|
|
|
}
|
2016-12-27 03:40:44 +08:00
|
|
|
if httpserver.HTTP2 {
|
|
|
|
http2.ConfigureTransport(transport)
|
|
|
|
}
|
2016-12-22 03:44:07 +08:00
|
|
|
rp.Transport = transport
|
2016-06-06 11:51:56 +08:00
|
|
|
}
|
|
|
|
return rp
|
|
|
|
}
|
|
|
|
|
2016-08-07 04:46:52 +08:00
|
|
|
// UseInsecureTransport is used to facilitate HTTPS proxying
|
2016-08-02 06:47:31 +08:00
|
|
|
// when it is OK for upstream to be using a bad certificate,
|
|
|
|
// since this transport skips verification.
|
|
|
|
func (rp *ReverseProxy) UseInsecureTransport() {
|
|
|
|
if rp.Transport == nil {
|
2016-12-22 03:44:07 +08:00
|
|
|
transport := &http.Transport{
|
2016-12-27 03:52:36 +08:00
|
|
|
Proxy: http.ProxyFromEnvironment,
|
|
|
|
Dial: defaultDialer.Dial,
|
2017-09-12 09:49:02 +08:00
|
|
|
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
|
2016-08-02 06:47:31 +08:00
|
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
|
|
}
|
2016-12-27 03:40:44 +08:00
|
|
|
if httpserver.HTTP2 {
|
|
|
|
http2.ConfigureTransport(transport)
|
|
|
|
}
|
2016-12-22 03:44:07 +08:00
|
|
|
rp.Transport = transport
|
2016-08-02 06:47:31 +08:00
|
|
|
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
2017-03-11 01:41:37 +08:00
|
|
|
if transport.TLSClientConfig == nil {
|
|
|
|
transport.TLSClientConfig = &tls.Config{}
|
|
|
|
}
|
|
|
|
transport.TLSClientConfig.InsecureSkipVerify = true
|
2016-12-27 03:40:44 +08:00
|
|
|
// No http2.ConfigureTransport() here.
|
|
|
|
// For now this is only added in places where
|
|
|
|
// an http.Transport is actually created.
|
2017-09-12 09:49:02 +08:00
|
|
|
} else if transport, ok := rp.Transport.(*h2quic.RoundTripper); ok {
|
|
|
|
if transport.TLSClientConfig == nil {
|
|
|
|
transport.TLSClientConfig = &tls.Config{}
|
|
|
|
}
|
|
|
|
transport.TLSClientConfig.InsecureSkipVerify = true
|
2016-08-02 06:47:31 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-08-07 04:46:52 +08:00
|
|
|
// ServeHTTP serves the proxied request to the upstream by performing a roundtrip.
|
|
|
|
// It is designed to handle websocket connection upgrades as well.
|
|
|
|
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
|
|
|
|
transport := rp.Transport
|
|
|
|
if requestIsWebsocket(outreq) {
|
|
|
|
transport = newConnHijackerTransport(transport)
|
|
|
|
} else if transport == nil {
|
|
|
|
transport = http.DefaultTransport
|
|
|
|
}
|
|
|
|
|
|
|
|
rp.Director(outreq)
|
|
|
|
|
2017-09-12 09:49:02 +08:00
|
|
|
if outreq.URL.Scheme == "quic" {
|
|
|
|
outreq.URL.Scheme = "https" // Change scheme back to https for QUIC RoundTripper
|
|
|
|
}
|
|
|
|
|
2016-08-07 04:46:52 +08:00
|
|
|
res, err := transport.RoundTrip(outreq)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2016-12-31 01:13:14 +08:00
|
|
|
isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket"
|
|
|
|
|
|
|
|
// Remove hop-by-hop headers listed in the
|
|
|
|
// "Connection" header of the response.
|
|
|
|
if c := res.Header.Get("Connection"); c != "" {
|
|
|
|
for _, f := range strings.Split(c, ",") {
|
|
|
|
if f = strings.TrimSpace(f); f != "" {
|
|
|
|
res.Header.Del(f)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, h := range hopHeaders {
|
|
|
|
res.Header.Del(h)
|
|
|
|
}
|
|
|
|
|
2016-08-07 04:46:52 +08:00
|
|
|
if respUpdateFn != nil {
|
|
|
|
respUpdateFn(res)
|
|
|
|
}
|
2016-12-31 01:13:14 +08:00
|
|
|
|
|
|
|
if isWebsocket {
|
2017-07-25 15:12:38 +08:00
|
|
|
defer res.Body.Close()
|
2016-08-07 04:46:52 +08:00
|
|
|
hj, ok := rw.(http.Hijacker)
|
|
|
|
if !ok {
|
2016-10-11 10:15:33 +08:00
|
|
|
panic(httpserver.NonHijackerError{Underlying: rw})
|
2016-08-07 04:46:52 +08:00
|
|
|
}
|
|
|
|
|
2017-01-17 22:55:11 +08:00
|
|
|
conn, brw, err := hj.Hijack()
|
2016-08-07 04:46:52 +08:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer conn.Close()
|
|
|
|
|
|
|
|
var backendConn net.Conn
|
|
|
|
if hj, ok := transport.(*connHijackerTransport); ok {
|
|
|
|
backendConn = hj.Conn
|
|
|
|
if _, err := conn.Write(hj.Replay); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
bufferPool.Put(hj.Replay)
|
|
|
|
} else {
|
|
|
|
backendConn, err = net.Dial("tcp", outreq.URL.Host)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
outreq.Write(backendConn)
|
|
|
|
}
|
|
|
|
defer backendConn.Close()
|
|
|
|
|
2017-09-23 08:10:48 +08:00
|
|
|
proxyDone := make(chan struct{}, 2)
|
|
|
|
|
2017-01-17 22:55:11 +08:00
|
|
|
// Proxy backend -> frontend.
|
2017-09-23 08:10:48 +08:00
|
|
|
go func() {
|
|
|
|
pooledIoCopy(conn, backendConn)
|
|
|
|
proxyDone <- struct{}{}
|
|
|
|
}()
|
2017-01-17 22:55:11 +08:00
|
|
|
|
|
|
|
// Proxy frontend -> backend.
|
|
|
|
//
|
|
|
|
// NOTE: Hijack() sometimes returns buffered up bytes in brw which
|
|
|
|
// would be lost if we didn't read them out manually below.
|
|
|
|
if brw != nil {
|
|
|
|
if n := brw.Reader.Buffered(); n > 0 {
|
|
|
|
rbuf, err := brw.Reader.Peek(n)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
backendConn.Write(rbuf)
|
|
|
|
}
|
|
|
|
}
|
2017-09-23 08:10:48 +08:00
|
|
|
go func() {
|
|
|
|
pooledIoCopy(backendConn, conn)
|
|
|
|
proxyDone <- struct{}{}
|
|
|
|
}()
|
|
|
|
|
|
|
|
// If one side is done, we are done.
|
|
|
|
<-proxyDone
|
2016-08-07 04:46:52 +08:00
|
|
|
} else {
|
2017-05-14 00:08:33 +08:00
|
|
|
// NOTE:
|
|
|
|
// Closing the Body involves acquiring a mutex, which is a
|
|
|
|
// unnecessarily heavy operation, considering that this defer will
|
|
|
|
// pretty much never be executed with the Body still unclosed.
|
|
|
|
bodyOpen := true
|
|
|
|
closeBody := func() {
|
|
|
|
if bodyOpen {
|
|
|
|
res.Body.Close()
|
|
|
|
bodyOpen = false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
defer closeBody()
|
|
|
|
|
|
|
|
// Copy all headers over.
|
|
|
|
// res.Header does not include the "Trailer" header,
|
|
|
|
// which means we will have to do that manually below.
|
2016-08-07 04:46:52 +08:00
|
|
|
copyHeader(rw.Header(), res.Header)
|
2016-12-31 01:13:14 +08:00
|
|
|
|
2017-05-14 00:08:33 +08:00
|
|
|
// The "Trailer" header isn't included in res' Header map, which
|
|
|
|
// is why we have to build one ourselves from res.Trailer.
|
|
|
|
//
|
|
|
|
// But res.Trailer does not necessarily contain all trailer keys at this
|
|
|
|
// point yet. The HTTP spec allows one to send "unannounced trailers"
|
|
|
|
// after a request and certain systems like gRPC make use of that.
|
|
|
|
announcedTrailerKeyCount := len(res.Trailer)
|
|
|
|
if announcedTrailerKeyCount > 0 {
|
|
|
|
vv := make([]string, 0, announcedTrailerKeyCount)
|
2016-12-31 01:13:14 +08:00
|
|
|
for k := range res.Trailer {
|
2017-05-14 00:08:33 +08:00
|
|
|
vv = append(vv, k)
|
2016-12-31 01:13:14 +08:00
|
|
|
}
|
2017-05-14 00:08:33 +08:00
|
|
|
rw.Header()["Trailer"] = vv
|
2016-12-31 01:13:14 +08:00
|
|
|
}
|
|
|
|
|
2017-05-14 00:08:33 +08:00
|
|
|
// Now copy over the status code as well as the response body.
|
2016-08-07 04:46:52 +08:00
|
|
|
rw.WriteHeader(res.StatusCode)
|
2017-05-14 00:08:33 +08:00
|
|
|
if announcedTrailerKeyCount > 0 {
|
2016-12-31 01:13:14 +08:00
|
|
|
// Force chunking if we saw a response trailer.
|
2017-05-14 00:08:33 +08:00
|
|
|
// This prevents net/http from calculating the length
|
|
|
|
// for short bodies and adding a Content-Length.
|
2016-12-31 01:13:14 +08:00
|
|
|
if fl, ok := rw.(http.Flusher); ok {
|
|
|
|
fl.Flush()
|
|
|
|
}
|
|
|
|
}
|
2016-08-07 04:46:52 +08:00
|
|
|
rp.copyResponse(rw, res.Body)
|
2017-05-14 00:08:33 +08:00
|
|
|
|
|
|
|
// Now close the body to fully populate res.Trailer.
|
|
|
|
closeBody()
|
|
|
|
|
|
|
|
// Since Go does not remove keys from res.Trailer we
|
|
|
|
// can safely do a length comparison to check wether
|
|
|
|
// we received further, unannounced trailers.
|
|
|
|
//
|
|
|
|
// Most of the time forceSetTrailers should be false.
|
|
|
|
forceSetTrailers := len(res.Trailer) != announcedTrailerKeyCount
|
|
|
|
shallowCopyTrailers(rw.Header(), res.Trailer, forceSetTrailers)
|
2016-08-07 04:46:52 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
|
|
|
if rp.FlushInterval != 0 {
|
|
|
|
if wf, ok := dst.(writeFlusher); ok {
|
|
|
|
mlw := &maxLatencyWriter{
|
|
|
|
dst: wf,
|
|
|
|
latency: rp.FlushInterval,
|
|
|
|
done: make(chan bool),
|
|
|
|
}
|
|
|
|
go mlw.flushLoop()
|
|
|
|
defer mlw.stop()
|
|
|
|
dst = mlw
|
|
|
|
}
|
|
|
|
}
|
2016-12-29 00:17:52 +08:00
|
|
|
pooledIoCopy(dst, src)
|
2016-08-07 04:46:52 +08:00
|
|
|
}
|
|
|
|
|
2016-09-17 07:42:42 +08:00
|
|
|
// skip these headers if they already exist.
|
|
|
|
// see https://github.com/mholt/caddy/pull/1112#discussion_r80092582
|
|
|
|
var skipHeaders = map[string]struct{}{
|
|
|
|
"Content-Type": {},
|
|
|
|
"Content-Disposition": {},
|
|
|
|
"Accept-Ranges": {},
|
|
|
|
"Set-Cookie": {},
|
|
|
|
"Cache-Control": {},
|
|
|
|
"Expires": {},
|
|
|
|
}
|
|
|
|
|
2016-06-06 11:51:56 +08:00
|
|
|
func copyHeader(dst, src http.Header) {
|
|
|
|
for k, vv := range src {
|
2016-09-17 07:42:42 +08:00
|
|
|
if _, ok := dst[k]; ok {
|
|
|
|
// skip some predefined headers
|
|
|
|
// see https://github.com/mholt/caddy/issues/1086
|
|
|
|
if _, shouldSkip := skipHeaders[k]; shouldSkip {
|
|
|
|
continue
|
|
|
|
}
|
2017-06-28 02:10:03 +08:00
|
|
|
// otherwise, overwrite to avoid duplicated fields that can be
|
|
|
|
// problematic (see issue #1086) -- however, allow duplicate
|
|
|
|
// Server fields so we can see the reality of the proxying.
|
|
|
|
if k != "Server" {
|
|
|
|
dst.Del(k)
|
|
|
|
}
|
2016-09-17 07:42:42 +08:00
|
|
|
}
|
2016-06-06 11:51:56 +08:00
|
|
|
for _, v := range vv {
|
|
|
|
dst.Add(k, v)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-14 00:08:33 +08:00
|
|
|
// shallowCopyTrailers copies all headers from srcTrailer to dstHeader.
|
|
|
|
//
|
|
|
|
// If forceSetTrailers is set to true, the http.TrailerPrefix will be added to
|
|
|
|
// all srcTrailer key names. Otherwise the Go stdlib will ignore all keys
|
|
|
|
// which weren't listed in the Trailer map before submitting the Response.
|
|
|
|
//
|
|
|
|
// WARNING: Only a shallow copy will be created!
|
|
|
|
func shallowCopyTrailers(dstHeader, srcTrailer http.Header, forceSetTrailers bool) {
|
|
|
|
for k, vv := range srcTrailer {
|
|
|
|
if forceSetTrailers {
|
|
|
|
k = http.TrailerPrefix + k
|
|
|
|
}
|
|
|
|
dstHeader[k] = vv
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-06-06 11:51:56 +08:00
|
|
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
|
|
|
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
|
|
|
var hopHeaders = []string{
|
2016-12-31 01:13:14 +08:00
|
|
|
"Alt-Svc",
|
|
|
|
"Alternate-Protocol",
|
2016-06-06 11:51:56 +08:00
|
|
|
"Connection",
|
|
|
|
"Keep-Alive",
|
|
|
|
"Proxy-Authenticate",
|
|
|
|
"Proxy-Authorization",
|
2016-12-31 01:13:14 +08:00
|
|
|
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
|
|
|
"Te", // canonicalized version of "TE"
|
|
|
|
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
|
2016-06-06 11:51:56 +08:00
|
|
|
"Transfer-Encoding",
|
|
|
|
"Upgrade",
|
|
|
|
}
|
|
|
|
|
|
|
|
type respUpdateFn func(resp *http.Response)
|
|
|
|
|
2016-08-02 10:11:31 +08:00
|
|
|
type hijackedConn struct {
|
|
|
|
net.Conn
|
|
|
|
hj *connHijackerTransport
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *hijackedConn) Read(b []byte) (n int, err error) {
|
|
|
|
n, err = c.Conn.Read(b)
|
|
|
|
c.hj.Replay = append(c.hj.Replay, b[:n]...)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *hijackedConn) Close() error {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type connHijackerTransport struct {
|
|
|
|
*http.Transport
|
|
|
|
Conn net.Conn
|
|
|
|
Replay []byte
|
|
|
|
}
|
|
|
|
|
|
|
|
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
2016-12-27 03:52:36 +08:00
|
|
|
t := &http.Transport{
|
2016-08-23 08:58:43 +08:00
|
|
|
MaxIdleConnsPerHost: -1,
|
2016-08-02 10:11:31 +08:00
|
|
|
}
|
2016-12-27 03:52:36 +08:00
|
|
|
if b, _ := base.(*http.Transport); b != nil {
|
2016-12-29 00:20:31 +08:00
|
|
|
tlsClientConfig := b.TLSClientConfig
|
2017-03-11 01:41:37 +08:00
|
|
|
if tlsClientConfig != nil && tlsClientConfig.NextProtos != nil {
|
2017-02-19 06:26:23 +08:00
|
|
|
tlsClientConfig = tlsClientConfig.Clone()
|
2016-12-29 00:20:31 +08:00
|
|
|
tlsClientConfig.NextProtos = nil
|
|
|
|
}
|
|
|
|
|
2016-12-27 03:52:36 +08:00
|
|
|
t.Proxy = b.Proxy
|
2016-12-29 00:20:31 +08:00
|
|
|
t.TLSClientConfig = tlsClientConfig
|
2016-12-27 03:52:36 +08:00
|
|
|
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
|
|
|
t.Dial = b.Dial
|
|
|
|
t.DialTLS = b.DialTLS
|
|
|
|
} else {
|
|
|
|
t.Proxy = http.ProxyFromEnvironment
|
|
|
|
t.TLSHandshakeTimeout = 10 * time.Second
|
|
|
|
}
|
|
|
|
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
|
|
|
|
|
|
|
|
dial := getTransportDial(t)
|
|
|
|
dialTLS := getTransportDialTLS(t)
|
|
|
|
t.Dial = func(network, addr string) (net.Conn, error) {
|
|
|
|
c, err := dial(network, addr)
|
|
|
|
hj.Conn = c
|
|
|
|
return &hijackedConn{c, hj}, err
|
|
|
|
}
|
2016-12-29 00:20:31 +08:00
|
|
|
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
|
|
|
c, err := dialTLS(network, addr)
|
|
|
|
hj.Conn = c
|
|
|
|
return &hijackedConn{c, hj}, err
|
2016-08-02 10:11:31 +08:00
|
|
|
}
|
2016-12-27 03:52:36 +08:00
|
|
|
|
|
|
|
return hj
|
|
|
|
}
|
|
|
|
|
|
|
|
// getTransportDial always returns a plain Dialer
|
|
|
|
// and defaults to the existing t.Dial.
|
|
|
|
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
|
|
|
if t.Dial != nil {
|
|
|
|
return t.Dial
|
|
|
|
}
|
|
|
|
return defaultDialer.Dial
|
|
|
|
}
|
|
|
|
|
2016-12-29 00:20:31 +08:00
|
|
|
// getTransportDial always returns a TLS Dialer
|
2016-12-27 03:52:36 +08:00
|
|
|
// and defaults to the existing t.DialTLS.
|
|
|
|
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
|
|
|
if t.DialTLS != nil {
|
|
|
|
return t.DialTLS
|
2016-08-02 10:11:31 +08:00
|
|
|
}
|
2016-12-27 03:52:36 +08:00
|
|
|
|
|
|
|
// newConnHijackerTransport will modify t.Dial after calling this method
|
|
|
|
// => Create a backup reference.
|
|
|
|
plainDial := getTransportDial(t)
|
|
|
|
|
2016-12-29 00:20:31 +08:00
|
|
|
// The following DialTLS implementation stems from the Go stdlib and
|
|
|
|
// is identical to what happens if DialTLS is not provided.
|
|
|
|
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
|
2016-12-27 03:52:36 +08:00
|
|
|
return func(network, addr string) (net.Conn, error) {
|
|
|
|
plainConn, err := plainDial(network, addr)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2016-08-02 10:11:31 +08:00
|
|
|
}
|
2016-12-27 03:52:36 +08:00
|
|
|
|
2016-12-29 00:20:31 +08:00
|
|
|
tlsClientConfig := t.TLSClientConfig
|
|
|
|
if tlsClientConfig == nil {
|
|
|
|
tlsClientConfig = &tls.Config{}
|
|
|
|
}
|
|
|
|
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
|
|
|
|
tlsClientConfig.ServerName = stripPort(addr)
|
|
|
|
}
|
|
|
|
|
|
|
|
tlsConn := tls.Client(plainConn, tlsClientConfig)
|
2016-12-27 03:52:36 +08:00
|
|
|
errc := make(chan error, 2)
|
|
|
|
var timer *time.Timer
|
|
|
|
if d := t.TLSHandshakeTimeout; d != 0 {
|
|
|
|
timer = time.AfterFunc(d, func() {
|
|
|
|
errc <- tlsHandshakeTimeoutError{}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
go func() {
|
|
|
|
err := tlsConn.Handshake()
|
|
|
|
if timer != nil {
|
|
|
|
timer.Stop()
|
|
|
|
}
|
|
|
|
errc <- err
|
|
|
|
}()
|
|
|
|
if err := <-errc; err != nil {
|
|
|
|
plainConn.Close()
|
|
|
|
return nil, err
|
|
|
|
}
|
2016-12-29 00:20:31 +08:00
|
|
|
if !tlsClientConfig.InsecureSkipVerify {
|
|
|
|
hostname := tlsClientConfig.ServerName
|
|
|
|
if hostname == "" {
|
|
|
|
hostname = stripPort(addr)
|
2016-12-27 03:52:36 +08:00
|
|
|
}
|
2016-12-29 00:20:31 +08:00
|
|
|
if err := tlsConn.VerifyHostname(hostname); err != nil {
|
2016-12-27 03:52:36 +08:00
|
|
|
plainConn.Close()
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return tlsConn, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-12-29 00:20:31 +08:00
|
|
|
// stripPort returns address without its port if it has one and
|
|
|
|
// works with IP addresses as well as hostnames formatted as host:port.
|
|
|
|
//
|
|
|
|
// IPv6 addresses (excluding the port) must be enclosed in
|
|
|
|
// square brackets similar to the requirements of Go's stdlib.
|
|
|
|
func stripPort(address string) string {
|
|
|
|
// Keep in mind that the address might be a IPv6 address
|
|
|
|
// and thus contain a colon, but not have a port.
|
|
|
|
portIdx := strings.LastIndex(address, ":")
|
|
|
|
ipv6Idx := strings.LastIndex(address, "]")
|
|
|
|
if portIdx > ipv6Idx {
|
|
|
|
address = address[:portIdx]
|
|
|
|
}
|
|
|
|
return address
|
|
|
|
}
|
|
|
|
|
2016-12-27 03:52:36 +08:00
|
|
|
type tlsHandshakeTimeoutError struct{}
|
|
|
|
|
|
|
|
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
|
|
|
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
|
|
|
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
|
|
|
|
2016-08-02 10:11:31 +08:00
|
|
|
func requestIsWebsocket(req *http.Request) bool {
|
2016-12-27 03:52:36 +08:00
|
|
|
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
2016-08-02 10:11:31 +08:00
|
|
|
}
|
|
|
|
|
2016-06-06 11:51:56 +08:00
|
|
|
type writeFlusher interface {
|
|
|
|
io.Writer
|
|
|
|
http.Flusher
|
|
|
|
}
|
|
|
|
|
|
|
|
type maxLatencyWriter struct {
|
|
|
|
dst writeFlusher
|
|
|
|
latency time.Duration
|
|
|
|
|
|
|
|
lk sync.Mutex // protects Write + Flush
|
|
|
|
done chan bool
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *maxLatencyWriter) Write(p []byte) (int, error) {
|
|
|
|
m.lk.Lock()
|
|
|
|
defer m.lk.Unlock()
|
|
|
|
return m.dst.Write(p)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *maxLatencyWriter) flushLoop() {
|
|
|
|
t := time.NewTicker(m.latency)
|
|
|
|
defer t.Stop()
|
|
|
|
for {
|
|
|
|
select {
|
|
|
|
case <-m.done:
|
|
|
|
if onExitFlushLoop != nil {
|
|
|
|
onExitFlushLoop()
|
|
|
|
}
|
|
|
|
return
|
|
|
|
case <-t.C:
|
|
|
|
m.lk.Lock()
|
|
|
|
m.dst.Flush()
|
|
|
|
m.lk.Unlock()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *maxLatencyWriter) stop() { m.done <- true }
|