caddy/modules/caddyhttp/responsewriter_test.go
fleandro dd9813c65b
caddyhttp: ensure ResponseWriterWrapper and ResponseRecorder use ReadFrom if the underlying response writer implements it. (#5022)
Doing so allows for splice/sendfile optimizations when available.
Fixes #4731

Co-authored-by: flga <flga@users.noreply.github.com>
Co-authored-by: Matthew Holt <mholt@users.noreply.github.com>
2022-09-07 21:13:35 +01:00

166 lines
4.0 KiB
Go

package caddyhttp
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"testing"
)
type responseWriterSpy interface {
http.ResponseWriter
Written() string
CalledReadFrom() bool
}
var (
_ responseWriterSpy = (*baseRespWriter)(nil)
_ responseWriterSpy = (*readFromRespWriter)(nil)
)
// a barebones http.ResponseWriter mock
type baseRespWriter []byte
func (brw *baseRespWriter) Write(d []byte) (int, error) {
*brw = append(*brw, d...)
return len(d), nil
}
func (brw *baseRespWriter) Header() http.Header { return nil }
func (brw *baseRespWriter) WriteHeader(statusCode int) {}
func (brw *baseRespWriter) Written() string { return string(*brw) }
func (brw *baseRespWriter) CalledReadFrom() bool { return false }
// an http.ResponseWriter mock that supports ReadFrom
type readFromRespWriter struct {
baseRespWriter
called bool
}
func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
rf.called = true
return io.Copy(&rf.baseRespWriter, r)
}
func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }
func TestResponseWriterWrapperReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
wantReadFrom bool
}{
"no ReadFrom": {
responseWriter: &baseRespWriter{},
wantReadFrom: false,
},
"has ReadFrom": {
responseWriter: &readFromRespWriter{},
wantReadFrom: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
// what we expect middlewares to do:
type myWrapper struct {
*ResponseWriterWrapper
}
wrapped := myWrapper{
ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: tt.responseWriter},
}
const srcData = "boo!"
// hides everything but Read, since strings.Reader implements WriteTo it would
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}
fmt.Println(name)
if _, err := io.Copy(wrapped, src); err != nil {
t.Errorf("Copy() err = %v", err)
}
if got := tt.responseWriter.Written(); got != srcData {
t.Errorf("data = %q, want %q", got, srcData)
}
if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
} else {
t.Errorf("ReadFrom() should not have been called")
}
}
})
}
}
func TestResponseRecorderReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
shouldBuffer bool
wantReadFrom bool
}{
"buffered plain": {
responseWriter: &baseRespWriter{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed plain": {
responseWriter: &baseRespWriter{},
shouldBuffer: false,
wantReadFrom: false,
},
"buffered ReadFrom": {
responseWriter: &readFromRespWriter{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed ReadFrom": {
responseWriter: &readFromRespWriter{},
shouldBuffer: false,
wantReadFrom: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
var buf bytes.Buffer
rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool {
return tt.shouldBuffer
})
const srcData = "boo!"
// hides everything but Read, since strings.Reader implements WriteTo it would
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}
if _, err := io.Copy(rr, src); err != nil {
t.Errorf("Copy() err = %v", err)
}
wantStreamed := srcData
wantBuffered := ""
if tt.shouldBuffer {
wantStreamed = ""
wantBuffered = srcData
}
if got := tt.responseWriter.Written(); got != wantStreamed {
t.Errorf("streamed data = %q, want %q", got, wantStreamed)
}
if got := buf.String(); got != wantBuffered {
t.Errorf("buffered data = %q, want %q", got, wantBuffered)
}
if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
} else {
t.Errorf("ReadFrom() should not have been called")
}
}
})
}
}