package proxy import ( "bufio" "bytes" "io" "net" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" "golang.org/x/net/websocket" ) func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { // No-op websocket backend simply allows the WS connection to be // accepted then it will be immediately closed. Perfect for testing. wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {})) defer wsNop.Close() // Get proxy to use for the test p := newWebSocketTestProxy(wsNop.URL) // Create client request r, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("Failed to create request: %v", err) } r.Header = http.Header{ "Connection": {"Upgrade"}, "Upgrade": {"websocket"}, "Origin": {wsNop.URL}, "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="}, "Sec-WebSocket-Version": {"13"}, } // Capture the request w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)} // Booya! Do the test. p.ServeHTTP(w, r) // Make sure the backend accepted the WS connection. // Mostly interested in the Upgrade and Connection response headers // and the 101 status code. expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n") actual := w.fakeConn.writeBuf.Bytes() if !bytes.Equal(actual, expected) { t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual) } } func TestWebSocketReverseProxyFromWSClient(t *testing.T) { // Echo server allows us to test that socket bytes are properly // being proxied. wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { io.Copy(ws, ws) })) defer wsEcho.Close() // Get proxy to use for the test p := newWebSocketTestProxy(wsEcho.URL) // This is a full end-end test, so the proxy handler // has to be part of a server listening on a port. Our // WS client will connect to this test server, not // the echo client directly. echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { p.ServeHTTP(w, r) })) defer echoProxy.Close() // Set up WebSocket client url := strings.Replace(echoProxy.URL, "http://", "ws://", 1) ws, err := websocket.Dial(url, "", echoProxy.URL) if err != nil { t.Fatal(err) } defer ws.Close() // Send test message trialMsg := "Is it working?" websocket.Message.Send(ws, trialMsg) // It should be echoed back to us var actualMsg string websocket.Message.Receive(ws, &actualMsg) if actualMsg != trialMsg { t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) } } // newWebSocketTestProxy returns a test proxy that will // redirect to the specified backendAddr. The function // also sets up the rules/environment for testing WebSocket // proxy. func newWebSocketTestProxy(backendAddr string) *Proxy { return &Proxy{ Upstreams: []Upstream{&fakeUpstream{name: backendAddr}}, } } type fakeUpstream struct { name string } func (u *fakeUpstream) From() string { return "/" } func (u *fakeUpstream) Select() *UpstreamHost { uri, _ := url.Parse(u.name) return &UpstreamHost{ Name: u.name, ReverseProxy: NewSingleHostReverseProxy(uri, ""), ExtraHeaders: http.Header{ "Connection": {"{>Connection}"}, "Upgrade": {"{>Upgrade}"}}, } } func (u *fakeUpstream) IsAllowedPath(requestPath string) bool { return true } // recorderHijacker is a ResponseRecorder that can // be hijacked. type recorderHijacker struct { *httptest.ResponseRecorder fakeConn *fakeConn } func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return rh.fakeConn, nil, nil } type fakeConn struct { readBuf bytes.Buffer writeBuf bytes.Buffer } func (c *fakeConn) LocalAddr() net.Addr { return nil } func (c *fakeConn) RemoteAddr() net.Addr { return nil } func (c *fakeConn) SetDeadline(t time.Time) error { return nil } func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil } func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } func (c *fakeConn) Close() error { return nil } func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }