diff --git a/config/setup/rewrite_test.go b/config/setup/rewrite_test.go index 9ff294ef0..5747dee30 100644 --- a/config/setup/rewrite_test.go +++ b/config/setup/rewrite_test.go @@ -4,8 +4,9 @@ import ( "testing" "fmt" - "github.com/mholt/caddy/middleware/rewrite" "regexp" + + "github.com/mholt/caddy/middleware/rewrite" ) func TestRewrite(t *testing.T) { diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 027f2266c..15350a993 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -147,14 +147,43 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr } defer res.Body.Close() - for _, h := range hopHeaders { - res.Header.Del(h) + if res.StatusCode == http.StatusSwitchingProtocols && res.Header.Get("Upgrade") == "websocket" { + hj, ok := rw.(http.Hijacker) + if !ok { + return nil + } + + conn, _, err := hj.Hijack() + if err != nil { + return err + } + + backendConn, err := net.Dial("tcp", outreq.Host) + if err != nil { + conn.Close() + return err + } + + outreq.Write(backendConn) + + go func() { + io.Copy(backendConn, conn) // write tcp stream to backend. + backendConn.Close() + }() + + io.Copy(conn, backendConn) // read tcp stream from backend. + conn.Close() + } else { + for _, h := range hopHeaders { + res.Header.Del(h) + } + + copyHeader(rw.Header(), res.Header) + + rw.WriteHeader(res.StatusCode) + p.copyResponse(rw, res.Body) } - copyHeader(rw.Header(), res.Header) - - rw.WriteHeader(res.StatusCode) - p.copyResponse(rw, res.Body) return nil } diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index a657a088e..011a58b86 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -12,7 +12,10 @@ import ( "github.com/mholt/caddy/config/parse" ) -var supportedPolicies map[string]func() Policy = make(map[string]func() Policy) +var ( + supportedPolicies map[string]func() Policy = make(map[string]func() Policy) + proxyHeaders http.Header = make(http.Header) +) type staticUpstream struct { from string @@ -40,7 +43,7 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { FailTimeout: 10 * time.Second, MaxFails: 1, } - var proxyHeaders http.Header + if !c.Args(&upstream.from) { return upstreams, c.ArgErr() } @@ -97,10 +100,10 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { if !c.Args(&header, &value) { return upstreams, c.ArgErr() } - if proxyHeaders == nil { - proxyHeaders = make(map[string][]string) - } proxyHeaders.Add(header, value) + case "websocket": + proxyHeaders.Add("Connection", "{>Connection}") + proxyHeaders.Add("Upgrade", "{>Upgrade}") } }