diff --git a/config/setup/proxy.go b/config/setup/proxy.go index b993763e9..42aebf9d7 100644 --- a/config/setup/proxy.go +++ b/config/setup/proxy.go @@ -1,19 +1,13 @@ package setup import ( - "net/http" - "net/url" - "strconv" - "strings" - "time" - "github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware/proxy" ) // Proxy configures a new Proxy middleware instance. func Proxy(c *Controller) (middleware.Middleware, error) { - if upstreams, err := newStaticUpstreams(c); err == nil { + if upstreams, err := proxy.NewStaticUpstreams(c.Dispenser); err == nil { return func(next middleware.Handler) middleware.Handler { return proxy.Proxy{Next: next, Upstreams: upstreams} }, nil @@ -21,125 +15,3 @@ func Proxy(c *Controller) (middleware.Middleware, error) { return nil, err } } - -// newStaticUpstreams parses the configuration input and sets up -// static upstreams for the proxy middleware. -func newStaticUpstreams(c *Controller) ([]proxy.Upstream, error) { - var upstreams []proxy.Upstream - - for c.Next() { - upstream := &proxy.StaticUpstream{ - From: "", - Hosts: nil, - Policy: &proxy.Random{}, - FailTimeout: 10 * time.Second, - MaxFails: 1, - } - var proxyHeaders http.Header - if !c.Args(&upstream.From) { - return upstreams, c.ArgErr() - } - to := c.RemainingArgs() - if len(to) == 0 { - return upstreams, c.ArgErr() - } - - for c.NextBlock() { - switch c.Val() { - case "policy": - if !c.NextArg() { - return upstreams, c.ArgErr() - } - switch c.Val() { - case "random": - upstream.Policy = &proxy.Random{} - case "round_robin": - upstream.Policy = &proxy.RoundRobin{} - case "least_conn": - upstream.Policy = &proxy.LeastConn{} - default: - return upstreams, c.ArgErr() - } - case "fail_timeout": - if !c.NextArg() { - return upstreams, c.ArgErr() - } - if dur, err := time.ParseDuration(c.Val()); err == nil { - upstream.FailTimeout = dur - } else { - return upstreams, err - } - case "max_fails": - if !c.NextArg() { - return upstreams, c.ArgErr() - } - if n, err := strconv.Atoi(c.Val()); err == nil { - upstream.MaxFails = int32(n) - } else { - return upstreams, err - } - case "health_check": - if !c.NextArg() { - return upstreams, c.ArgErr() - } - upstream.HealthCheck.Path = c.Val() - upstream.HealthCheck.Interval = 30 * time.Second - if c.NextArg() { - if dur, err := time.ParseDuration(c.Val()); err == nil { - upstream.HealthCheck.Interval = dur - } else { - return upstreams, err - } - } - case "proxy_header": - var header, value string - if !c.Args(&header, &value) { - return upstreams, c.ArgErr() - } - if proxyHeaders == nil { - proxyHeaders = make(map[string][]string) - } - proxyHeaders.Add(header, value) - } - } - - upstream.Hosts = make([]*proxy.UpstreamHost, len(to)) - for i, host := range to { - if !strings.HasPrefix(host, "http") { - host = "http://" + host - } - uh := &proxy.UpstreamHost{ - Name: host, - Conns: 0, - Fails: 0, - FailTimeout: upstream.FailTimeout, - Unhealthy: false, - ExtraHeaders: proxyHeaders, - CheckDown: func(upstream *proxy.StaticUpstream) proxy.UpstreamHostDownFunc { - return func(uh *proxy.UpstreamHost) bool { - if uh.Unhealthy { - return true - } - if uh.Fails >= upstream.MaxFails && - upstream.MaxFails != 0 { - return true - } - return false - } - }(upstream), - } - if baseUrl, err := url.Parse(uh.Name); err == nil { - uh.ReverseProxy = proxy.NewSingleHostReverseProxy(baseUrl) - } else { - return upstreams, err - } - upstream.Hosts[i] = uh - } - - if upstream.HealthCheck.Path != "" { - go upstream.HealthCheckWorker(nil) - } - upstreams = append(upstreams, upstream) - } - return upstreams, nil -} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index e04f422de..bf36b1d85 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -23,7 +23,7 @@ type Proxy struct { // suitable upstream host, or nil if no such hosts are available. type Upstream interface { //The path this upstream host should be routed on - from() string + From() string // Selects an upstream host to be routed to. Select() *UpstreamHost } @@ -55,7 +55,7 @@ func (uh *UpstreamHost) Down() bool { func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { for _, upstream := range p.Upstreams { - if middleware.Path(r.URL.Path).Matches(upstream.from()) { + if middleware.Path(r.URL.Path).Matches(upstream.From()) { var replacer middleware.Replacer start := time.Now() requestHost := r.Host diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index 55da12db1..b49beaa0c 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -1,14 +1,18 @@ package proxy import ( + "github.com/mholt/caddy/config/parse" "io" "io/ioutil" "net/http" + "net/url" + "strconv" + "strings" "time" ) -type StaticUpstream struct { - From string +type staticUpstream struct { + from string Hosts HostPool Policy Policy @@ -20,11 +24,133 @@ type StaticUpstream struct { } } -func (u *StaticUpstream) from() string { - return u.From +// newStaticUpstreams parses the configuration input and sets up +// static upstreams for the proxy middleware. +func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { + var upstreams []Upstream + + for c.Next() { + upstream := &staticUpstream{ + from: "", + Hosts: nil, + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + var proxyHeaders http.Header + if !c.Args(&upstream.from) { + return upstreams, c.ArgErr() + } + to := c.RemainingArgs() + if len(to) == 0 { + return upstreams, c.ArgErr() + } + + for c.NextBlock() { + switch c.Val() { + case "policy": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + switch c.Val() { + case "random": + upstream.Policy = &Random{} + case "round_robin": + upstream.Policy = &RoundRobin{} + case "least_conn": + upstream.Policy = &LeastConn{} + default: + return upstreams, c.ArgErr() + } + case "fail_timeout": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + if dur, err := time.ParseDuration(c.Val()); err == nil { + upstream.FailTimeout = dur + } else { + return upstreams, err + } + case "max_fails": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + if n, err := strconv.Atoi(c.Val()); err == nil { + upstream.MaxFails = int32(n) + } else { + return upstreams, err + } + case "health_check": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + upstream.HealthCheck.Path = c.Val() + upstream.HealthCheck.Interval = 30 * time.Second + if c.NextArg() { + if dur, err := time.ParseDuration(c.Val()); err == nil { + upstream.HealthCheck.Interval = dur + } else { + return upstreams, err + } + } + case "proxy_header": + var header, value string + if !c.Args(&header, &value) { + return upstreams, c.ArgErr() + } + if proxyHeaders == nil { + proxyHeaders = make(map[string][]string) + } + proxyHeaders.Add(header, value) + } + } + + upstream.Hosts = make([]*UpstreamHost, len(to)) + for i, host := range to { + if !strings.HasPrefix(host, "http") { + host = "http://" + host + } + uh := &UpstreamHost{ + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: upstream.FailTimeout, + Unhealthy: false, + ExtraHeaders: proxyHeaders, + CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { + return func(uh *UpstreamHost) bool { + if uh.Unhealthy { + return true + } + if uh.Fails >= upstream.MaxFails && + upstream.MaxFails != 0 { + return true + } + return false + } + }(upstream), + } + if baseUrl, err := url.Parse(uh.Name); err == nil { + uh.ReverseProxy = NewSingleHostReverseProxy(baseUrl) + } else { + return upstreams, err + } + upstream.Hosts[i] = uh + } + + if upstream.HealthCheck.Path != "" { + go upstream.HealthCheckWorker(nil) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil } -func (u *StaticUpstream) healthCheck() { +func (u *staticUpstream) From() string { + return u.from +} + +func (u *staticUpstream) healthCheck() { for _, host := range u.Hosts { hostUrl := host.Name + u.HealthCheck.Path if r, err := http.Get(hostUrl); err == nil { @@ -37,7 +163,7 @@ func (u *StaticUpstream) healthCheck() { } } -func (u *StaticUpstream) HealthCheckWorker(stop chan struct{}) { +func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { ticker := time.NewTicker(u.HealthCheck.Interval) u.healthCheck() for { @@ -52,7 +178,7 @@ func (u *StaticUpstream) HealthCheckWorker(stop chan struct{}) { } } -func (u *StaticUpstream) Select() *UpstreamHost { +func (u *staticUpstream) Select() *UpstreamHost { pool := u.Hosts if len(pool) == 1 { if pool[0].Down() {