diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index c7db7a4b6..e96898bd8 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -13,6 +13,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "runtime" "strings" "sync/atomic" @@ -407,16 +408,19 @@ func TestUpstreamHeadersUpdate(t *testing.T) { replacer := httpserver.NewReplacer(r, nil, "") headerKey := "Merge-Me" - values, ok := actualHeaders[headerKey] - if !ok { - t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey) - } else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) { - t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values) + got := actualHeaders[headerKey] + expect := []string{"Initial", "Merge-Value"} + if !reflect.DeepEqual(got, expect) { + t.Errorf("Request sent to upstream backend does not contain expected %v header: expect %v, but got %v", + headerKey, expect, got) } headerKey = "Add-Me" - if _, ok := actualHeaders[headerKey]; !ok { - t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey) + got = actualHeaders[headerKey] + expect = []string{"Add-Value"} + if !reflect.DeepEqual(got, expect) { + t.Errorf("Request sent to upstream backend does not contain expected %v header: expect %v, but got %v", + headerKey, expect, got) } headerKey = "Remove-Me" @@ -425,12 +429,11 @@ func TestUpstreamHeadersUpdate(t *testing.T) { } headerKey = "Replace-Me" - headerValue := replacer.Replace("{hostname}") - value, ok := actualHeaders[headerKey] - if !ok { - t.Errorf("Request sent to upstream backend should not remove %v header", headerKey) - } else if len(value) > 0 && headerValue != value[0] { - t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value) + got = actualHeaders[headerKey] + expect = []string{replacer.Replace("{hostname}")} + if !reflect.DeepEqual(got, expect) { + t.Errorf("Request sent to upstream backend does not contain expected %v header: expect %v, but got %v", + headerKey, expect, got) } if actualHost != expectHost { @@ -447,6 +450,8 @@ func TestDownstreamHeadersUpdate(t *testing.T) { w.Header().Add("Merge-Me", "Initial") w.Header().Add("Remove-Me", "Remove-Value") w.Header().Add("Replace-Me", "Replace-Value") + w.Header().Add("Content-Type", "text/html") + w.Header().Add("Overwrite-Me", "Overwrite-Value") w.Write([]byte("Hello, client")) })) defer backend.Close() @@ -470,6 +475,10 @@ func TestDownstreamHeadersUpdate(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } w := httptest.NewRecorder() + // set a predefined skip header + w.Header().Set("Content-Type", "text/css") + // set a predefined overwritten header + w.Header().Set("Overwrite-Me", "Initial") p.ServeHTTP(w, r) @@ -477,16 +486,19 @@ func TestDownstreamHeadersUpdate(t *testing.T) { actualHeaders := w.Header() headerKey := "Merge-Me" - values, ok := actualHeaders[headerKey] - if !ok { - t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey) - } else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) { - t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values) + got := actualHeaders[headerKey] + expect := []string{"Initial", "Merge-Value"} + if !reflect.DeepEqual(got, expect) { + t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v", + headerKey, expect, got) } headerKey = "Add-Me" - if _, ok := actualHeaders[headerKey]; !ok { - t.Errorf("Downstream response does not contain expected %v header", headerKey) + got = actualHeaders[headerKey] + expect = []string{"Add-Value"} + if !reflect.DeepEqual(got, expect) { + t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v", + headerKey, expect, got) } headerKey = "Remove-Me" @@ -495,14 +507,28 @@ func TestDownstreamHeadersUpdate(t *testing.T) { } headerKey = "Replace-Me" - headerValue := replacer.Replace("{hostname}") - value, ok := actualHeaders[headerKey] - if !ok { - t.Errorf("Downstream response should contain %v header and not remove it", headerKey) - } else if len(value) > 0 && headerValue != value[0] { - t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value) + got = actualHeaders[headerKey] + expect = []string{replacer.Replace("{hostname}")} + if !reflect.DeepEqual(got, expect) { + t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v", + headerKey, expect, got) } + headerKey = "Content-Type" + got = actualHeaders[headerKey] + expect = []string{"text/css"} + if !reflect.DeepEqual(got, expect) { + t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v", + headerKey, expect, got) + } + + headerKey = "Overwrite-Me" + got = actualHeaders[headerKey] + expect = []string{"Overwrite-Value"} + if !reflect.DeepEqual(got, expect) { + t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v", + headerKey, expect, got) + } } var ( diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 5038afcb1..b355f5c69 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -252,8 +252,28 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { io.CopyBuffer(dst, src, buf.([]byte)) } +// skip these headers if they already exist. +// see https://github.com/mholt/caddy/pull/1112#discussion_r80092582 +var skipHeaders = map[string]struct{}{ + "Content-Type": {}, + "Content-Disposition": {}, + "Accept-Ranges": {}, + "Set-Cookie": {}, + "Cache-Control": {}, + "Expires": {}, +} + func copyHeader(dst, src http.Header) { for k, vv := range src { + if _, ok := dst[k]; ok { + // skip some predefined headers + // see https://github.com/mholt/caddy/issues/1086 + if _, shouldSkip := skipHeaders[k]; shouldSkip { + continue + } + // otherwise, overwrite + dst.Del(k) + } for _, v := range vv { dst.Add(k, v) }