diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies.go b/modules/caddyhttp/reverseproxy/selectionpolicies.go index 293ff75e2..fcf7f90f6 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies.go @@ -111,8 +111,8 @@ func (r *WeightedRoundRobinSelection) UnmarshalCaddyfile(d *caddyfile.Dispenser) if err != nil { return d.Errf("invalid weight value '%s': %v", weight, err) } - if weightInt < 1 { - return d.Errf("invalid weight value '%s': weight should be non-zero and positive", weight) + if weightInt < 0 { + return d.Errf("invalid weight value '%s': weight should be non-negative", weight) } r.Weights = append(r.Weights, weightInt) } @@ -136,8 +136,15 @@ func (r *WeightedRoundRobinSelection) Select(pool UpstreamPool, _ *http.Request, return pool[0] } var index, totalWeight int + var weights []int + + for _, w := range r.Weights { + if w > 0 { + weights = append(weights, w) + } + } currentWeight := int(atomic.AddUint32(&r.index, 1)) % r.totalWeight - for i, weight := range r.Weights { + for i, weight := range weights { totalWeight += weight if currentWeight < totalWeight { index = i @@ -145,9 +152,9 @@ func (r *WeightedRoundRobinSelection) Select(pool UpstreamPool, _ *http.Request, } } - upstreams := make([]*Upstream, 0, len(r.Weights)) - for _, upstream := range pool { - if !upstream.Available() { + upstreams := make([]*Upstream, 0, len(weights)) + for i, upstream := range pool { + if !upstream.Available() || r.Weights[i] == 0 { continue } upstreams = append(upstreams, upstream) diff --git a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go index a4701ce86..580abbdde 100644 --- a/modules/caddyhttp/reverseproxy/selectionpolicies_test.go +++ b/modules/caddyhttp/reverseproxy/selectionpolicies_test.go @@ -131,6 +131,58 @@ func TestWeightedRoundRobinPolicy(t *testing.T) { } } +func TestWeightedRoundRobinPolicyWithZeroWeight(t *testing.T) { + pool := testPool() + wrrPolicy := WeightedRoundRobinSelection{ + Weights: []int{0, 2, 1}, + totalWeight: 3, + } + req, _ := http.NewRequest("GET", "/", nil) + + h := wrrPolicy.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected first weighted round robin host to be second host in the pool.") + } + + h = wrrPolicy.Select(pool, req, nil) + if h != pool[2] { + t.Error("Expected second weighted round robin host to be third host in the pool.") + } + + h = wrrPolicy.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expected third weighted round robin host to be second host in the pool.") + } + + // mark second host as down + pool[1].setHealthy(false) + h = wrrPolicy.Select(pool, req, nil) + if h != pool[2] { + t.Error("Expect select next available host.") + } + + h = wrrPolicy.Select(pool, req, nil) + if h != pool[2] { + t.Error("Expect select only available host.") + } + // mark second host as up + pool[1].setHealthy(true) + + h = wrrPolicy.Select(pool, req, nil) + if h != pool[1] { + t.Error("Expect select first host on availability.") + } + + // test next select in full cycle + expected := []*Upstream{pool[1], pool[2], pool[1], pool[1], pool[2], pool[1]} + for i, want := range expected { + got := wrrPolicy.Select(pool, req, nil) + if want != got { + t.Errorf("Selection %d: got host[%s], want host[%s]", i+1, got, want) + } + } +} + func TestLeastConnPolicy(t *testing.T) { pool := testPool() lcPolicy := LeastConnSelection{}