mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-22 12:13:00 +08:00
Merge pull request #1174 from tw4452852/1173
header: implement http.Hijacker for responseWriterWrapper
This commit is contained in:
commit
036633b64a
|
@ -5,7 +5,6 @@ package gzip
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
@ -144,7 +143,7 @@ func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
if hj, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||||
return hj.Hijack()
|
return hj.Hijack()
|
||||||
}
|
}
|
||||||
return nil, nil, fmt.Errorf("not a Hijacker")
|
return nil, nil, httpserver.NonHijackerError{Underlying: w.ResponseWriter}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush implements http.Flusher. It simply wraps the underlying
|
// Flush implements http.Flusher. It simply wraps the underlying
|
||||||
|
@ -153,7 +152,7 @@ func (w *gzipResponseWriter) Flush() {
|
||||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
f.Flush()
|
f.Flush()
|
||||||
} else {
|
} else {
|
||||||
panic("not a Flusher") // should be recovered at the beginning of middleware stack
|
panic(httpserver.NonFlusherError{Underlying: w.ResponseWriter}) // should be recovered at the beginning of middleware stack
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,5 +162,5 @@ func (w *gzipResponseWriter) CloseNotify() <-chan bool {
|
||||||
if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
|
if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok {
|
||||||
return cn.CloseNotify()
|
return cn.CloseNotify()
|
||||||
}
|
}
|
||||||
panic("not a CloseNotifier")
|
panic(httpserver.NonCloseNotifierError{Underlying: w.ResponseWriter})
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
package header
|
package header
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -113,3 +115,12 @@ func (rww *responseWriterWrapper) setHeader(key, value string) {
|
||||||
h.Set(key, value)
|
h.Set(key, value)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Hijack implements http.Hijacker. It simply wraps the underlying
|
||||||
|
// ResponseWriter's Hijack method if there is one, or returns an error.
|
||||||
|
func (rww *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hj, ok := rww.w.(http.Hijacker); ok {
|
||||||
|
return hj.Hijack()
|
||||||
|
}
|
||||||
|
return nil, nil, httpserver.NonHijackerError{Underlying: rww.w}
|
||||||
|
}
|
||||||
|
|
44
caddyhttp/httpserver/error.go
Normal file
44
caddyhttp/httpserver/error.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package httpserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ error = NonHijackerError{}
|
||||||
|
_ error = NonFlusherError{}
|
||||||
|
_ error = NonCloseNotifierError{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NonHijackerError is more descriptive error caused by a non hijacker
|
||||||
|
type NonHijackerError struct {
|
||||||
|
// underlying type which doesn't implement Hijack
|
||||||
|
Underlying interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement Error
|
||||||
|
func (h NonHijackerError) Error() string {
|
||||||
|
return fmt.Sprintf("%T is not a hijacker", h.Underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NonFlusherError is more descriptive error caused by a non flusher
|
||||||
|
type NonFlusherError struct {
|
||||||
|
// underlying type which doesn't implement Flush
|
||||||
|
Underlying interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement Error
|
||||||
|
func (f NonFlusherError) Error() string {
|
||||||
|
return fmt.Sprintf("%T is not a flusher", f.Underlying)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NonCloseNotifierError is more descriptive error caused by a non closeNotifier
|
||||||
|
type NonCloseNotifierError struct {
|
||||||
|
// underlying type which doesn't implement CloseNotify
|
||||||
|
Underlying interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement Error
|
||||||
|
func (c NonCloseNotifierError) Error() string {
|
||||||
|
return fmt.Sprintf("%T is not a closeNotifier", c.Underlying)
|
||||||
|
}
|
|
@ -2,7 +2,6 @@ package httpserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -75,7 +74,7 @@ func (r *ResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
if hj, ok := r.ResponseWriter.(http.Hijacker); ok {
|
if hj, ok := r.ResponseWriter.(http.Hijacker); ok {
|
||||||
return hj.Hijack()
|
return hj.Hijack()
|
||||||
}
|
}
|
||||||
return nil, nil, errors.New("not a Hijacker")
|
return nil, nil, NonHijackerError{Underlying: r.ResponseWriter}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush implements http.Flusher. It simply wraps the underlying
|
// Flush implements http.Flusher. It simply wraps the underlying
|
||||||
|
@ -84,7 +83,7 @@ func (r *ResponseRecorder) Flush() {
|
||||||
if f, ok := r.ResponseWriter.(http.Flusher); ok {
|
if f, ok := r.ResponseWriter.(http.Flusher); ok {
|
||||||
f.Flush()
|
f.Flush()
|
||||||
} else {
|
} else {
|
||||||
panic("not a Flusher") // should be recovered at the beginning of middleware stack
|
panic(NonFlusherError{Underlying: r.ResponseWriter}) // should be recovered at the beginning of middleware stack
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,5 +93,5 @@ func (r *ResponseRecorder) CloseNotify() <-chan bool {
|
||||||
if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
|
if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
|
||||||
return cn.CloseNotify()
|
return cn.CloseNotify()
|
||||||
}
|
}
|
||||||
panic("not a CloseNotifier")
|
panic(NonCloseNotifierError{Underlying: r.ResponseWriter})
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,6 +97,39 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
||||||
|
// Capture the expected panic
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if _, ok := r.(httpserver.NonHijackerError); !ok {
|
||||||
|
t.Error("not get the expected panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var connCount int32
|
||||||
|
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { atomic.AddInt32(&connCount, 1) }))
|
||||||
|
defer wsNop.Close()
|
||||||
|
|
||||||
|
// Get proxy to use for the test
|
||||||
|
p := newWebSocketTestProxy(wsNop.URL)
|
||||||
|
|
||||||
|
// Create client request
|
||||||
|
r, err := http.NewRequest("GET", "/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
r.Header = http.Header{
|
||||||
|
"Connection": {"Upgrade"},
|
||||||
|
"Upgrade": {"websocket"},
|
||||||
|
"Origin": {wsNop.URL},
|
||||||
|
"Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
|
||||||
|
"Sec-WebSocket-Version": {"13"},
|
||||||
|
}
|
||||||
|
|
||||||
|
nonHijacker := httptest.NewRecorder()
|
||||||
|
p.ServeHTTP(nonHijacker, r)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
||||||
// No-op websocket backend simply allows the WS connection to be
|
// No-op websocket backend simply allows the WS connection to be
|
||||||
// accepted then it will be immediately closed. Perfect for testing.
|
// accepted then it will be immediately closed. Perfect for testing.
|
||||||
|
|
|
@ -21,6 +21,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
var bufferPool = sync.Pool{New: createBuffer}
|
var bufferPool = sync.Pool{New: createBuffer}
|
||||||
|
@ -195,7 +197,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
hj, ok := rw.(http.Hijacker)
|
hj, ok := rw.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
panic(httpserver.NonHijackerError{Underlying: rw})
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, _, err := hj.Hijack()
|
conn, _, err := hj.Hijack()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user