diff --git a/caddytest/integration/listener_test.go b/caddytest/integration/listener_test.go new file mode 100644 index 000000000..30642b1ae --- /dev/null +++ b/caddytest/integration/listener_test.go @@ -0,0 +1,94 @@ +package integration + +import ( + "bytes" + "fmt" + "math/rand" + "net" + "net/http" + "strings" + "testing" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +func setupListenerWrapperTest(t *testing.T, handlerFunc http.HandlerFunc) *caddytest.Tester { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %s", err) + } + + mux := http.NewServeMux() + mux.Handle("/", handlerFunc) + srv := &http.Server{ + Handler: mux, + } + go srv.Serve(l) + t.Cleanup(func() { + _ = srv.Close() + _ = l.Close() + }) + tester := caddytest.NewTester(t) + tester.InitServer(fmt.Sprintf(` + { + skip_install_trust + admin localhost:2999 + http_port 9080 + https_port 9443 + local_certs + servers :9443 { + listener_wrappers { + http_redirect + tls + } + } + } + localhost { + reverse_proxy %s + } + `, l.Addr().String()), "caddyfile") + return tester +} + +func TestHTTPRedirectWrapperWithLargeUpload(t *testing.T) { + const uploadSize = (1024 * 1024) + 1 // 1 MB + 1 byte + // 1 more than an MB + body := make([]byte, uploadSize) + rand.New(rand.NewSource(0)).Read(body) + + tester := setupListenerWrapperTest(t, func(writer http.ResponseWriter, request *http.Request) { + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(request.Body) + if err != nil { + t.Fatalf("failed to read body: %s", err) + } + + if !bytes.Equal(buf.Bytes(), body) { + t.Fatalf("body not the same") + } + + writer.WriteHeader(http.StatusNoContent) + }) + resp, err := tester.Client.Post("https://localhost:9443", "application/octet-stream", bytes.NewReader(body)) + if err != nil { + t.Fatalf("failed to post: %s", err) + } + + if resp.StatusCode != http.StatusNoContent { + t.Fatalf("unexpected status: %d != %d", resp.StatusCode, http.StatusNoContent) + } +} + +func TestLargeHttpRequest(t *testing.T) { + tester := setupListenerWrapperTest(t, func(writer http.ResponseWriter, request *http.Request) { + t.Fatal("not supposed to handle a request") + }) + + // We never read the body in any way, set an extra long header instead. + req, _ := http.NewRequest("POST", "http://localhost:9443", nil) + req.Header.Set("Long-Header", strings.Repeat("X", 1024*1024)) + _, err := tester.Client.Do(req) + if err == nil { + t.Fatal("not supposed to succeed") + } +} diff --git a/modules/caddyhttp/httpredirectlistener.go b/modules/caddyhttp/httpredirectlistener.go index 082dc7ce8..ce9ac0308 100644 --- a/modules/caddyhttp/httpredirectlistener.go +++ b/modules/caddyhttp/httpredirectlistener.go @@ -16,11 +16,11 @@ package caddyhttp import ( "bufio" + "bytes" "fmt" "io" "net" "net/http" - "sync" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" @@ -86,15 +86,17 @@ func (l *httpRedirectListener) Accept() (net.Conn, error) { } return &httpRedirectConn{ - Conn: c, - r: bufio.NewReader(io.LimitReader(c, maxHeaderBytes)), + Conn: c, + limit: maxHeaderBytes, + r: bufio.NewReader(c), }, nil } type httpRedirectConn struct { net.Conn - once sync.Once - r *bufio.Reader + once bool + limit int64 + r *bufio.Reader } // Read tries to peek at the first few bytes of the request, and if we get @@ -102,53 +104,58 @@ type httpRedirectConn struct { // like an HTTP request, then we perform a HTTP->HTTPS redirect on the same // port as the original connection. func (c *httpRedirectConn) Read(p []byte) (int, error) { - var errReturn error - c.once.Do(func() { - firstBytes, err := c.r.Peek(5) - if err != nil { - return - } + if c.once { + return c.r.Read(p) + } + // no need to use sync.Once - net.Conn is not read from concurrently. + c.once = true - // If the request doesn't look like HTTP, then it's probably - // TLS bytes and we don't need to do anything. - if !firstBytesLookLikeHTTP(firstBytes) { - return - } - - // Parse the HTTP request, so we can get the Host and URL to redirect to. - req, err := http.ReadRequest(c.r) - if err != nil { - return - } - - // Build the redirect response, using the same Host and URL, - // but replacing the scheme with https. - headers := make(http.Header) - headers.Add("Location", "https://"+req.Host+req.URL.String()) - resp := &http.Response{ - Proto: "HTTP/1.0", - Status: "308 Permanent Redirect", - StatusCode: 308, - ProtoMajor: 1, - ProtoMinor: 0, - Header: headers, - } - - err = resp.Write(c.Conn) - if err != nil { - errReturn = fmt.Errorf("couldn't write HTTP->HTTPS redirect") - return - } - - errReturn = fmt.Errorf("redirected HTTP request on HTTPS port") - c.Conn.Close() - }) - - if errReturn != nil { - return 0, errReturn + firstBytes, err := c.r.Peek(5) + if err != nil { + return 0, err } - return c.r.Read(p) + // If the request doesn't look like HTTP, then it's probably + // TLS bytes, and we don't need to do anything. + if !firstBytesLookLikeHTTP(firstBytes) { + return c.r.Read(p) + } + + // From now on, we can be almost certain the request is HTTP. + // The returned error will be non nil and caller are expected to + // close the connection. + + // Set the read limit, io.MultiReader is needed because + // when resetting, *bufio.Reader discards buffered data. + buffered, _ := c.r.Peek(c.r.Buffered()) + mr := io.MultiReader(bytes.NewReader(buffered), c.Conn) + c.r.Reset(io.LimitReader(mr, c.limit)) + + // Parse the HTTP request, so we can get the Host and URL to redirect to. + req, err := http.ReadRequest(c.r) + if err != nil { + return 0, fmt.Errorf("couldn't read HTTP request") + } + + // Build the redirect response, using the same Host and URL, + // but replacing the scheme with https. + headers := make(http.Header) + headers.Add("Location", "https://"+req.Host+req.URL.String()) + resp := &http.Response{ + Proto: "HTTP/1.0", + Status: "308 Permanent Redirect", + StatusCode: 308, + ProtoMajor: 1, + ProtoMinor: 0, + Header: headers, + } + + err = resp.Write(c.Conn) + if err != nil { + return 0, fmt.Errorf("couldn't write HTTP->HTTPS redirect") + } + + return 0, fmt.Errorf("redirected HTTP request on HTTPS port") } // firstBytesLookLikeHTTP reports whether a TLS record header