diff --git a/caddyhttp/proxy/policy.go b/caddyhttp/proxy/policy.go index 543b5339b..28c81bd84 100644 --- a/caddyhttp/proxy/policy.go +++ b/caddyhttp/proxy/policy.go @@ -23,6 +23,7 @@ func init() { RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) RegisterPolicy("ip_hash", func() Policy { return &IPHash{} }) RegisterPolicy("first", func() Policy { return &First{} }) + RegisterPolicy("uri_hash", func() Policy { return &URIHash{} }) } // Random is a policy that selects up hosts from a pool at random. @@ -106,23 +107,10 @@ func (r *RoundRobin) Select(pool HostPool, request *http.Request) *UpstreamHost return nil } -// IPHash is a policy that selects hosts based on hashing the request IP -type IPHash struct{} - -func hash(s string) uint32 { - h := fnv.New32a() - h.Write([]byte(s)) - return h.Sum32() -} - -// Select selects an up host from the pool based on hashing the request IP -func (r *IPHash) Select(pool HostPool, request *http.Request) *UpstreamHost { +// hostByHashing returns an available host from pool based on a hashable string +func hostByHashing(pool HostPool, s string) *UpstreamHost { poolLen := uint32(len(pool)) - clientIP, _, err := net.SplitHostPort(request.RemoteAddr) - if err != nil { - clientIP = request.RemoteAddr - } - index := hash(clientIP) % poolLen + index := hash(s) % poolLen for i := uint32(0); i < poolLen; i++ { index += i host := pool[index%poolLen] @@ -133,6 +121,33 @@ func (r *IPHash) Select(pool HostPool, request *http.Request) *UpstreamHost { return nil } +// hash calculates a hash based on string s +func hash(s string) uint32 { + h := fnv.New32a() + h.Write([]byte(s)) + return h.Sum32() +} + +// IPHash is a policy that selects hosts based on hashing the request IP +type IPHash struct{} + +// Select selects an up host from the pool based on hashing the request IP +func (r *IPHash) Select(pool HostPool, request *http.Request) *UpstreamHost { + clientIP, _, err := net.SplitHostPort(request.RemoteAddr) + if err != nil { + clientIP = request.RemoteAddr + } + return hostByHashing(pool, clientIP) +} + +// URIHash is a policy that selects the host based on hashing the request URI +type URIHash struct{} + +// Select selects the host based on hashing the URI +func (r *URIHash) Select(pool HostPool, request *http.Request) *UpstreamHost { + return hostByHashing(pool, request.RequestURI) +} + // First is a policy that selects the first available host type First struct{} diff --git a/caddyhttp/proxy/policy_test.go b/caddyhttp/proxy/policy_test.go index 1db7c29f3..5cc7e85c6 100644 --- a/caddyhttp/proxy/policy_test.go +++ b/caddyhttp/proxy/policy_test.go @@ -243,3 +243,62 @@ func TestFirstPolicy(t *testing.T) { t.Error("Expected first policy host to be the second host.") } } + +func TestUriPolicy(t *testing.T) { + pool := testPool() + uriPolicy := &URIHash{} + + request := httptest.NewRequest(http.MethodGet, "/test", nil) + h := uriPolicy.Select(pool, request) + if h != pool[0] { + t.Error("Expected uri policy host to be the first host.") + } + + pool[0].Unhealthy = 1 + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the first host.") + } + + request = httptest.NewRequest(http.MethodGet, "/test_2", nil) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the second host.") + } + + // We should be able to resize the host pool and still be able to predict + // where a request will be routed with the same URI's used above + pool = []*UpstreamHost{ + { + Name: workableServer.URL, // this should resolve (healthcheck test) + }, + { + Name: "http://localhost:99998", // this shouldn't + }, + } + + request = httptest.NewRequest(http.MethodGet, "/test", nil) + h = uriPolicy.Select(pool, request) + if h != pool[0] { + t.Error("Expected uri policy host to be the first host.") + } + + pool[0].Unhealthy = 1 + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the first host.") + } + + request = httptest.NewRequest(http.MethodGet, "/test_2", nil) + h = uriPolicy.Select(pool, request) + if h != pool[1] { + t.Error("Expected uri policy host to be the second host.") + } + + pool[0].Unhealthy = 1 + pool[1].Unhealthy = 1 + h = uriPolicy.Select(pool, request) + if h != nil { + t.Error("Expected uri policy policy host to be nil.") + } +}