mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-21 12:47:32 +08:00
Merge pull request #1232 from mholt/fix-1229
proxy: record request Body for retry (fixes #1229)
This commit is contained in:
commit
12fd349916
40
caddyhttp/proxy/body.go
Normal file
40
caddyhttp/proxy/body.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
type bufferedBody struct {
|
||||
*bytes.Reader
|
||||
}
|
||||
|
||||
func (*bufferedBody) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewind allows bufferedBody to be read again.
|
||||
func (b *bufferedBody) rewind() error {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
_, err := b.Seek(0, io.SeekStart)
|
||||
return err
|
||||
}
|
||||
|
||||
// newBufferedBody returns *bufferedBody to use in place of src. Closes src
|
||||
// and returns Read error on src. All content from src is buffered.
|
||||
func newBufferedBody(src io.ReadCloser) (*bufferedBody, error) {
|
||||
if src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := ioutil.ReadAll(src)
|
||||
src.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &bufferedBody{
|
||||
Reader: bytes.NewReader(b),
|
||||
}, nil
|
||||
}
|
69
caddyhttp/proxy/body_test.go
Normal file
69
caddyhttp/proxy/body_test.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBodyRetry(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
io.Copy(w, r.Body)
|
||||
r.Body.Close()
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
testcase := "test content"
|
||||
req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewBufferString(testcase))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
body, err := newBufferedBody(req.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if body != nil {
|
||||
req.Body = body
|
||||
}
|
||||
|
||||
// simulate fail request
|
||||
host := req.URL.Host
|
||||
req.URL.Host = "example.com"
|
||||
body.rewind()
|
||||
_, _ = http.DefaultTransport.RoundTrip(req)
|
||||
|
||||
// retry request
|
||||
req.URL.Host = host
|
||||
body.rewind()
|
||||
resp, err := http.DefaultTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
result, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if string(result) != testcase {
|
||||
t.Fatalf("result = %s, want %s", result, testcase)
|
||||
}
|
||||
|
||||
// try one more time for body reuse
|
||||
body.rewind()
|
||||
resp, err = http.DefaultTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
result, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if string(result) != testcase {
|
||||
t.Fatalf("result = %s, want %s", result, testcase)
|
||||
}
|
||||
}
|
|
@ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
// outreq is the request that makes a roundtrip to the backend
|
||||
outreq := createUpstreamRequest(r)
|
||||
|
||||
// record and replace outreq body
|
||||
body, err := newBufferedBody(outreq.Body)
|
||||
if err != nil {
|
||||
return http.StatusBadRequest, errors.New("failed to read downstream request body")
|
||||
}
|
||||
if body != nil {
|
||||
outreq.Body = body
|
||||
}
|
||||
|
||||
// The keepRetrying function will return true if we should
|
||||
// loop and try to select another host, or false if we
|
||||
// should break and stop retrying.
|
||||
|
@ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
|||
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
|
||||
}
|
||||
|
||||
// rewind request body to its beginning
|
||||
if err := body.rewind(); err != nil {
|
||||
return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
|
||||
}
|
||||
|
||||
// tell the proxy to serve the request
|
||||
atomic.AddInt64(&host.Conns, 1)
|
||||
backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
|
@ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyRetry(t *testing.T) {
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
// set up proxy
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
io.Copy(w, r.Body)
|
||||
r.Body.Close()
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
|
||||
proxy / localhost:65535 localhost:65534 `+backend.URL+` {
|
||||
policy round_robin
|
||||
fail_timeout 5s
|
||||
max_fails 1
|
||||
try_duration 5s
|
||||
try_interval 250ms
|
||||
}
|
||||
`)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: su,
|
||||
}
|
||||
|
||||
// middle is required to simulate closable downstream request body
|
||||
middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err = p.ServeHTTP(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}))
|
||||
defer middle.Close()
|
||||
|
||||
testcase := "test content"
|
||||
r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := http.DefaultTransport.RoundTrip(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(b) != testcase {
|
||||
t.Fatalf("string(b) = %s, want %s", string(b), testcase)
|
||||
}
|
||||
}
|
||||
|
||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||
uri, _ := url.Parse(name)
|
||||
u := &fakeUpstream{
|
||||
|
|
Loading…
Reference in New Issue
Block a user