diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 3efcf6030..7be8af2ad 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -63,7 +63,6 @@ var tryDuration = 60 * time.Second // ServeHTTP satisfies the middleware.Handler interface. func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { - for _, upstream := range p.Upstreams { if middleware.Path(r.URL.Path).Matches(upstream.From()) && upstream.IsAllowedPath(r.URL.Path) { var replacer middleware.Replacer diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 18e2034b6..68b135679 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -3,6 +3,7 @@ package proxy import ( "bufio" "bytes" + "fmt" "io" "io/ioutil" "log" @@ -13,7 +14,9 @@ import ( "os" "strings" "testing" + "runtime" "time" + "path/filepath" "golang.org/x/net/websocket" ) @@ -160,6 +163,69 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) { } } +func TestUnixSocketProxy(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + + trialMsg := "Is it working?" + + var proxySuccess bool + + // This is our fake "application" we want to proxy to + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Request was proxied when this is called + proxySuccess = true + + fmt.Fprint(w, trialMsg) + })) + + // Get absolute path for unix: socket + socketPath, err := filepath.Abs("./test_socket") + if err != nil { + t.Fatalf("Unable to get absolute path: %v", err) + } + + // Change httptest.Server listener to listen to unix: socket + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Unable to listen: %v", err) + } + ts.Listener = ln + + ts.Start() + defer ts.Close() + + url := strings.Replace(ts.URL, "http://", "unix:", 1) + p := newWebSocketTestProxy(url) + + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + res, err := http.Get(echoProxy.URL) + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + actualMsg := fmt.Sprintf("%s", greeting) + + if !proxySuccess { + t.Errorf("Expected request to be proxied, but it wasn't") + } + + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + func newFakeUpstream(name string, insecure bool) *fakeUpstream { uri, _ := url.Parse(name) u := &fakeUpstream{ diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 32ca2378b..cb4ec8750 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -59,6 +59,18 @@ func singleJoiningSlash(a, b string) string { return a + b } +// Though the relevant directive prefix is just "unix:", url.Parse +// will - assuming the regular URL scheme - add additional slashes +// as if "unix" was a request protocol. +// What we need is just the path, so if "unix:/var/run/www.socket" +// was the proxy directive, the parsed hostName would be +// "unix:///var/run/www.socket", hence the ambiguous trimming. +func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) { + return func(network, addr string) (conn net.Conn, err error) { + return net.Dial("unix", hostName[len("unix://"):]) + } +} + // NewSingleHostReverseProxy returns a new ReverseProxy that rewrites // URLs to the scheme, host, and base path provided in target. If the // target's path is "/base" and the incoming request was for "/dir", @@ -68,8 +80,15 @@ func singleJoiningSlash(a, b string) string { func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { targetQuery := target.RawQuery director := func(req *http.Request) { - req.URL.Scheme = target.Scheme - req.URL.Host = target.Host + if target.Scheme == "unix" { + // to make Dial work with unix URL, + // scheme and host have to be faked + req.URL.Scheme = "http" + req.URL.Host = "socket" + } else { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + } req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) if targetQuery == "" || req.URL.RawQuery == "" { req.URL.RawQuery = targetQuery + req.URL.RawQuery @@ -80,7 +99,13 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy { req.URL.Path = strings.TrimPrefix(req.URL.Path, without) } } - return &ReverseProxy{Director: director} + rp := &ReverseProxy{Director: director} + if target.Scheme == "unix" { + rp.Transport = &http.Transport{ + Dial: socketDial(target.String()), + } + } + return rp } func copyHeader(dst, src http.Header) { diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index 9d87c07ab..faa11cd92 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -65,7 +65,8 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { upstream.Hosts = make([]*UpstreamHost, len(to)) for i, host := range to { - if !strings.HasPrefix(host, "http") { + if !strings.HasPrefix(host, "http") && + !strings.HasPrefix(host, "unix:") { host = "http://" + host } uh := &UpstreamHost{