diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 621582a66..6e9c3c426 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -238,6 +238,116 @@ func TestUnixSocketProxy(t *testing.T) { } } +func GetHTTPProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, messageFormat, r.URL.String()) + })) + + return newPrefixedWebSocketTestProxy(ts.URL, prefix), ts +} + +func GetSocketProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server, error) { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, messageFormat, r.URL.String()) + })) + + socketPath, err := filepath.Abs("./test_socket") + if err != nil { + return nil, nil, fmt.Errorf("Unable to get absolute path: %v", err) + } + + ln, err := net.Listen("unix", socketPath) + if err != nil { + return nil, nil, fmt.Errorf("Unable to listen: %v", err) + } + ts.Listener = ln + + ts.Start() + + tsURL := strings.Replace(ts.URL, "http://", "unix:", 1) + + return newPrefixedWebSocketTestProxy(tsURL, prefix), ts, nil +} + +func GetTestServerMessage(p *Proxy, ts *httptest.Server, path string) (string, error) { + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + + // *httptest.Server is passed so it can be `defer`red properly + defer ts.Close() + defer echoProxy.Close() + + res, err := http.Get(echoProxy.URL + path) + if err != nil { + return "", fmt.Errorf("Unable to GET: %v", err) + } + + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + return "", fmt.Errorf("Unable to read body: %v", err) + } + + return fmt.Sprintf("%s", greeting), nil +} + +func TestUnixSocketProxyPaths(t *testing.T) { + greeting := "Hello route %s" + + tests := []struct { + url string + prefix string + expected string + }{ + {"", "", fmt.Sprintf(greeting, "/")}, + {"/hello", "", fmt.Sprintf(greeting, "/hello")}, + {"/foo/bar", "", fmt.Sprintf(greeting, "/foo/bar")}, + {"/foo?bar", "", fmt.Sprintf(greeting, "/foo?bar")}, + {"/greet?name=john", "", fmt.Sprintf(greeting, "/greet?name=john")}, + {"/world?wonderful&colorful", "", fmt.Sprintf(greeting, "/world?wonderful&colorful")}, + {"/proxy/hello", "/proxy", fmt.Sprintf(greeting, "/hello")}, + {"/proxy/foo/bar", "/proxy", fmt.Sprintf(greeting, "/foo/bar")}, + {"/proxy/?foo=bar", "/proxy", fmt.Sprintf(greeting, "/?foo=bar")}, + } + + for _, test := range tests { + p, ts := GetHTTPProxy(greeting, test.prefix) + + actualMsg, err := GetTestServerMessage(p, ts, test.url) + + if err != nil { + t.Fatalf("Getting server message failed - %v", err) + } + + if actualMsg != test.expected { + t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg) + } + } + + if runtime.GOOS == "windows" { + return + } + + for _, test := range tests { + p, ts, err := GetSocketProxy(greeting, test.prefix) + + if err != nil { + t.Fatalf("Getting socket proxy failed - %v", err) + } + + actualMsg, err := GetTestServerMessage(p, ts, test.url) + + if err != nil { + t.Fatalf("Getting server message failed - %v", err) + } + + if actualMsg != test.expected { + t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg) + } + } +} + func newFakeUpstream(name string, insecure bool) *fakeUpstream { uri, _ := url.Parse(name) u := &fakeUpstream{ @@ -276,12 +386,19 @@ func (u *fakeUpstream) AllowedPath(requestPath string) bool { // proxy. func newWebSocketTestProxy(backendAddr string) *Proxy { return &Proxy{ - Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}}, + Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}}, + } +} + +func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy { + return &Proxy{ + Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}}, } } type fakeWsUpstream struct { - name string + name string + without string } func (u *fakeWsUpstream) From() string { @@ -292,7 +409,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost { uri, _ := url.Parse(u.name) return &UpstreamHost{ Name: u.name, - ReverseProxy: NewSingleHostReverseProxy(uri, ""), + ReverseProxy: NewSingleHostReverseProxy(uri, u.without), ExtraHeaders: http.Header{ "Connection": {"{>Connection}"}, "Upgrade": {"{>Upgrade}"}}, diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index cb4ec8750..7e8815d2e 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -95,6 +95,18 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { } else { req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery } + // Trims the path of the socket from the URL path. + // This is done because req.URL passed to your proxied service + // will have the full path of the socket file prefixed to it. + // Calling /test on a server that proxies requests to + // unix:/var/run/www.socket will thus set the requested path + // to /var/run/www.socket/test, rendering paths useless. + if target.Scheme == "unix" { + // See comment on socketDial for the trim + socketPrefix := target.String()[len("unix://"):] + req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix) + } + // We are then safe to remove the `without` prefix. if without != "" { req.URL.Path = strings.TrimPrefix(req.URL.Path, without) }