mirror of
https://github.com/caddyserver/caddy.git
synced 2024-11-26 02:09:47 +08:00
* Overwrite proxy headers based on directive Headers of the request sent by the proxy upstream can now be modified in the following way: Prefix header with `+`: Header will be added if it doesn't exist otherwise, the values will be merge Prefix header with `-': Header will be removed No prefix: Header will be replaced with given value * Add missing formating directive reported by go vet * Overwrite up/down stream proxy headers Add Up/DownStreamHeaders to UpstreamHost Split `proxy_header` option in `proxy` directive into `header_upstream` and `header_downstream`. By splitting into two, it makes it clear in what direction the given headers must be applied. `proxy_header` can still be used (to maintain backward compatability) but its assumed to be `header_upstream` Response headers received by the reverse proxy from the upstream host are updated according the `header_downstream` rules. The update occurs through a func given to the reverse proxy, which is applied once a response is received. Headers (for upstream and downstream) can now be modified in the following way: Prefix header with `+`: Header will be added if it doesn't exist otherwise, the values will be merge Prefix header with `-': Header will be removed No prefix: Header will be replaced with given value Updated branch with changes from master * minor refactor to make intent clearer * Make Up/Down stream headers naming consistent * Fix error descriptions to be more clear * Fix lint issue
This commit is contained in:
parent
96425f0f40
commit
e2234497b7
|
@ -43,7 +43,8 @@ type UpstreamHost struct {
|
|||
Fails int32
|
||||
FailTimeout time.Duration
|
||||
Unhealthy bool
|
||||
ExtraHeaders http.Header
|
||||
UpstreamHeaders http.Header
|
||||
DownstreamHeaders http.Header
|
||||
CheckDown UpstreamHostDownFunc
|
||||
WithoutPathPrefix string
|
||||
MaxConns int64
|
||||
|
@ -99,26 +100,33 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
}
|
||||
|
||||
outreq.Host = host.Name
|
||||
if host.ExtraHeaders != nil {
|
||||
extraHeaders := make(http.Header)
|
||||
if host.UpstreamHeaders != nil {
|
||||
if replacer == nil {
|
||||
rHost := r.Host
|
||||
replacer = middleware.NewReplacer(r, nil, "")
|
||||
outreq.Host = rHost
|
||||
}
|
||||
for header, values := range host.ExtraHeaders {
|
||||
for _, value := range values {
|
||||
extraHeaders.Add(header, replacer.Replace(value))
|
||||
if header == "Host" {
|
||||
outreq.Host = replacer.Replace(value)
|
||||
if v, ok := host.UpstreamHeaders["Host"]; ok {
|
||||
r.Host = replacer.Replace(v[len(v)-1])
|
||||
}
|
||||
}
|
||||
}
|
||||
for k, v := range extraHeaders {
|
||||
// Modify headers for request that will be sent to the upstream host
|
||||
upHeaders := createHeadersByRules(host.UpstreamHeaders, r.Header, replacer)
|
||||
for k, v := range upHeaders {
|
||||
outreq.Header[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
var downHeaderUpdateFn respUpdateFn
|
||||
if host.DownstreamHeaders != nil {
|
||||
if replacer == nil {
|
||||
rHost := r.Host
|
||||
replacer = middleware.NewReplacer(r, nil, "")
|
||||
outreq.Host = rHost
|
||||
}
|
||||
//Creates a function that is used to update headers the response received by the reverse proxy
|
||||
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
|
||||
}
|
||||
|
||||
proxy := host.ReverseProxy
|
||||
if baseURL, err := url.Parse(host.Name); err == nil {
|
||||
r.Host = baseURL.Host
|
||||
|
@ -130,7 +138,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
}
|
||||
|
||||
atomic.AddInt64(&host.Conns, 1)
|
||||
backendErr := proxy.ServeHTTP(w, outreq)
|
||||
backendErr := proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
|
||||
atomic.AddInt64(&host.Conns, -1)
|
||||
if backendErr == nil {
|
||||
return 0, nil
|
||||
|
@ -182,3 +190,48 @@ func createUpstreamRequest(r *http.Request) *http.Request {
|
|||
|
||||
return outreq
|
||||
}
|
||||
|
||||
func createRespHeaderUpdateFn(rules http.Header, replacer middleware.Replacer) respUpdateFn {
|
||||
return func(resp *http.Response) {
|
||||
newHeaders := createHeadersByRules(rules, resp.Header, replacer)
|
||||
for h, v := range newHeaders {
|
||||
resp.Header[h] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createHeadersByRules(rules http.Header, base http.Header, repl middleware.Replacer) http.Header {
|
||||
newHeaders := make(http.Header)
|
||||
for header, values := range rules {
|
||||
if strings.HasPrefix(header, "+") {
|
||||
header = strings.TrimLeft(header, "+")
|
||||
add(newHeaders, header, base[header])
|
||||
applyEach(values, repl.Replace)
|
||||
add(newHeaders, header, values)
|
||||
} else if strings.HasPrefix(header, "-") {
|
||||
base.Del(strings.TrimLeft(header, "-"))
|
||||
} else if _, ok := base[header]; ok {
|
||||
applyEach(values, repl.Replace)
|
||||
for _, v := range values {
|
||||
newHeaders.Set(header, v)
|
||||
}
|
||||
} else {
|
||||
applyEach(values, repl.Replace)
|
||||
add(newHeaders, header, values)
|
||||
add(newHeaders, header, base[header])
|
||||
}
|
||||
}
|
||||
return newHeaders
|
||||
}
|
||||
|
||||
func applyEach(values []string, mapFn func(string) string) {
|
||||
for i, v := range values {
|
||||
values[i] = mapFn(v)
|
||||
}
|
||||
}
|
||||
|
||||
func add(base http.Header, header string, values []string) {
|
||||
for _, v := range values {
|
||||
base.Add(header, v)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -348,6 +348,141 @@ func TestUnixSocketProxyPaths(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUpstreamHeadersUpdate(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
var actualHeaders http.Header
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hello, client"))
|
||||
actualHeaders = r.Header
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream.host.UpstreamHeaders = http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"},
|
||||
"+Merge-Me": {"Merge-Value"},
|
||||
"+Add-Me": {"Add-Value"},
|
||||
"-Remove-Me": {""},
|
||||
"Replace-Me": {"{hostname}"},
|
||||
}
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Upstreams: []Upstream{upstream},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
//add initial headers
|
||||
r.Header.Add("Merge-Me", "Initial")
|
||||
r.Header.Add("Remove-Me", "Remove-Value")
|
||||
r.Header.Add("Replace-Me", "Replace-Value")
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
replacer := middleware.NewReplacer(r, nil, "")
|
||||
|
||||
headerKey := "Merge-Me"
|
||||
values, ok := actualHeaders[headerKey]
|
||||
if !ok {
|
||||
t.Errorf("Request sent to upstream backend does not contain expected %v header. Expected header to be added", headerKey)
|
||||
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
|
||||
t.Errorf("Values for proxy header `+Merge-Me` should be merged. Got %v", values)
|
||||
}
|
||||
|
||||
headerKey = "Add-Me"
|
||||
if _, ok := actualHeaders[headerKey]; !ok {
|
||||
t.Errorf("Request sent to upstream backend does not contain expected %v header", headerKey)
|
||||
}
|
||||
|
||||
headerKey = "Remove-Me"
|
||||
if _, ok := actualHeaders[headerKey]; ok {
|
||||
t.Errorf("Request sent to upstream backend should not contain %v header", headerKey)
|
||||
}
|
||||
|
||||
headerKey = "Replace-Me"
|
||||
headerValue := replacer.Replace("{hostname}")
|
||||
value, ok := actualHeaders[headerKey]
|
||||
if !ok {
|
||||
t.Errorf("Request sent to upstream backend should not remove %v header", headerKey)
|
||||
} else if len(value) > 0 && headerValue != value[0] {
|
||||
t.Errorf("Request sent to upstream backend should replace value of %v header with %v. Instead value was %v", headerKey, headerValue, value)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestDownstreamHeadersUpdate(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Merge-Me", "Initial")
|
||||
w.Header().Add("Remove-Me", "Remove-Value")
|
||||
w.Header().Add("Replace-Me", "Replace-Value")
|
||||
w.Write([]byte("Hello, client"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream.host.DownstreamHeaders = http.Header{
|
||||
"+Merge-Me": {"Merge-Value"},
|
||||
"+Add-Me": {"Add-Value"},
|
||||
"-Remove-Me": {""},
|
||||
"Replace-Me": {"{hostname}"},
|
||||
}
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Upstreams: []Upstream{upstream},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
replacer := middleware.NewReplacer(r, nil, "")
|
||||
actualHeaders := w.Header()
|
||||
|
||||
headerKey := "Merge-Me"
|
||||
values, ok := actualHeaders[headerKey]
|
||||
if !ok {
|
||||
t.Errorf("Downstream response does not contain expected %v header. Expected header should be added", headerKey)
|
||||
} else if len(values) < 2 && (values[0] != "Initial" || values[1] != replacer.Replace("{hostname}")) {
|
||||
t.Errorf("Values for header `+Merge-Me` should be merged. Got %v", values)
|
||||
}
|
||||
|
||||
headerKey = "Add-Me"
|
||||
if _, ok := actualHeaders[headerKey]; !ok {
|
||||
t.Errorf("Downstream response does not contain expected %v header", headerKey)
|
||||
}
|
||||
|
||||
headerKey = "Remove-Me"
|
||||
if _, ok := actualHeaders[headerKey]; ok {
|
||||
t.Errorf("Downstream response should not contain %v header received from upstream", headerKey)
|
||||
}
|
||||
|
||||
headerKey = "Replace-Me"
|
||||
headerValue := replacer.Replace("{hostname}")
|
||||
value, ok := actualHeaders[headerKey]
|
||||
if !ok {
|
||||
t.Errorf("Downstream response should contain %v header and not remove it", headerKey)
|
||||
} else if len(value) > 0 && headerValue != value[0] {
|
||||
t.Errorf("Downstream response should have header %v with value %v. Instead value was %v", headerKey, headerValue, value)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||
uri, _ := url.Parse(name)
|
||||
u := &fakeUpstream{
|
||||
|
@ -410,7 +545,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost {
|
|||
return &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
|
||||
ExtraHeaders: http.Header{
|
||||
UpstreamHeaders: http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"}},
|
||||
}
|
||||
|
|
|
@ -154,7 +154,9 @@ var InsecureTransport http.RoundTripper = &http.Transport{
|
|||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) error {
|
||||
type respUpdateFn func(resp *http.Response)
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
|
@ -169,6 +171,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request) e
|
|||
res, err := transport.RoundTrip(outreq)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if respUpdateFn != nil {
|
||||
respUpdateFn(res)
|
||||
}
|
||||
|
||||
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
||||
|
|
|
@ -20,7 +20,8 @@ var (
|
|||
|
||||
type staticUpstream struct {
|
||||
from string
|
||||
proxyHeaders http.Header
|
||||
upstreamHeaders http.Header
|
||||
downstreamHeaders http.Header
|
||||
Hosts HostPool
|
||||
Policy Policy
|
||||
insecureSkipVerify bool
|
||||
|
@ -43,7 +44,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
|
|||
for c.Next() {
|
||||
upstream := &staticUpstream{
|
||||
from: "",
|
||||
proxyHeaders: make(http.Header),
|
||||
upstreamHeaders: make(http.Header),
|
||||
downstreamHeaders: make(http.Header),
|
||||
Hosts: nil,
|
||||
Policy: &Random{},
|
||||
FailTimeout: 10 * time.Second,
|
||||
|
@ -102,7 +104,8 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
|||
Fails: 0,
|
||||
FailTimeout: u.FailTimeout,
|
||||
Unhealthy: false,
|
||||
ExtraHeaders: u.proxyHeaders,
|
||||
UpstreamHeaders: u.upstreamHeaders,
|
||||
DownstreamHeaders: u.downstreamHeaders,
|
||||
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
|
||||
return func(uh *UpstreamHost) bool {
|
||||
if uh.Unhealthy {
|
||||
|
@ -182,15 +185,23 @@ func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
|
|||
}
|
||||
u.HealthCheck.Interval = dur
|
||||
}
|
||||
case "header_upstream":
|
||||
fallthrough
|
||||
case "proxy_header":
|
||||
var header, value string
|
||||
if !c.Args(&header, &value) {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.proxyHeaders.Add(header, value)
|
||||
u.upstreamHeaders.Add(header, value)
|
||||
case "header_downstream":
|
||||
var header, value string
|
||||
if !c.Args(&header, &value) {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.downstreamHeaders.Add(header, value)
|
||||
case "websocket":
|
||||
u.proxyHeaders.Add("Connection", "{>Connection}")
|
||||
u.proxyHeaders.Add("Upgrade", "{>Upgrade}")
|
||||
u.upstreamHeaders.Add("Connection", "{>Connection}")
|
||||
u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
|
||||
case "without":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
|
|
Loading…
Reference in New Issue
Block a user