Merge pull request #1174 from tw4452852/1173

header: implement http.Hijacker for responseWriterWrapper
This commit is contained in:
Matt Holt 2016-10-13 23:09:21 -06:00 committed by GitHub
commit 036633b64a
6 changed files with 97 additions and 9 deletions

View File

@ -5,7 +5,6 @@ package gzip
import ( import (
"bufio" "bufio"
"compress/gzip" "compress/gzip"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -144,7 +143,7 @@ func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := w.ResponseWriter.(http.Hijacker); ok { if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack() return hj.Hijack()
} }
return nil, nil, fmt.Errorf("not a Hijacker") return nil, nil, httpserver.NonHijackerError{Underlying: w.ResponseWriter}
} }
// Flush implements http.Flusher. It simply wraps the underlying // Flush implements http.Flusher. It simply wraps the underlying
@ -153,7 +152,7 @@ func (w *gzipResponseWriter) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok { if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush() f.Flush()
} else { } else {
panic("not a Flusher") // should be recovered at the beginning of middleware stack panic(httpserver.NonFlusherError{Underlying: w.ResponseWriter}) // should be recovered at the beginning of middleware stack
} }
} }
@ -163,5 +162,5 @@ func (w *gzipResponseWriter) CloseNotify() <-chan bool {
if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok { if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify() return cn.CloseNotify()
} }
panic("not a CloseNotifier") panic(httpserver.NonCloseNotifierError{Underlying: w.ResponseWriter})
} }

View File

@ -4,6 +4,8 @@
package header package header
import ( import (
"bufio"
"net"
"net/http" "net/http"
"strings" "strings"
@ -113,3 +115,12 @@ func (rww *responseWriterWrapper) setHeader(key, value string) {
h.Set(key, value) h.Set(key, value)
}) })
} }
// Hijack implements http.Hijacker. It simply wraps the underlying
// ResponseWriter's Hijack method if there is one, or returns an error.
func (rww *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := rww.w.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, httpserver.NonHijackerError{Underlying: rww.w}
}

View File

@ -0,0 +1,44 @@
package httpserver
import (
"fmt"
)
var (
_ error = NonHijackerError{}
_ error = NonFlusherError{}
_ error = NonCloseNotifierError{}
)
// NonHijackerError is more descriptive error caused by a non hijacker
type NonHijackerError struct {
// underlying type which doesn't implement Hijack
Underlying interface{}
}
// Implement Error
func (h NonHijackerError) Error() string {
return fmt.Sprintf("%T is not a hijacker", h.Underlying)
}
// NonFlusherError is more descriptive error caused by a non flusher
type NonFlusherError struct {
// underlying type which doesn't implement Flush
Underlying interface{}
}
// Implement Error
func (f NonFlusherError) Error() string {
return fmt.Sprintf("%T is not a flusher", f.Underlying)
}
// NonCloseNotifierError is more descriptive error caused by a non closeNotifier
type NonCloseNotifierError struct {
// underlying type which doesn't implement CloseNotify
Underlying interface{}
}
// Implement Error
func (c NonCloseNotifierError) Error() string {
return fmt.Sprintf("%T is not a closeNotifier", c.Underlying)
}

View File

@ -2,7 +2,6 @@ package httpserver
import ( import (
"bufio" "bufio"
"errors"
"net" "net"
"net/http" "net/http"
"time" "time"
@ -75,7 +74,7 @@ func (r *ResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := r.ResponseWriter.(http.Hijacker); ok { if hj, ok := r.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack() return hj.Hijack()
} }
return nil, nil, errors.New("not a Hijacker") return nil, nil, NonHijackerError{Underlying: r.ResponseWriter}
} }
// Flush implements http.Flusher. It simply wraps the underlying // Flush implements http.Flusher. It simply wraps the underlying
@ -84,7 +83,7 @@ func (r *ResponseRecorder) Flush() {
if f, ok := r.ResponseWriter.(http.Flusher); ok { if f, ok := r.ResponseWriter.(http.Flusher); ok {
f.Flush() f.Flush()
} else { } else {
panic("not a Flusher") // should be recovered at the beginning of middleware stack panic(NonFlusherError{Underlying: r.ResponseWriter}) // should be recovered at the beginning of middleware stack
} }
} }
@ -94,5 +93,5 @@ func (r *ResponseRecorder) CloseNotify() <-chan bool {
if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
return cn.CloseNotify() return cn.CloseNotify()
} }
panic("not a CloseNotifier") panic(NonCloseNotifierError{Underlying: r.ResponseWriter})
} }

View File

@ -97,6 +97,39 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
} }
} }
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
// Capture the expected panic
defer func() {
r := recover()
if _, ok := r.(httpserver.NonHijackerError); !ok {
t.Error("not get the expected panic")
}
}()
var connCount int32
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { atomic.AddInt32(&connCount, 1) }))
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL)
// Create client request
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
r.Header = http.Header{
"Connection": {"Upgrade"},
"Upgrade": {"websocket"},
"Origin": {wsNop.URL},
"Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
"Sec-WebSocket-Version": {"13"},
}
nonHijacker := httptest.NewRecorder()
p.ServeHTTP(nonHijacker, r)
}
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
// No-op websocket backend simply allows the WS connection to be // No-op websocket backend simply allows the WS connection to be
// accepted then it will be immediately closed. Perfect for testing. // accepted then it will be immediately closed. Perfect for testing.

View File

@ -21,6 +21,8 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/mholt/caddy/caddyhttp/httpserver"
) )
var bufferPool = sync.Pool{New: createBuffer} var bufferPool = sync.Pool{New: createBuffer}
@ -195,7 +197,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
res.Body.Close() res.Body.Close()
hj, ok := rw.(http.Hijacker) hj, ok := rw.(http.Hijacker)
if !ok { if !ok {
return nil panic(httpserver.NonHijackerError{Underlying: rw})
} }
conn, _, err := hj.Hijack() conn, _, err := hj.Hijack()