diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index e68cf3a22..6ce51192c 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -43,7 +43,8 @@ type UpstreamHost struct { Fails int32 FailTimeout time.Duration Unhealthy bool - ExtraHeaders http.Header + UpstreamHeaders http.Header + DownstreamHeaders http.Header CheckDown UpstreamHostDownFunc WithoutPathPrefix string MaxConns int64 @@ -99,26 +100,33 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { } outreq.Host = host.Name - if host.ExtraHeaders != nil { - extraHeaders := make(http.Header) + if host.UpstreamHeaders != nil { if replacer == nil { rHost := r.Host replacer = middleware.NewReplacer(r, nil, "") outreq.Host = rHost } - for header, values := range host.ExtraHeaders { - for _, value := range values { - extraHeaders.Add(header, replacer.Replace(value)) - if header == "Host" { - outreq.Host = replacer.Replace(value) - } - } + if v, ok := host.UpstreamHeaders["Host"]; ok { + r.Host = replacer.Replace(v[len(v)-1]) } - for k, v := range extraHeaders { + // Modify headers for request that will be sent to the upstream host + upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer) + for k, v := range upHeaders { outreq.Header[k] = v } } + var downHeaderUpdateFn respUpdateFn + if host.DownstreamHeaders != nil { + if replacer == nil { + rHost := r.Host + replacer = middleware.NewReplacer(r, nil, "") + outreq.Host = rHost + } + //Creates a function that is used to update headers the response received by the reverse proxy + downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) + } + proxy := host.ReverseProxy if baseURL, err := url.Parse(host.Name); err == nil { r.Host = baseURL.Host @@ -130,7 +138,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { } atomic.AddInt64(&host.Conns, 1) - backendErr := proxy.ServeHTTP(w, outreq) + backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) atomic.AddInt64(&host.Conns, -1) if backendErr == nil { return 0, nil @@ -182,3 +190,48 @@ func createUpstreamRequest(r *http.Request) *http.Request { return outreq } + +func createRespHeaderUpdateFn(rules http.Header, replacer middleware.Replacer) respUpdateFn { + return func(resp *http.Response) { + newHeaders := createHeadersByRules(rules, resp.Header, replacer) + for h, v := range newHeaders { + resp.Header[h] = v + } + } +} + +func createHeadersByRules(rules http.Header, base http.Header, repl middleware.Replacer) http.Header { + newHeaders := make(http.Header) + for header, values := range rules { + if strings.HasPrefix(header, "+") { + header = strings.TrimLeft(header, "+") + add(newHeaders, header, base[header]) + applyEach(values, repl.Replace) + add(newHeaders, header, values) + } else if strings.HasPrefix(header, "-") { + base.Del(strings.TrimLeft(header, "-")) + } else if _, ok := base[header]; ok { + applyEach(values, repl.Replace) + for _, v := range values { + newHeaders.Set(header, v) + } + } else { + applyEach(values, repl.Replace) + add(newHeaders, header, values) + add(newHeaders, header, base[header]) + } + } + return newHeaders +} + +func applyEach(values []string, mapFn func(string) string) { + for i, v := range values { + values[i] = mapFn(v) + } +} + +func add(base http.Header, header string, values []string) { + for _, v := range values { + base.Add(header, v) + } +} diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 6e9c3c426..b07ff7df9 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -348,6 +348,141 @@ func TestUnixSocketProxyPaths(t *testing.T) { } } +func TestUpstreamHeadersUpdate(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var actualHeaders http.Header + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, client")) + actualHeaders = r.Header + })) + defer backend.Close() + + upstream := newFakeUpstream(backend.URL, false) + upstream.host.UpstreamHeaders = http.Header{ + "Connection": {"{>Connection}"}, + "Upgrade": {"{>Upgrade}"}, + "+Merge-Me": {"Merge-Value"}, + "+Add-Me": {"Add-Value"}, + "-Remove-Me": {""}, + "Replace-Me": {"{hostname}"}, + } + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{upstream}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + //add initial headers + r.Header.Add("Merge-Me", "Initial") + r.Header.Add("Remove-Me", "Remove-Value") + r.Header.Add("Replace-Me", "Replace-Value") + + p.ServeHTTP(w, r) + + replacer := middleware.NewReplacer(r, nil, "") + + headerKey := "Merge-Me" + values, ok := actualHeaders[headerKey] + if !ok { + t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey) + } else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) { + t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values) + } + + headerKey = "Add-Me" + if _, ok := actualHeaders[headerKey]; !ok { + t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey) + } + + headerKey = "Remove-Me" + if _, ok := actualHeaders[headerKey]; ok { + t.Errorf("Request sent to upstream backend should not contain %v header", headerKey) + } + + headerKey = "Replace-Me" + headerValue := replacer.Replace("{hostname}") + value, ok := actualHeaders[headerKey] + if !ok { + t.Errorf("Request sent to upstream backend should not remove %v header", headerKey) + } else if len(value) > 0 && headerValue != value[0] { + t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value) + } + +} + +func TestDownstreamHeadersUpdate(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Merge-Me", "Initial") + w.Header().Add("Remove-Me", "Remove-Value") + w.Header().Add("Replace-Me", "Replace-Value") + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + upstream := newFakeUpstream(backend.URL, false) + upstream.host.DownstreamHeaders = http.Header{ + "+Merge-Me": {"Merge-Value"}, + "+Add-Me": {"Add-Value"}, + "-Remove-Me": {""}, + "Replace-Me": {"{hostname}"}, + } + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{upstream}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + replacer := middleware.NewReplacer(r, nil, "") + actualHeaders := w.Header() + + headerKey := "Merge-Me" + values, ok := actualHeaders[headerKey] + if !ok { + t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey) + } else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) { + t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values) + } + + headerKey = "Add-Me" + if _, ok := actualHeaders[headerKey]; !ok { + t.Errorf("Downstream response does not contain expected %v header", headerKey) + } + + headerKey = "Remove-Me" + if _, ok := actualHeaders[headerKey]; ok { + t.Errorf("Downstream response should not contain %v header received from upstream", headerKey) + } + + headerKey = "Replace-Me" + headerValue := replacer.Replace("{hostname}") + value, ok := actualHeaders[headerKey] + if !ok { + t.Errorf("Downstream response should contain %v header and not remove it", headerKey) + } else if len(value) > 0 && headerValue != value[0] { + t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value) + } + +} + func newFakeUpstream(name string, insecure bool) *fakeUpstream { uri, _ := url.Parse(name) u := &fakeUpstream{ @@ -410,7 +545,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost { return &UpstreamHost{ Name: u.name, ReverseProxy: NewSingleHostReverseProxy(uri, u.without), - ExtraHeaders: http.Header{ + UpstreamHeaders: http.Header{ "Connection": {"{>Connection}"}, "Upgrade": {"{>Upgrade}"}}, } diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index fd630876c..5a14aea79 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -154,7 +154,9 @@ var InsecureTransport http.RoundTripper = &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) error { +type respUpdateFn func(resp *http.Response) + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error { transport := p.Transport if transport == nil { transport = http.DefaultTransport @@ -169,6 +171,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) e res, err := transport.RoundTrip(outreq) if err != nil { return err + } else if respUpdateFn != nil { + respUpdateFn(res) } if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index f98f1482e..e28db6437 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -20,7 +20,8 @@ var ( type staticUpstream struct { from string - proxyHeaders http.Header + upstreamHeaders http.Header + downstreamHeaders http.Header Hosts HostPool Policy Policy insecureSkipVerify bool @@ -42,13 +43,14 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { var upstreams []Upstream for c.Next() { upstream := &staticUpstream{ - from: "", - proxyHeaders: make(http.Header), - Hosts: nil, - Policy: &Random{}, - FailTimeout: 10 * time.Second, - MaxFails: 1, - MaxConns: 0, + from: "", + upstreamHeaders: make(http.Header), + downstreamHeaders: make(http.Header), + Hosts: nil, + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + MaxConns: 0, } if !c.Args(&upstream.from) { @@ -97,12 +99,13 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { host = "http://" + host } uh := &UpstreamHost{ - Name: host, - Conns: 0, - Fails: 0, - FailTimeout: u.FailTimeout, - Unhealthy: false, - ExtraHeaders: u.proxyHeaders, + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: u.FailTimeout, + Unhealthy: false, + UpstreamHeaders: u.upstreamHeaders, + DownstreamHeaders: u.downstreamHeaders, CheckDown: func(u *staticUpstream) UpstreamHostDownFunc { return func(uh *UpstreamHost) bool { if uh.Unhealthy { @@ -182,15 +185,23 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error { } u.HealthCheck.Interval = dur } + case "header_upstream": + fallthrough case "proxy_header": var header, value string if !c.Args(&header, &value) { return c.ArgErr() } - u.proxyHeaders.Add(header, value) + u.upstreamHeaders.Add(header, value) + case "header_downstream": + var header, value string + if !c.Args(&header, &value) { + return c.ArgErr() + } + u.downstreamHeaders.Add(header, value) case "websocket": - u.proxyHeaders.Add("Connection", "{>Connection}") - u.proxyHeaders.Add("Upgrade", "{>Upgrade}") + u.upstreamHeaders.Add("Connection", "{>Connection}") + u.upstreamHeaders.Add("Upgrade", "{>Upgrade}") case "without": if !c.NextArg() { return c.ArgErr()