package caddyhttp import ( "bytes" "io" "net/http" "strings" "testing" ) type responseWriterSpy interface { http.ResponseWriter Written() string CalledReadFrom() bool } var ( _ responseWriterSpy = (*baseRespWriter)(nil) _ responseWriterSpy = (*readFromRespWriter)(nil) ) // a barebones http.ResponseWriter mock type baseRespWriter []byte func (brw *baseRespWriter) Write(d []byte) (int, error) { *brw = append(*brw, d...) return len(d), nil } func (brw *baseRespWriter) Header() http.Header { return nil } func (brw *baseRespWriter) WriteHeader(statusCode int) {} func (brw *baseRespWriter) Written() string { return string(*brw) } func (brw *baseRespWriter) CalledReadFrom() bool { return false } // an http.ResponseWriter mock that supports ReadFrom type readFromRespWriter struct { baseRespWriter called bool } func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) { rf.called = true return io.Copy(&rf.baseRespWriter, r) } func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called } func TestResponseWriterWrapperReadFrom(t *testing.T) { tests := map[string]struct { responseWriter responseWriterSpy wantReadFrom bool }{ "no ReadFrom": { responseWriter: &baseRespWriter{}, wantReadFrom: false, }, "has ReadFrom": { responseWriter: &readFromRespWriter{}, wantReadFrom: true, }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { // what we expect middlewares to do: type myWrapper struct { *ResponseWriterWrapper } wrapped := myWrapper{ ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: tt.responseWriter}, } const srcData = "boo!" // hides everything but Read, since strings.Reader implements WriteTo it would // take precedence over our ReadFrom. src := struct{ io.Reader }{strings.NewReader(srcData)} if _, err := io.Copy(wrapped, src); err != nil { t.Errorf("%s: Copy() err = %v", name, err) } if got := tt.responseWriter.Written(); got != srcData { t.Errorf("%s: data = %q, want %q", name, got, srcData) } if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom { if tt.wantReadFrom { t.Errorf("%s: ReadFrom() should have been called", name) } else { t.Errorf("%s: ReadFrom() should not have been called", name) } } }) } } func TestResponseWriterWrapperUnwrap(t *testing.T) { w := &ResponseWriterWrapper{&baseRespWriter{}} if _, ok := w.Unwrap().(*baseRespWriter); !ok { t.Errorf("Unwrap() doesn't return the underlying ResponseWriter") } } func TestResponseRecorderReadFrom(t *testing.T) { tests := map[string]struct { responseWriter responseWriterSpy shouldBuffer bool wantReadFrom bool }{ "buffered plain": { responseWriter: &baseRespWriter{}, shouldBuffer: true, wantReadFrom: false, }, "streamed plain": { responseWriter: &baseRespWriter{}, shouldBuffer: false, wantReadFrom: false, }, "buffered ReadFrom": { responseWriter: &readFromRespWriter{}, shouldBuffer: true, wantReadFrom: false, }, "streamed ReadFrom": { responseWriter: &readFromRespWriter{}, shouldBuffer: false, wantReadFrom: true, }, } for name, tt := range tests { t.Run(name, func(t *testing.T) { var buf bytes.Buffer rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool { return tt.shouldBuffer }) const srcData = "boo!" // hides everything but Read, since strings.Reader implements WriteTo it would // take precedence over our ReadFrom. src := struct{ io.Reader }{strings.NewReader(srcData)} if _, err := io.Copy(rr, src); err != nil { t.Errorf("Copy() err = %v", err) } wantStreamed := srcData wantBuffered := "" if tt.shouldBuffer { wantStreamed = "" wantBuffered = srcData } if got := tt.responseWriter.Written(); got != wantStreamed { t.Errorf("streamed data = %q, want %q", got, wantStreamed) } if got := buf.String(); got != wantBuffered { t.Errorf("buffered data = %q, want %q", got, wantBuffered) } if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom { if tt.wantReadFrom { t.Errorf("ReadFrom() should have been called") } else { t.Errorf("ReadFrom() should not have been called") } } }) } }