diff --git a/caddy.go b/caddy.go index 032d8721c..8da6d4db8 100644 --- a/caddy.go +++ b/caddy.go @@ -518,6 +518,11 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r } if !Quiet { for _, srvln := range inst.servers { + // only show FD notice if the listener is not nil. + // This can happen when only serving UDP or TCP + if srvln.listener == nil { + continue + } if !IsLoopback(srvln.listener.Addr().String()) { checkFdlimit() break diff --git a/caddyhttp/httpserver/middleware.go b/caddyhttp/httpserver/middleware.go index a754e77ce..d470811a9 100644 --- a/caddyhttp/httpserver/middleware.go +++ b/caddyhttp/httpserver/middleware.go @@ -214,6 +214,9 @@ func SameNext(next1, next2 Handler) bool { // Context key constants. const ( + // ReplacerCtxKey is the context key for a per-request replacer. + ReplacerCtxKey caddy.CtxKey = "replacer" + // RemoteUserCtxKey is the key for the remote user of the request, if any (basicauth). RemoteUserCtxKey caddy.CtxKey = "remote_user" diff --git a/caddyhttp/httpserver/replacer.go b/caddyhttp/httpserver/replacer.go index 629067b69..371526bf6 100644 --- a/caddyhttp/httpserver/replacer.go +++ b/caddyhttp/httpserver/replacer.go @@ -102,20 +102,30 @@ func (lw *limitWriter) String() string { // emptyValue should be the string that is used in place // of empty string (can still be empty string). func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer { - rb := newLimitWriter(MaxLogBodySize) - if r.Body != nil { - r.Body = struct { - io.Reader - io.Closer - }{io.TeeReader(r.Body, rb), io.Closer(r.Body)} + repl := &replacer{ + request: r, + responseRecorder: rr, + emptyValue: emptyValue, } - return &replacer{ - request: r, - requestBody: rb, - responseRecorder: rr, - customReplacements: make(map[string]string), - emptyValue: emptyValue, + + // extract customReplacements from a request replacer when present. + if existing, ok := r.Context().Value(ReplacerCtxKey).(*replacer); ok { + repl.requestBody = existing.requestBody + repl.customReplacements = existing.customReplacements + } else { + // if there is no existing replacer, build one from scratch. + rb := newLimitWriter(MaxLogBodySize) + if r.Body != nil { + r.Body = struct { + io.Reader + io.Closer + }{io.TeeReader(r.Body, rb), io.Closer(r.Body)} + } + repl.requestBody = rb + repl.customReplacements = make(map[string]string) } + + return repl } func canLogRequest(r *http.Request) bool { diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go index b0f85a056..92f2b6fd7 100644 --- a/caddyhttp/httpserver/server.go +++ b/caddyhttp/httpserver/server.go @@ -356,6 +356,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { c := context.WithValue(r.Context(), OriginalURLCtxKey, urlCopy) r = r.WithContext(c) + // Setup a replacer for the request that keeps track of placeholder + // values across plugins. + replacer := NewReplacer(r, nil, "") + c = context.WithValue(r.Context(), ReplacerCtxKey, replacer) + r = r.WithContext(c) + w.Header().Set("Server", caddy.AppName) status, _ := s.serveHTTP(w, r) diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index 3d88af213..e66ad7766 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -82,7 +82,8 @@ type UpstreamHost struct { // This is an int32 so that we can use atomic operations to do concurrent // reads & writes to this value. The default value of 0 indicates that it // is healthy and any non-zero value indicates unhealthy. - Unhealthy int32 + Unhealthy int32 + HealthCheckResult atomic.Value } // Down checks whether the upstream host is down or not. diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 78fbb6bdd..2fac5aabc 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -26,7 +26,9 @@ package proxy import ( + "context" "crypto/tls" + "fmt" "io" "net" "net/http" @@ -91,6 +93,8 @@ type ReverseProxy struct { // response body. // If zero, no periodic flushing is done. FlushInterval time.Duration + + srvResolver srvResolver } // Though the relevant directive prefix is just "unix:", url.Parse @@ -105,6 +109,23 @@ func socketDial(hostName string) func(network, addr string) (conn net.Conn, err } } +func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) { + service := locator + if strings.HasPrefix(locator, "srv://") { + service = locator[6:] + } else if strings.HasPrefix(locator, "srv+https://") { + service = locator[12:] + } + + return func(network, addr string) (conn net.Conn, err error) { + _, addrs, err := rp.srvResolver.LookupSRV(context.Background(), "", "", service) + if err != nil { + return nil, err + } + return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port)) + } +} + func singleJoiningSlash(a, b string) string { aslash := strings.HasSuffix(a, "/") bslash := strings.HasPrefix(b, "/") @@ -131,6 +152,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * // scheme and host have to be faked req.URL.Scheme = "http" req.URL.Host = "socket" + } else if target.Scheme == "srv" { + req.URL.Scheme = "http" + req.URL.Host = target.Host + } else if target.Scheme == "srv+https" { + req.URL.Scheme = "https" + req.URL.Host = target.Host } else { req.URL.Scheme = target.Scheme req.URL.Host = target.Host @@ -199,7 +226,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * } } - rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events + rp := &ReverseProxy{ + Director: director, + FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events + srvResolver: net.DefaultResolver, + } + if target.Scheme == "unix" { rp.Transport = &http.Transport{ Dial: socketDial(target.String()), @@ -210,13 +242,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * HandshakeTimeout: defaultCryptoHandshakeTimeout, }, } - } else if keepalive != http.DefaultMaxIdleConnsPerHost { - // if keepalive is equal to the default, - // just use default transport, to avoid creating - // a brand new transport + } else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") { + dialFunc := defaultDialer.Dial + if strings.HasPrefix(target.Scheme, "srv") { + dialFunc = rp.srvDialerFunc(target.String()) + } + transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, - Dial: defaultDialer.Dial, + Dial: dialFunc, TLSHandshakeTimeout: defaultCryptoHandshakeTimeout, ExpectContinueTimeout: 1 * time.Second, } diff --git a/caddyhttp/proxy/reverseproxy_test.go b/caddyhttp/proxy/reverseproxy_test.go new file mode 100644 index 000000000..2d1d80df4 --- /dev/null +++ b/caddyhttp/proxy/reverseproxy_test.go @@ -0,0 +1,94 @@ +// Copyright 2015 Light Code Labs, LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" +) + +const ( + expectedResponse = "response from request proxied to upstream" + expectedStatus = http.StatusOK +) + +var upstreamHost *httptest.Server + +func setupTest() { + upstreamHost = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test-path" { + w.WriteHeader(expectedStatus) + w.Write([]byte(expectedResponse)) + } else { + w.WriteHeader(404) + w.Write([]byte("Not found")) + } + })) +} + +func tearDownTest() { + upstreamHost.Close() +} + +func TestSingleSRVHostReverseProxy(t *testing.T) { + setupTest() + defer tearDownTest() + + target, err := url.Parse("srv://test.upstream.service") + if err != nil { + t.Errorf("Failed to parse target URL. %s", err.Error()) + } + + upstream, err := url.Parse(upstreamHost.URL) + if err != nil { + t.Errorf("Failed to parse test server URL [%s]. %s", upstreamHost.URL, err.Error()) + } + pp, err := strconv.Atoi(upstream.Port()) + if err != nil { + t.Errorf("Failed to parse upstream server port [%s]. %s", upstream.Port(), err.Error()) + } + port := uint16(pp) + + rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost) + rp.srvResolver = testResolver{ + result: []*net.SRV{ + {Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1}, + }, + } + + resp := httptest.NewRecorder() + req, err := http.NewRequest("GET", "http://test.host/test-path", nil) + if err != nil { + t.Errorf("Failed to create new request. %s", err.Error()) + } + + err = rp.ServeHTTP(resp, req, nil) + if err != nil { + t.Errorf("Failed to perform reverse proxy to upstream host. %s", err.Error()) + } + + if resp.Body.String() != expectedResponse { + t.Errorf("Unexpected proxy response received. Expected: '%s', Got: '%s'", expectedResponse, resp.Body.String()) + } + + if resp.Code != expectedStatus { + t.Errorf("Unexpected proxy status. Expected: '%d', Got: '%d'", expectedStatus, resp.Code) + } +} diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index bab8b462c..ae15a6dcb 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -16,6 +16,7 @@ package proxy import ( "bytes" + "context" "fmt" "io" "io/ioutil" @@ -65,6 +66,11 @@ type staticUpstream struct { IgnoredSubPaths []string insecureSkipVerify bool MaxFails int32 + resolver srvResolver +} + +type srvResolver interface { + LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) } // NewStaticUpstreams parses the configuration input and sets up @@ -86,6 +92,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) TryInterval: 250 * time.Millisecond, MaxConns: 0, KeepAlive: http.DefaultMaxIdleConnsPerHost, + resolver: net.DefaultResolver, } if !c.Args(&upstream.from) { @@ -93,7 +100,21 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) } var to []string + hasSrv := false + for _, t := range c.RemainingArgs() { + if len(to) > 0 && hasSrv { + return upstreams, c.Err("only one upstream is supported when using SRV locator") + } + + if strings.HasPrefix(t, "srv://") || strings.HasPrefix(t, "srv+https://") { + if len(to) > 0 { + return upstreams, c.Err("service locator upstreams can not be mixed with host names") + } + + hasSrv = true + } + parsed, err := parseUpstream(t) if err != nil { return upstreams, err @@ -107,13 +128,18 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error) if !c.NextArg() { return upstreams, c.ArgErr() } + + if hasSrv { + return upstreams, c.Err("upstream directive is not supported when backend is service locator") + } + parsed, err := parseUpstream(c.Val()) if err != nil { return upstreams, err } to = append(to, parsed...) default: - if err := parseBlock(&c, upstream); err != nil { + if err := parseBlock(&c, upstream, hasSrv); err != nil { return upstreams, err } } @@ -165,7 +191,9 @@ func (u *staticUpstream) From() string { func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { if !strings.HasPrefix(host, "http") && !strings.HasPrefix(host, "unix:") && - !strings.HasPrefix(host, "quic:") { + !strings.HasPrefix(host, "quic:") && + !strings.HasPrefix(host, "srv://") && + !strings.HasPrefix(host, "srv+https://") { host = "http://" + host } uh := &UpstreamHost{ @@ -189,6 +217,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { }(u), WithoutPathPrefix: u.WithoutPathPrefix, MaxConns: u.MaxConns, + HealthCheckResult: atomic.Value{}, } baseURL, err := url.Parse(uh.Name) @@ -205,50 +234,65 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { } func parseUpstream(u string) ([]string, error) { - if !strings.HasPrefix(u, "unix:") { - colonIdx := strings.LastIndex(u, ":") - protoIdx := strings.Index(u, "://") - - if colonIdx != -1 && colonIdx != protoIdx { - us := u[:colonIdx] - ue := "" - portsEnd := len(u) - if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 { - portsEnd = colonIdx + nextSlash - ue = u[portsEnd:] - } - ports := u[len(us)+1 : portsEnd] - - if separators := strings.Count(ports, "-"); separators == 1 { - portsStr := strings.Split(ports, "-") - pIni, err := strconv.Atoi(portsStr[0]) - if err != nil { - return nil, err - } - - pEnd, err := strconv.Atoi(portsStr[1]) - if err != nil { - return nil, err - } - - if pEnd <= pIni { - return nil, fmt.Errorf("port range [%s] is invalid", ports) - } - - hosts := []string{} - for p := pIni; p <= pEnd; p++ { - hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue)) - } - return hosts, nil - } - } + if strings.HasPrefix(u, "unix:") { + return []string{u}, nil } - return []string{u}, nil + isSrv := strings.HasPrefix(u, "srv://") || strings.HasPrefix(u, "srv+https://") + colonIdx := strings.LastIndex(u, ":") + protoIdx := strings.Index(u, "://") + if colonIdx == -1 || colonIdx == protoIdx { + return []string{u}, nil + } + + if isSrv { + return nil, fmt.Errorf("service locator %s can not have port specified", u) + } + + us := u[:colonIdx] + ue := "" + portsEnd := len(u) + if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 { + portsEnd = colonIdx + nextSlash + ue = u[portsEnd:] + } + + ports := u[len(us)+1 : portsEnd] + separators := strings.Count(ports, "-") + + if separators == 0 { + return []string{u}, nil + } + + if separators > 1 { + return nil, fmt.Errorf("port range [%s] has %d separators", ports, separators) + } + + portsStr := strings.Split(ports, "-") + pIni, err := strconv.Atoi(portsStr[0]) + if err != nil { + return nil, err + } + + pEnd, err := strconv.Atoi(portsStr[1]) + if err != nil { + return nil, err + } + + if pEnd <= pIni { + return nil, fmt.Errorf("port range [%s] is invalid", ports) + } + + hosts := []string{} + for p := pIni; p <= pEnd; p++ { + hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue)) + } + + return hosts, nil } -func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { +func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error { switch c.Val() { case "policy": if !c.NextArg() { @@ -348,6 +392,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { if !c.NextArg() { return c.ArgErr() } + + if hasSrv { + return c.Err("health_check_port directive is not allowed when upstream is SRV locator") + } + port := c.Val() n, err := strconv.Atoi(port) if err != nil { @@ -420,54 +469,94 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error { return nil } +func (u *staticUpstream) resolveHost(h string) ([]string, bool, error) { + names := []string{} + proto := "http" + if !strings.HasPrefix(h, "srv://") && !strings.HasPrefix(h, "srv+https://") { + return []string{h}, false, nil + } + + if strings.HasPrefix(h, "srv+https://") { + proto = "https" + } + + _, addrs, err := u.resolver.LookupSRV(context.Background(), "", "", h) + if err != nil { + return names, true, err + } + + for _, addr := range addrs { + names = append(names, fmt.Sprintf("%s://%s:%d", proto, addr.Target, addr.Port)) + } + + return names, true, nil +} + func (u *staticUpstream) healthCheck() { for _, host := range u.Hosts { - hostURL := host.Name - if u.HealthCheck.Port != "" { - hostURL = replacePort(host.Name, u.HealthCheck.Port) - } - hostURL += u.HealthCheck.Path - - unhealthy := func() bool { - // set up request, needed to be able to modify headers - // possible errors are bad HTTP methods or un-parsable urls - req, err := http.NewRequest("GET", hostURL, nil) - if err != nil { - return true - } - // set host for request going upstream - if u.HealthCheck.Host != "" { - req.Host = u.HealthCheck.Host - } - r, err := u.HealthCheck.Client.Do(req) - if err != nil { - return true - } - defer func() { - io.Copy(ioutil.Discard, r.Body) - r.Body.Close() - }() - if r.StatusCode < 200 || r.StatusCode >= 400 { - return true - } - if u.HealthCheck.ContentString == "" { // don't check for content string - return false - } - // TODO ReadAll will be replaced if deemed necessary - // See https://github.com/mholt/caddy/pull/1691 - buf, err := ioutil.ReadAll(r.Body) - if err != nil { - return true - } - if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) { - return false - } - return true - }() - if unhealthy { + candidates, isSrv, err := u.resolveHost(host.Name) + if err != nil { + host.HealthCheckResult.Store(err.Error()) atomic.StoreInt32(&host.Unhealthy, 1) + continue + } + + unhealthyCount := 0 + for _, addr := range candidates { + hostURL := addr + if !isSrv && u.HealthCheck.Port != "" { + hostURL = replacePort(hostURL, u.HealthCheck.Port) + } + hostURL += u.HealthCheck.Path + + unhealthy := func() bool { + // set up request, needed to be able to modify headers + // possible errors are bad HTTP methods or un-parsable urls + req, err := http.NewRequest("GET", hostURL, nil) + if err != nil { + return true + } + // set host for request going upstream + if u.HealthCheck.Host != "" { + req.Host = u.HealthCheck.Host + } + r, err := u.HealthCheck.Client.Do(req) + if err != nil { + return true + } + defer func() { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + }() + if r.StatusCode < 200 || r.StatusCode >= 400 { + return true + } + if u.HealthCheck.ContentString == "" { // don't check for content string + return false + } + // TODO ReadAll will be replaced if deemed necessary + // See https://github.com/mholt/caddy/pull/1691 + buf, err := ioutil.ReadAll(r.Body) + if err != nil { + return true + } + if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) { + return false + } + return true + }() + + if unhealthy { + unhealthyCount++ + } + } + + if unhealthyCount == len(candidates) { + atomic.StoreInt32(&host.Unhealthy, 1) + host.HealthCheckResult.Store("Failed") } else { atomic.StoreInt32(&host.Unhealthy, 0) + host.HealthCheckResult.Store("OK") } } } diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index ce662d19c..23fd4831b 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -15,10 +15,15 @@ package proxy import ( + "context" + "errors" "fmt" "net" "net/http" "net/http/httptest" + "net/url" + "reflect" + "strconv" "strings" "sync/atomic" "testing" @@ -187,7 +192,7 @@ func TestParseBlockHealthCheck(t *testing.T) { u := staticUpstream{} c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)) for c.Next() { - parseBlock(&c, &u) + parseBlock(&c, &u, false) } if u.HealthCheck.Interval.String() != test.interval { t.Errorf( @@ -551,3 +556,216 @@ func TestQuicHost(t *testing.T) { } } } + +func TestParseSRVBlock(t *testing.T) { + tests := []struct { + config string + shouldErr bool + }{ + {"proxy / srv://bogus.service", false}, + {"proxy / srv://bogus.service:80", true}, + {"proxy / srv://bogus.service srv://bogus.service.fallback", true}, + {"proxy / srv://bogus.service http://bogus.service.fallback", true}, + {"proxy / http://bogus.service srv://bogus.service.fallback", true}, + {"proxy / srv://bogus.service bogus.service.fallback", true}, + {`proxy / srv://bogus.service { + upstream srv://bogus.service + }`, true}, + {"proxy / srv+https://bogus.service", false}, + {"proxy / srv+https://bogus.service:80", true}, + {"proxy / srv+https://bogus.service srv://bogus.service.fallback", true}, + {"proxy / srv+https://bogus.service http://bogus.service.fallback", true}, + {"proxy / http://bogus.service srv+https://bogus.service.fallback", true}, + {"proxy / srv+https://bogus.service bogus.service.fallback", true}, + {`proxy / srv+https://bogus.service { + upstream srv://bogus.service + }`, true}, + {`proxy / srv+https://bogus.service { + health_check_port 96 + }`, true}, + } + + for i, test := range tests { + _, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "") + if err == nil && test.shouldErr { + t.Errorf("Case %d - Expected an error. got nothing", i) + } + + if err != nil && !test.shouldErr { + t.Errorf("Case %d - Expected no error. got %s", i, err.Error()) + } + } +} + +type testResolver struct { + errOn string + result []*net.SRV +} + +func (r testResolver) LookupSRV(ctx context.Context, _, _, service string) (string, []*net.SRV, error) { + if service == r.errOn { + return "", nil, errors.New("an error occurred") + } + + return "", r.result, nil +} + +func TestResolveHost(t *testing.T) { + upstream := &staticUpstream{ + resolver: testResolver{ + errOn: "srv://problematic.service.name", + result: []*net.SRV{ + {Target: "target-1.fqdn", Port: 85, Priority: 1, Weight: 1}, + {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, + {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, + }, + }, + } + + tests := []struct { + host string + expect []string + isSrv bool + shouldErr bool + }{ + // Static DNS records + {"http://subdomain.domain.service", + []string{"http://subdomain.domain.service"}, + false, + false}, + {"https://subdomain.domain.service", + []string{"https://subdomain.domain.service"}, + false, + false}, + {"http://subdomain.domain.service:76", + []string{"http://subdomain.domain.service:76"}, + false, + false}, + {"https://subdomain.domain.service:65", + []string{"https://subdomain.domain.service:65"}, + false, + false}, + + // SRV lookups + {"srv://service.name", []string{ + "http://target-1.fqdn:85", + "http://target-2.fqdn:33", + "http://target-3.fqdn:94", + }, true, false}, + {"srv+https://service.name", []string{ + "https://target-1.fqdn:85", + "https://target-2.fqdn:33", + "https://target-3.fqdn:94", + }, true, false}, + {"srv://problematic.service.name", []string{}, true, true}, + } + + for i, test := range tests { + results, isSrv, err := upstream.resolveHost(test.host) + if err == nil && test.shouldErr { + t.Errorf("Test %d - expected an error, got none", i) + } + + if err != nil && !test.shouldErr { + t.Errorf("Test %d - unexpected error %s", i, err.Error()) + } + + if test.isSrv && !isSrv { + t.Errorf("Test %d - expecting resolution to be SRV lookup but it isn't", i) + } + + if isSrv && !test.isSrv { + t.Errorf("Test %d - expecting resolution to be normal lookup, got SRV", i) + } + + if !reflect.DeepEqual(results, test.expect) { + t.Errorf("Test %d - resolution result %#v does not match expected value %#v", i, results, test.expect) + } + } +} + +func TestSRVHealthCheck(t *testing.T) { + serverURL, err := url.Parse(workableServer.URL) + if err != nil { + t.Errorf("Failed to parse test server URL: %s", err.Error()) + } + + pp, err := strconv.Atoi(serverURL.Port()) + if err != nil { + t.Errorf("Failed to parse test server port [%s]: %s", serverURL.Port(), err.Error()) + } + + port := uint16(pp) + + allGoodResolver := testResolver{ + result: []*net.SRV{ + {Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1}, + }, + } + + partialFailureResolver := testResolver{ + result: []*net.SRV{ + {Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1}, + {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, + {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, + }, + } + + fullFailureResolver := testResolver{ + result: []*net.SRV{ + {Target: "target-1.fqdn", Port: 876, Priority: 1, Weight: 1}, + {Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1}, + {Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1}, + }, + } + + resolutionErrorResolver := testResolver{ + errOn: "srv://tag.service.consul", + result: []*net.SRV{}, + } + + upstream := &staticUpstream{ + Hosts: []*UpstreamHost{ + {Name: "srv://tag.service.consul"}, + }, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + + tests := []struct { + resolver testResolver + shouldFail bool + shouldErr bool + }{ + {allGoodResolver, false, false}, + {partialFailureResolver, false, false}, + {fullFailureResolver, true, false}, + {resolutionErrorResolver, true, true}, + } + + for i, test := range tests { + upstream.resolver = test.resolver + upstream.healthCheck() + if upstream.Hosts[0].Down() && !test.shouldFail { + t.Errorf("Test %d - expected all healthchecks to pass, all failing", i) + } + + if test.shouldFail && !upstream.Hosts[0].Down() { + t.Errorf("Test %d - expected all healthchecks to fail, all passing", i) + } + + status := fmt.Sprintf("%s", upstream.Hosts[0].HealthCheckResult.Load()) + + if test.shouldFail && !test.shouldErr && status != "Failed" { + t.Errorf("Test %d - Expected health check result to be 'Failed', got '%s'", i, status) + } + + if !test.shouldFail && status != "OK" { + t.Errorf("Test %d - Expected health check result to be 'OK', got '%s'", i, status) + } + + if test.shouldErr && status != "an error occurred" { + t.Errorf("Test %d - Expected health check result to be 'an error occured', got '%s'", i, status) + } + } +} diff --git a/caddytls/client.go b/caddytls/client.go index 3a9adae6c..26ef6a3c5 100644 --- a/caddytls/client.go +++ b/caddytls/client.go @@ -39,6 +39,7 @@ type ACMEClient struct { AllowPrompts bool config *Config acmeClient *acme.Client + locker Locker } // newACMEClient creates a new ACMEClient given an email and whether @@ -120,6 +121,10 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) AllowPrompts: allowPrompts, config: config, acmeClient: client, + locker: &syncLock{ + nameLocks: make(map[string]*sync.WaitGroup), + nameLocksMu: sync.Mutex{}, + }, } if config.DNSProvider == "" { @@ -210,7 +215,7 @@ func (c *ACMEClient) Obtain(name string) error { return err } - waiter, err := storage.TryLock(name) + waiter, err := c.locker.TryLock(name) if err != nil { return err } @@ -220,7 +225,7 @@ func (c *ACMEClient) Obtain(name string) error { return nil // we assume the process with the lock succeeded, rather than hammering this execution path again } defer func() { - if err := storage.Unlock(name); err != nil { + if err := c.locker.Unlock(name); err != nil { log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err) } }() @@ -286,7 +291,7 @@ func (c *ACMEClient) Renew(name string) error { return err } - waiter, err := storage.TryLock(name) + waiter, err := c.locker.TryLock(name) if err != nil { return err } @@ -296,7 +301,7 @@ func (c *ACMEClient) Renew(name string) error { return nil // we assume the process with the lock succeeded, rather than hammering this execution path again } defer func() { - if err := storage.Unlock(name); err != nil { + if err := c.locker.Unlock(name); err != nil { log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err) } }() diff --git a/caddytls/filestorage.go b/caddytls/filestorage.go index 7e8d730e9..67084ef45 100644 --- a/caddytls/filestorage.go +++ b/caddytls/filestorage.go @@ -22,7 +22,6 @@ import ( "os" "path/filepath" "strings" - "sync" "github.com/mholt/caddy" ) @@ -40,8 +39,7 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme") // instance is guaranteed to be non-nil if there is no error. func NewFileStorage(caURL *url.URL) (Storage, error) { return &FileStorage{ - Path: filepath.Join(storageBasePath, caURL.Host), - nameLocks: make(map[string]*sync.WaitGroup), + Path: filepath.Join(storageBasePath, caURL.Host), }, nil } @@ -49,9 +47,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) { // directory. It is used to get file paths in a consistent, // cross-platform way or persisting ACME assets on the file system. type FileStorage struct { - Path string - nameLocks map[string]*sync.WaitGroup - nameLocksMu sync.Mutex + Path string } // sites gets the directory that stores site certificate and keys. @@ -254,36 +250,6 @@ func (s *FileStorage) StoreUser(email string, data *UserData) error { return nil } -// TryLock attempts to get a lock for name, otherwise it returns -// a Waiter value to wait until the other process is finished. -func (s *FileStorage) TryLock(name string) (Waiter, error) { - s.nameLocksMu.Lock() - defer s.nameLocksMu.Unlock() - wg, ok := s.nameLocks[name] - if ok { - // lock already obtained, let caller wait on it - return wg, nil - } - // caller gets lock - wg = new(sync.WaitGroup) - wg.Add(1) - s.nameLocks[name] = wg - return nil, nil -} - -// Unlock unlocks name. -func (s *FileStorage) Unlock(name string) error { - s.nameLocksMu.Lock() - defer s.nameLocksMu.Unlock() - wg, ok := s.nameLocks[name] - if !ok { - return fmt.Errorf("FileStorage: no lock to release for %s", name) - } - wg.Done() - delete(s.nameLocks, name) - return nil -} - // MostRecentUserEmail implements Storage.MostRecentUserEmail by finding the // most recently written sub directory in the users' directory. It is named // after the email address. This corresponds to the most recent call to diff --git a/caddytls/storage.go b/caddytls/storage.go index 666d3f064..8587dd026 100644 --- a/caddytls/storage.go +++ b/caddytls/storage.go @@ -39,24 +39,9 @@ type UserData struct { Key []byte } -// Storage is an interface abstracting all storage used by Caddy's TLS -// subsystem. Implementations of this interface store both site and -// user data. -type Storage interface { - // SiteExists returns true if this site exists in storage. - // Site data is considered present when StoreSite has been called - // successfully (without DeleteSite having been called, of course). - SiteExists(domain string) (bool, error) - - // TryLock is called before Caddy attempts to obtain or renew a - // certificate for a certain name and store it. From the perspective - // of this method and its companion Unlock, the actions of - // obtaining/renewing and then storing the certificate are atomic, - // and both should occur within a lock. This prevents multiple - // processes -- maybe distributed ones -- from stepping on each - // other's space in the same shared storage, and from spamming - // certificate providers with multiple, redundant requests. - // +// Locker provides support for mutual exclusion +type Locker interface { + // TryLock will return immediatedly with or without acquiring the lock. // If a lock could be obtained, (nil, nil) is returned and you may // continue normally. If not (meaning another process is already // working on that name), a Waiter value will be returned upon @@ -75,6 +60,16 @@ type Storage interface { // the obtain/renew and store are finished, even if there was // an error (or a timeout). Unlock(name string) error +} + +// Storage is an interface abstracting all storage used by Caddy's TLS +// subsystem. Implementations of this interface store both site and +// user data. +type Storage interface { + // SiteExists returns true if this site exists in storage. + // Site data is considered present when StoreSite has been called + // successfully (without DeleteSite having been called, of course). + SiteExists(domain string) (bool, error) // LoadSite obtains the site data from storage for the given domain and // returns it. If data for the domain does not exist, an error value diff --git a/caddytls/sync_locker.go b/caddytls/sync_locker.go new file mode 100644 index 000000000..693f3b875 --- /dev/null +++ b/caddytls/sync_locker.go @@ -0,0 +1,57 @@ +// Copyright 2015 Light Code Labs, LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package caddytls + +import ( + "fmt" + "sync" +) + +var _ Locker = &syncLock{} + +type syncLock struct { + nameLocks map[string]*sync.WaitGroup + nameLocksMu sync.Mutex +} + +// TryLock attempts to get a lock for name, otherwise it returns +// a Waiter value to wait until the other process is finished. +func (s *syncLock) TryLock(name string) (Waiter, error) { + s.nameLocksMu.Lock() + defer s.nameLocksMu.Unlock() + wg, ok := s.nameLocks[name] + if ok { + // lock already obtained, let caller wait on it + return wg, nil + } + // caller gets lock + wg = new(sync.WaitGroup) + wg.Add(1) + s.nameLocks[name] = wg + return nil, nil +} + +// Unlock unlocks name. +func (s *syncLock) Unlock(name string) error { + s.nameLocksMu.Lock() + defer s.nameLocksMu.Unlock() + wg, ok := s.nameLocks[name] + if !ok { + return fmt.Errorf("FileStorage: no lock to release for %s", name) + } + wg.Done() + delete(s.nameLocks, name) + return nil +} diff --git a/caddytls/tls_test.go b/caddytls/tls_test.go index 7eb5c3b20..2b592cf56 100644 --- a/caddytls/tls_test.go +++ b/caddytls/tls_test.go @@ -16,7 +16,6 @@ package caddytls import ( "os" - "sync" "testing" "github.com/xenolf/lego/acme" @@ -94,7 +93,7 @@ func TestQualifiesForManagedTLS(t *testing.T) { } func TestSaveCertResource(t *testing.T) { - storage := &FileStorage{Path: "./le_test_save", nameLocks: make(map[string]*sync.WaitGroup)} + storage := &FileStorage{Path: "./le_test_save"} defer func() { err := os.RemoveAll(storage.Path) if err != nil { @@ -140,7 +139,7 @@ func TestSaveCertResource(t *testing.T) { } func TestExistingCertAndKey(t *testing.T) { - storage := &FileStorage{Path: "./le_test_existing", nameLocks: make(map[string]*sync.WaitGroup)} + storage := &FileStorage{Path: "./le_test_existing"} defer func() { err := os.RemoveAll(storage.Path) if err != nil { diff --git a/caddytls/user_test.go b/caddytls/user_test.go index 23a45edc6..f82480fbd 100644 --- a/caddytls/user_test.go +++ b/caddytls/user_test.go @@ -21,7 +21,6 @@ import ( "crypto/rand" "io" "strings" - "sync" "testing" "time" @@ -196,7 +195,7 @@ func TestGetEmail(t *testing.T) { } } -var testStorage = &FileStorage{Path: "./testdata", nameLocks: make(map[string]*sync.WaitGroup)} +var testStorage = &FileStorage{Path: "./testdata"} func (s *FileStorage) clean() error { return os.RemoveAll(s.Path)