From 6490ff6224a949ac4c63ef184d2c10377889c3b1 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Thu, 21 Jul 2016 18:18:01 -0600 Subject: [PATCH] Adjust proxy headers properly (fixes #916) --- caddyhttp/proxy/proxy.go | 72 ++++++++++----------------------- caddyhttp/proxy/reverseproxy.go | 5 ++- 2 files changed, 25 insertions(+), 52 deletions(-) diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index c4cd19606..5a356d6f6 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -84,7 +84,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { } // this replacer is used to fill in header field values - var replacer httpserver.Replacer + replacer := httpserver.NewReplacer(r, nil, "") // outreq is the request that makes a roundtrip to the backend outreq := createUpstreamRequest(r) @@ -119,16 +119,10 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // set headers for request going upstream if host.UpstreamHeaders != nil { - if replacer == nil { - replacer = httpserver.NewReplacer(r, nil, "") - } - if v, ok := host.UpstreamHeaders["Host"]; ok { - outreq.Host = replacer.Replace(v[len(v)-1]) - } // modify headers for request that will be sent to the upstream host - upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer) - for k, v := range upHeaders { - outreq.Header[k] = v + mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer) + if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 { + outreq.Host = hostHeaders[len(hostHeaders)-1] } } @@ -136,9 +130,6 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // headers coming back downstream var downHeaderUpdateFn respUpdateFn if host.DownstreamHeaders != nil { - if replacer == nil { - replacer = httpserver.NewReplacer(r, nil, "") - } downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) } @@ -185,6 +176,8 @@ func (p Proxy) match(r *http.Request) Upstream { // createUpstremRequest shallow-copies r into a new request // that can be sent upstream. +// +// Derived from reverseproxy.go in the standard Go httputil package. func createUpstreamRequest(r *http.Request) *http.Request { outreq := new(http.Request) *outreq = *r // includes shallow copies of maps, but okay @@ -199,10 +192,14 @@ func createUpstreamRequest(r *http.Request) *http.Request { // connection, regardless of what the client sent to us. This // is modifying the same underlying map from r (shallow // copied above) so we only copy it if necessary. + var copiedHeaders bool for _, h := range hopHeaders { if outreq.Header.Get(h) != "" { - outreq.Header = make(http.Header) - copyHeader(outreq.Header, r.Header) + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, r.Header) + copiedHeaders = true + } outreq.Header.Del(h) } } @@ -222,45 +219,20 @@ func createUpstreamRequest(r *http.Request) *http.Request { func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn { return func(resp *http.Response) { - newHeaders := createHeadersByRules(rules, resp.Header, replacer) - for h, v := range newHeaders { - resp.Header[h] = v - } + mutateHeadersByRules(resp.Header, rules, replacer) } } -func createHeadersByRules(rules http.Header, base http.Header, repl httpserver.Replacer) http.Header { - newHeaders := make(http.Header) - for header, values := range rules { - if strings.HasPrefix(header, "+") { - header = strings.TrimLeft(header, "+") - add(newHeaders, header, base[header]) - applyEach(values, repl.Replace) - add(newHeaders, header, values) - } else if strings.HasPrefix(header, "-") { - base.Del(strings.TrimLeft(header, "-")) - } else if _, ok := base[header]; ok { - applyEach(values, repl.Replace) - for _, v := range values { - newHeaders.Set(header, v) +func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) { + for ruleField, ruleValues := range rules { + if strings.HasPrefix(ruleField, "+") { + for _, ruleValue := range ruleValues { + headers.Add(strings.TrimPrefix(ruleField, "+"), repl.Replace(ruleValue)) } - } else { - applyEach(values, repl.Replace) - add(newHeaders, header, values) - add(newHeaders, header, base[header]) + } else if strings.HasPrefix(ruleField, "-") { + headers.Del(strings.TrimPrefix(ruleField, "-")) + } else if len(ruleValues) > 0 { + headers.Set(ruleField, repl.Replace(ruleValues[len(ruleValues)-1])) } } - return newHeaders -} - -func applyEach(values []string, mapFn func(string) string) { - for i, v := range values { - values[i] = mapFn(v) - } -} - -func add(base http.Header, header string, values []string) { - for _, v := range values { - base.Add(header, v) - } } diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index c92746dc8..e6f759dd5 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -177,10 +177,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r res, err := transport.RoundTrip(outreq) if err != nil { return err - } else if respUpdateFn != nil { - respUpdateFn(res) } + if respUpdateFn != nil { + respUpdateFn(res) + } if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { res.Body.Close() hj, ok := rw.(http.Hijacker)