caddy/caddyhttp/proxy/proxy.go
Angel Santiago 59bf71c293 proxy: Cleanly shutdown health checks on restart (#1524)
* Add a shutdown function and context to staticUpstream so that running goroutines can be cancelled. Add a GetShutdownFunc to Upstream interface to expose the shutdown function to the caddy Controller for performing it on restarts.

* Make fakeUpstream implement new Upstream methods.

Implement new Upstream method for fakeWSUpstream as well.

* Rename GetShutdownFunc to Stop(). Add a waitgroup to the staticUpstream for controlling individual object's goroutines. Add the Stop function to OnRestart and OnShutdown. Add tests for checking to see if healthchecks continue hitting a backend server after stop has been called.

* Go back to using a stop channel since the context adds no additional benefit.
Only register stop function for onShutdown since it's called as part of restart.

* Remove assignment to atomic value

* Incrementing WaitGroup outside of goroutine to avoid race condition. Loading atomic values in test.

* Linting: change counter to just use the default zero value instead of setting it

* Clarify Stop method comments, add comments to stop channel and waitgroup and remove out of date comment about handling stopping the proxy. Stop the ticker when the stop signal is sent
2017-04-02 14:58:15 -06:00

350 lines
11 KiB
Go

// Package proxy is middleware that proxies HTTP requests.
package proxy
import (
"errors"
"net"
"net/http"
"net/url"
"strings"
"sync/atomic"
"time"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Proxy represents a middleware instance that can proxy requests.
type Proxy struct {
Next httpserver.Handler
Upstreams []Upstream
}
// Upstream manages a pool of proxy upstream hosts.
type Upstream interface {
// The path this upstream host should be routed on
From() string
// Selects an upstream host to be routed to. It
// should return a suitable upstream host, or nil
// if no such hosts are available.
Select(*http.Request) *UpstreamHost
// Checks if subpath is not an ignored path
AllowedPath(string) bool
// Gets how long to try selecting upstream hosts
// in the case of cascading failures.
GetTryDuration() time.Duration
// Gets how long to wait between selecting upstream
// hosts in the case of cascading failures.
GetTryInterval() time.Duration
// Gets the number of upstream hosts.
GetHostCount() int
// Stops the upstream from proxying requests to shutdown goroutines cleanly.
Stop() error
}
// UpstreamHostDownFunc can be used to customize how Down behaves.
type UpstreamHostDownFunc func(*UpstreamHost) bool
// UpstreamHost represents a single proxy upstream
type UpstreamHost struct {
// This field is read & written to concurrently, so all access must use
// atomic operations.
Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
MaxConns int64
Name string // hostname of this upstream host
UpstreamHeaders http.Header
DownstreamHeaders http.Header
FailTimeout time.Duration
CheckDown UpstreamHostDownFunc
WithoutPathPrefix string
ReverseProxy *ReverseProxy
Fails int32
// This is an int32 so that we can use atomic operations to do concurrent
// reads & writes to this value. The default value of 0 indicates that it
// is healthy and any non-zero value indicates unhealthy.
Unhealthy int32
}
// Down checks whether the upstream host is down or not.
// Down will try to use uh.CheckDown first, and will fall
// back to some default criteria if necessary.
func (uh *UpstreamHost) Down() bool {
if uh.CheckDown == nil {
// Default settings
return atomic.LoadInt32(&uh.Unhealthy) != 0 || atomic.LoadInt32(&uh.Fails) > 0
}
return uh.CheckDown(uh)
}
// Full checks whether the upstream host has reached its maximum connections
func (uh *UpstreamHost) Full() bool {
return uh.MaxConns > 0 && atomic.LoadInt64(&uh.Conns) >= uh.MaxConns
}
// Available checks whether the upstream host is available for proxying to
func (uh *UpstreamHost) Available() bool {
return !uh.Down() && !uh.Full()
}
// ServeHTTP satisfies the httpserver.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// start by selecting most specific matching upstream config
upstream := p.match(r)
if upstream == nil {
return p.Next.ServeHTTP(w, r)
}
// this replacer is used to fill in header field values
replacer := httpserver.NewReplacer(r, nil, "")
// outreq is the request that makes a roundtrip to the backend
outreq := createUpstreamRequest(r)
// If we have more than one upstream host defined and if retrying is enabled
// by setting try_duration to a non-zero value, caddy will try to
// retry the request at a different host if the first one failed.
//
// This requires us to possibly rewind and replay the request body though,
// which in turn requires us to buffer the request body first.
//
// An unbuffered request is usually preferrable, because it reduces latency
// as well as memory usage. Furthermore it enables different kinds of
// HTTP streaming applications like gRPC for instance.
requiresBuffering := upstream.GetHostCount() > 1 && upstream.GetTryDuration() != 0
if requiresBuffering {
body, err := newBufferedBody(outreq.Body)
if err != nil {
return http.StatusBadRequest, errors.New("failed to read downstream request body")
}
if body != nil {
outreq.Body = body
}
}
// The keepRetrying function will return true if we should
// loop and try to select another host, or false if we
// should break and stop retrying.
start := time.Now()
keepRetrying := func() bool {
// if we've tried long enough, break
if time.Since(start) >= upstream.GetTryDuration() {
return false
}
// otherwise, wait and try the next available host
time.Sleep(upstream.GetTryInterval())
return true
}
var backendErr error
for {
// since Select() should give us "up" hosts, keep retrying
// hosts until timeout (or until we get a nil host).
host := upstream.Select(r)
if host == nil {
if backendErr == nil {
backendErr = errors.New("no hosts available upstream")
}
if !keepRetrying() {
break
}
continue
}
if rr, ok := w.(*httpserver.ResponseRecorder); ok && rr.Replacer != nil {
rr.Replacer.Set("upstream", host.Name)
}
proxy := host.ReverseProxy
// a backend's name may contain more than just the host,
// so we parse it as a URL to try to isolate the host.
if nameURL, err := url.Parse(host.Name); err == nil {
outreq.Host = nameURL.Host
if proxy == nil {
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost)
}
// use upstream credentials by default
if outreq.Header.Get("Authorization") == "" && nameURL.User != nil {
pwd, _ := nameURL.User.Password()
outreq.SetBasicAuth(nameURL.User.Username(), pwd)
}
} else {
outreq.Host = host.Name
}
if proxy == nil {
return http.StatusInternalServerError, errors.New("proxy for host '" + host.Name + "' is nil")
}
// set headers for request going upstream
if host.UpstreamHeaders != nil {
// modify headers for request that will be sent to the upstream host
mutateHeadersByRules(outreq.Header, host.UpstreamHeaders, replacer)
if hostHeaders, ok := outreq.Header["Host"]; ok && len(hostHeaders) > 0 {
outreq.Host = hostHeaders[len(hostHeaders)-1]
}
}
// prepare a function that will update response
// headers coming back downstream
var downHeaderUpdateFn respUpdateFn
if host.DownstreamHeaders != nil {
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
}
// Before we retry the request we have to make sure
// that the body is rewound to it's beginning.
if bb, ok := outreq.Body.(*bufferedBody); ok {
if err := bb.rewind(); err != nil {
return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
}
}
// tell the proxy to serve the request
//
// NOTE:
// The call to proxy.ServeHTTP can theoretically panic.
// To prevent host.Conns from getting out-of-sync we thus have to
// make sure that it's _always_ correctly decremented afterwards.
func() {
atomic.AddInt64(&host.Conns, 1)
defer atomic.AddInt64(&host.Conns, -1)
backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
}()
// if no errors, we're done here
if backendErr == nil {
return 0, nil
}
if _, ok := backendErr.(httpserver.MaxBytesExceeded); ok {
return http.StatusRequestEntityTooLarge, backendErr
}
// failover; remember this failure for some time if
// request failure counting is enabled
timeout := host.FailTimeout
if timeout > 0 {
atomic.AddInt32(&host.Fails, 1)
go func(host *UpstreamHost, timeout time.Duration) {
time.Sleep(timeout)
atomic.AddInt32(&host.Fails, -1)
}(host, timeout)
}
// if we've tried long enough, break
if !keepRetrying() {
break
}
}
return http.StatusBadGateway, backendErr
}
// match finds the best match for a proxy config based on r.
func (p Proxy) match(r *http.Request) Upstream {
var u Upstream
var longestMatch int
for _, upstream := range p.Upstreams {
basePath := upstream.From()
if !httpserver.Path(r.URL.Path).Matches(basePath) || !upstream.AllowedPath(r.URL.Path) {
continue
}
if len(basePath) > longestMatch {
longestMatch = len(basePath)
u = upstream
}
}
return u
}
// createUpstremRequest shallow-copies r into a new request
// that can be sent upstream.
//
// Derived from reverseproxy.go in the standard Go httputil package.
func createUpstreamRequest(r *http.Request) *http.Request {
outreq := new(http.Request)
*outreq = *r // includes shallow copies of maps, but okay
// We should set body to nil explicitly if request body is empty.
// For server requests the Request Body is always non-nil.
if r.ContentLength == 0 {
outreq.Body = nil
}
// We are modifying the same underlying map from req (shallow
// copied above) so we only copy it if necessary.
copiedHeaders := false
// Remove hop-by-hop headers listed in the "Connection" header.
// See RFC 2616, section 14.10.
if c := outreq.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
if !copiedHeaders {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, r.Header)
copiedHeaders = true
}
outreq.Header.Del(f)
}
}
}
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" {
if !copiedHeaders {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, r.Header)
copiedHeaders = true
}
outreq.Header.Del(h)
}
}
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
// If we aren't the first proxy, retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
outreq.Header.Set("X-Forwarded-For", clientIP)
}
return outreq
}
func createRespHeaderUpdateFn(rules http.Header, replacer httpserver.Replacer) respUpdateFn {
return func(resp *http.Response) {
mutateHeadersByRules(resp.Header, rules, replacer)
}
}
func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer) {
for ruleField, ruleValues := range rules {
if strings.HasPrefix(ruleField, "+") {
for _, ruleValue := range ruleValues {
replacement := repl.Replace(ruleValue)
if len(replacement) > 0 {
headers.Add(strings.TrimPrefix(ruleField, "+"), replacement)
}
}
} else if strings.HasPrefix(ruleField, "-") {
headers.Del(strings.TrimPrefix(ruleField, "-"))
} else if len(ruleValues) > 0 {
replacement := repl.Replace(ruleValues[len(ruleValues)-1])
if len(replacement) > 0 {
headers.Set(ruleField, replacement)
}
}
}
}