Only write error message/page if body not already written (fixes #567)

Based on work started in, and replaces, #614
This commit is contained in:
Matthew Holt 2016-02-24 19:50:46 -07:00
parent 737c7c4372
commit c37ad7f677
5 changed files with 22 additions and 34 deletions

View File

@ -43,7 +43,9 @@ func (h ErrorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, er
}
if status >= 400 {
h.errorPage(w, r, status)
if w.Header().Get("Content-Length") == "" {
h.errorPage(w, r, status)
}
return 0, err
}

View File

@ -9,6 +9,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
@ -78,6 +79,13 @@ func TestErrors(t *testing.T) {
expectedLog: "",
expectedErr: nil,
},
{
next: genErrorHandler(http.StatusNotFound, nil, "normal"),
expectedCode: 0,
expectedBody: "normal",
expectedLog: "",
expectedErr: nil,
},
{
next: genErrorHandler(http.StatusForbidden, nil, ""),
expectedCode: 0,
@ -158,6 +166,9 @@ func TestVisibleErrorWithPanic(t *testing.T) {
func genErrorHandler(status int, err error, body string) middleware.Handler {
return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
if len(body) > 0 {
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
}
fmt.Fprint(w, body)
return status, err
})

View File

@ -107,7 +107,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
}
var responseBody io.Reader = resp.Body
if r.Header.Get("Content-Length") == "" {
if resp.Header.Get("Content-Length") == "" {
// If the upstream app didn't set a Content-Length (shame on them),
// we need to do it to prevent error messages being appended to
// an already-written response, and other problematic behavior.
@ -137,6 +137,11 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
}
// Normally we should only return a status >= 400 if no response
// body is written yet, however, upstream apps don't know about
// this contract and we still want the correct code logged, so error
// handling code in our stack needs to check Content-Length before
// writing an error message... oh well.
return resp.StatusCode, err
}
}

View File

@ -26,7 +26,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// The error must be handled here so the log entry will record the response size.
if l.ErrorFunc != nil {
l.ErrorFunc(responseRecorder, r, status)
} else {
} else if responseRecorder.Header().Get("Content-Length") == "" { // ensure no body written since proxy backends may write an error page
// Default failover error handler
responseRecorder.WriteHeader(status)
fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status))

View File

@ -319,7 +319,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
status, _ := vh.stack.ServeHTTP(w, r)
// Fallback error response in case error handling wasn't chained in
if status >= 400 {
if status >= 400 && w.Header().Get("Content-Length") == "" {
DefaultErrorFunc(w, r, status)
}
} else {
@ -417,36 +417,6 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
return ln.TCPListener.File()
}
// copied from net/http/transport.go
/*
TODO - remove - not necessary?
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}*/
// ShutdownCallbacks executes all the shutdown callbacks
// for all the virtualhosts in servers, and returns all the
// errors generated during their execution. In other words,