reverseproxy: Always return new upstreams (fix #5736) (#5752)

* reverseproxy: Always return new upstreams (fix #5736)

* Fix healthcheck logger race
This commit is contained in:
Matt Holt 2023-08-17 11:33:40 -06:00 committed by GitHub
parent d6f86cccf5
commit 936ee918ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 19 deletions

View File

@ -356,6 +356,7 @@ func (h *Handler) Provision(ctx caddy.Context) error {
if h.HealthChecks != nil { if h.HealthChecks != nil {
// set defaults on passive health checks, if necessary // set defaults on passive health checks, if necessary
if h.HealthChecks.Passive != nil { if h.HealthChecks.Passive != nil {
h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive")
if h.HealthChecks.Passive.FailDuration > 0 && h.HealthChecks.Passive.MaxFails == 0 { if h.HealthChecks.Passive.FailDuration > 0 && h.HealthChecks.Passive.MaxFails == 0 {
h.HealthChecks.Passive.MaxFails = 1 h.HealthChecks.Passive.MaxFails = 1
} }
@ -1077,12 +1078,11 @@ func (h Handler) provisionUpstream(upstream *Upstream) {
// without MaxRequests), copy the value into this upstream, since the // without MaxRequests), copy the value into this upstream, since the
// value in the upstream (MaxRequests) is what is used during // value in the upstream (MaxRequests) is what is used during
// availability checks // availability checks
if h.HealthChecks != nil && h.HealthChecks.Passive != nil { if h.HealthChecks != nil &&
h.HealthChecks.Passive.logger = h.logger.Named("health_checker.passive") h.HealthChecks.Passive != nil &&
if h.HealthChecks.Passive.UnhealthyRequestCount > 0 && h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
upstream.MaxRequests == 0 { upstream.MaxRequests == 0 {
upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
}
} }
// upstreams need independent access to the passive // upstreams need independent access to the passive

View File

@ -114,7 +114,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
cached := srvs[suAddr] cached := srvs[suAddr]
srvsMu.RUnlock() srvsMu.RUnlock()
if cached.isFresh() { if cached.isFresh() {
return cached.upstreams, nil return allNew(cached.upstreams), nil
} }
// otherwise, obtain a write-lock to update the cached value // otherwise, obtain a write-lock to update the cached value
@ -126,7 +126,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// have refreshed it in the meantime before we re-obtained our lock // have refreshed it in the meantime before we re-obtained our lock
cached = srvs[suAddr] cached = srvs[suAddr]
if cached.isFresh() { if cached.isFresh() {
return cached.upstreams, nil return allNew(cached.upstreams), nil
} }
su.logger.Debug("refreshing SRV upstreams", su.logger.Debug("refreshing SRV upstreams",
@ -145,7 +145,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
su.logger.Warn("SRV records filtered", zap.Error(err)) su.logger.Warn("SRV records filtered", zap.Error(err))
} }
upstreams := make([]*Upstream, len(records)) upstreams := make([]Upstream, len(records))
for i, rec := range records { for i, rec := range records {
su.logger.Debug("discovered SRV record", su.logger.Debug("discovered SRV record",
zap.String("target", rec.Target), zap.String("target", rec.Target),
@ -153,7 +153,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
zap.Uint16("priority", rec.Priority), zap.Uint16("priority", rec.Priority),
zap.Uint16("weight", rec.Weight)) zap.Uint16("weight", rec.Weight))
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port))) addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
upstreams[i] = &Upstream{Dial: addr} upstreams[i] = Upstream{Dial: addr}
} }
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
@ -170,7 +170,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams: upstreams, upstreams: upstreams,
} }
return upstreams, nil return allNew(upstreams), nil
} }
func (su SRVUpstreams) String() string { func (su SRVUpstreams) String() string {
@ -206,7 +206,7 @@ func (SRVUpstreams) formattedAddr(service, proto, name string) string {
type srvLookup struct { type srvLookup struct {
srvUpstreams SRVUpstreams srvUpstreams SRVUpstreams
freshness time.Time freshness time.Time
upstreams []*Upstream upstreams []Upstream
} }
func (sl srvLookup) isFresh() bool { func (sl srvLookup) isFresh() bool {
@ -325,7 +325,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
cached := aAaaa[auStr] cached := aAaaa[auStr]
aAaaaMu.RUnlock() aAaaaMu.RUnlock()
if cached.isFresh() { if cached.isFresh() {
return cached.upstreams, nil return allNew(cached.upstreams), nil
} }
// otherwise, obtain a write-lock to update the cached value // otherwise, obtain a write-lock to update the cached value
@ -337,7 +337,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// have refreshed it in the meantime before we re-obtained our lock // have refreshed it in the meantime before we re-obtained our lock
cached = aAaaa[auStr] cached = aAaaa[auStr]
if cached.isFresh() { if cached.isFresh() {
return cached.upstreams, nil return allNew(cached.upstreams), nil
} }
name := repl.ReplaceAll(au.Name, "") name := repl.ReplaceAll(au.Name, "")
@ -348,15 +348,15 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
return nil, err return nil, err
} }
upstreams := make([]*Upstream, len(ips)) upstreams := make([]Upstream, len(ips))
for i, ip := range ips { for i, ip := range ips {
upstreams[i] = &Upstream{ upstreams[i] = Upstream{
Dial: net.JoinHostPort(ip.String(), port), Dial: net.JoinHostPort(ip.String(), port),
} }
} }
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
if cached.freshness.IsZero() && len(srvs) >= 100 { if cached.freshness.IsZero() && len(aAaaa) >= 100 {
for randomKey := range aAaaa { for randomKey := range aAaaa {
delete(aAaaa, randomKey) delete(aAaaa, randomKey)
break break
@ -369,7 +369,7 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
upstreams: upstreams, upstreams: upstreams,
} }
return upstreams, nil return allNew(upstreams), nil
} }
func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) } func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) }
@ -377,7 +377,7 @@ func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port)
type aLookup struct { type aLookup struct {
aUpstreams AUpstreams aUpstreams AUpstreams
freshness time.Time freshness time.Time
upstreams []*Upstream upstreams []Upstream
} }
func (al aLookup) isFresh() bool { func (al aLookup) isFresh() bool {
@ -483,6 +483,14 @@ func (u *UpstreamResolver) ParseAddresses() error {
return nil return nil
} }
func allNew(upstreams []Upstream) []*Upstream {
results := make([]*Upstream, len(upstreams))
for i := range upstreams {
results[i] = &Upstream{Dial: upstreams[i].Dial}
}
return results
}
var ( var (
srvs = make(map[string]srvLookup) srvs = make(map[string]srvLookup)
srvsMu sync.RWMutex srvsMu sync.RWMutex