reverse_proxy: WIP refactor and support for FastCGI

This commit is contained in:
Matthew Holt 2019-09-02 22:01:02 -06:00
parent 2dc4fcc62b
commit 026df7c5cb
No known key found for this signature in database
GPG Key ID: 2A349DD577D586A5
17 changed files with 2752 additions and 899 deletions

View File

@ -29,6 +29,7 @@ import (
_ "github.com/caddyserver/caddy/v2/modules/caddyhttp/markdown"
_ "github.com/caddyserver/caddy/v2/modules/caddyhttp/requestbody"
_ "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
_ "github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy/fastcgi"
_ "github.com/caddyserver/caddy/v2/modules/caddyhttp/rewrite"
_ "github.com/caddyserver/caddy/v2/modules/caddyhttp/templates"
_ "github.com/caddyserver/caddy/v2/modules/caddytls"

5
go.mod
View File

@ -8,7 +8,6 @@ require (
github.com/Masterminds/semver v1.4.2 // indirect
github.com/Masterminds/sprig v2.20.0+incompatible
github.com/andybalholm/brotli v0.0.0-20190704151324-71eb68cc467c
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.0
github.com/go-acme/lego v2.6.0+incompatible
github.com/google/go-cmp v0.3.1 // indirect
@ -17,7 +16,6 @@ require (
github.com/imdario/mergo v0.3.7 // indirect
github.com/klauspost/compress v1.7.1-0.20190613161414-0b31f265a57b
github.com/klauspost/cpuid v1.2.1
github.com/kr/pretty v0.1.0 // indirect
github.com/mholt/certmagic v0.6.2
github.com/mitchellh/go-ps v0.0.0-20170309133038-4fdf99ab2936
github.com/rs/cors v1.6.0
@ -26,8 +24,7 @@ require (
github.com/starlight-go/starlight v0.0.0-20181207205707-b06f321544f3
go.starlark.net v0.0.0-20190604130855-6ddc71c0ba77
golang.org/x/net v0.0.0-20190603091049-60506f45cf65
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect
golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e // indirect
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v2 v2.2.2 // indirect
)

13
go.sum
View File

@ -12,8 +12,6 @@ github.com/cenkalti/backoff v2.1.1+incompatible h1:tKJnvO2kl0zmb/jA5UKAt4VoEVw1q
github.com/cenkalti/backoff v2.1.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/go-acme/lego v2.5.0+incompatible/go.mod h1:yzMNe9CasVUhkquNvti5nAtPmG94USbYxYrZfTkIn0M=
@ -34,11 +32,6 @@ github.com/klauspost/compress v1.7.1-0.20190613161414-0b31f265a57b/go.mod h1:RyI
github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/klauspost/cpuid v1.2.1 h1:vJi+O/nMdFt0vqm8NZBI6wzALWdA2X+egi0ogNyrC/w=
github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mholt/certmagic v0.6.2 h1:yy9cKm3rtxdh12SW4E51lzG3Eo6N59LEOfBQ0CTnMms=
github.com/mholt/certmagic v0.6.2/go.mod h1:g4cOPxcjV0oFq3qwpjSA30LReKD8AoIfwAY9VvG35NY=
github.com/miekg/dns v1.1.3 h1:1g0r1IvskvgL8rR+AcHzUA+oFmGcQlaIm4IqakufeMM=
@ -71,16 +64,14 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sys v0.0.0-20190124100055-b90733256f2e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e h1:ZytStCyV048ZqDsWHiYDdoI2Vd4msMcrDECFxS+tL9c=
golang.org/x/sys v0.0.0-20190228124157-a34e9553db1e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA=
gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=

View File

@ -24,6 +24,8 @@ import (
"time"
)
// TODO: Can we use the new UsagePool type?
// Listen returns a listener suitable for use in a Caddy module.
// Always be sure to close listeners when you are done with them.
func Listen(network, addr string) (net.Listener, error) {

View File

@ -518,6 +518,20 @@ func (ws WeakString) String() string {
return string(ws)
}
// StatusCodeMatches returns true if a real HTTP status code matches
// the configured status code, which may be either a real HTTP status
// code or an integer representing a class of codes (e.g. 4 for all
// 4xx statuses).
func StatusCodeMatches(actual, configured int) bool {
if actual == configured {
return true
}
if configured < 100 && actual >= configured*100 && actual < (configured+1)*100 {
return true
}
return false
}
const (
// DefaultHTTPPort is the default port for HTTP.
DefaultHTTPPort = 80

View File

@ -616,10 +616,7 @@ func (rm ResponseMatcher) matchStatusCode(statusCode int) bool {
return true
}
for _, code := range rm.StatusCode {
if statusCode == code {
return true
}
if code < 100 && statusCode >= code*100 && statusCode < (code+1)*100 {
if StatusCodeMatches(statusCode, code) {
return true
}
}

View File

@ -0,0 +1,578 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Forked Jan. 2015 from http://bitbucket.org/PinIdea/fcgi_client
// (which is forked from https://code.google.com/p/go-fastcgi-client/).
// This fork contains several fixes and improvements by Matt Holt and
// other contributors to the Caddy project.
// Copyright 2012 Junqing Tan <ivan@mysqlab.net> and The Go Authors
// Use of this source code is governed by a BSD-style
// Part of source code is from Go fcgi package
package fastcgi
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"errors"
"io"
"io/ioutil"
"mime/multipart"
"net"
"net/http"
"net/http/httputil"
"net/textproto"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
// FCGIListenSockFileno describes listen socket file number.
const FCGIListenSockFileno uint8 = 0
// FCGIHeaderLen describes header length.
const FCGIHeaderLen uint8 = 8
// Version1 describes the version.
const Version1 uint8 = 1
// FCGINullRequestID describes the null request ID.
const FCGINullRequestID uint8 = 0
// FCGIKeepConn describes keep connection mode.
const FCGIKeepConn uint8 = 1
const (
// BeginRequest is the begin request flag.
BeginRequest uint8 = iota + 1
// AbortRequest is the abort request flag.
AbortRequest
// EndRequest is the end request flag.
EndRequest
// Params is the parameters flag.
Params
// Stdin is the standard input flag.
Stdin
// Stdout is the standard output flag.
Stdout
// Stderr is the standard error flag.
Stderr
// Data is the data flag.
Data
// GetValues is the get values flag.
GetValues
// GetValuesResult is the get values result flag.
GetValuesResult
// UnknownType is the unknown type flag.
UnknownType
// MaxType is the maximum type flag.
MaxType = UnknownType
)
const (
// Responder is the responder flag.
Responder uint8 = iota + 1
// Authorizer is the authorizer flag.
Authorizer
// Filter is the filter flag.
Filter
)
const (
// RequestComplete is the completed request flag.
RequestComplete uint8 = iota
// CantMultiplexConns is the multiplexed connections flag.
CantMultiplexConns
// Overloaded is the overloaded flag.
Overloaded
// UnknownRole is the unknown role flag.
UnknownRole
)
const (
// MaxConns is the maximum connections flag.
MaxConns string = "MAX_CONNS"
// MaxRequests is the maximum requests flag.
MaxRequests string = "MAX_REQS"
// MultiplexConns is the multiplex connections flag.
MultiplexConns string = "MPXS_CONNS"
)
const (
maxWrite = 65500 // 65530 may work, but for compatibility
maxPad = 255
)
type header struct {
Version uint8
Type uint8
ID uint16
ContentLength uint16
PaddingLength uint8
Reserved uint8
}
// for padding so we don't have to allocate all the time
// not synchronized because we don't care what the contents are
var pad [maxPad]byte
func (h *header) init(recType uint8, reqID uint16, contentLength int) {
h.Version = 1
h.Type = recType
h.ID = reqID
h.ContentLength = uint16(contentLength)
h.PaddingLength = uint8(-contentLength & 7)
}
type record struct {
h header
rbuf []byte
}
func (rec *record) read(r io.Reader) (buf []byte, err error) {
if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil {
return
}
if rec.h.Version != 1 {
err = errors.New("fcgi: invalid header version")
return
}
if rec.h.Type == EndRequest {
err = io.EOF
return
}
n := int(rec.h.ContentLength) + int(rec.h.PaddingLength)
if len(rec.rbuf) < n {
rec.rbuf = make([]byte, n)
}
if _, err = io.ReadFull(r, rec.rbuf[:n]); err != nil {
return
}
buf = rec.rbuf[:int(rec.h.ContentLength)]
return
}
// FCGIClient implements a FastCGI client, which is a standard for
// interfacing external applications with Web servers.
type FCGIClient struct {
mutex sync.Mutex
rwc io.ReadWriteCloser
h header
buf bytes.Buffer
stderr bytes.Buffer
keepAlive bool
reqID uint16
}
// DialWithDialerContext connects to the fcgi responder at the specified network address, using custom net.Dialer
// and a context.
// See func net.Dial for a description of the network and address parameters.
func DialWithDialerContext(ctx context.Context, network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) {
var conn net.Conn
conn, err = dialer.DialContext(ctx, network, address)
if err != nil {
return
}
fcgi = &FCGIClient{
rwc: conn,
keepAlive: false,
reqID: 1,
}
return
}
// DialContext is like Dial but passes ctx to dialer.Dial.
func DialContext(ctx context.Context, network, address string) (fcgi *FCGIClient, err error) {
// TODO: why not set timeout here?
return DialWithDialerContext(ctx, network, address, net.Dialer{})
}
// Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
// See func net.Dial for a description of the network and address parameters.
func Dial(network, address string) (fcgi *FCGIClient, err error) {
return DialContext(context.Background(), network, address)
}
// Close closes fcgi connection
func (c *FCGIClient) Close() {
c.rwc.Close()
}
func (c *FCGIClient) writeRecord(recType uint8, content []byte) (err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.buf.Reset()
c.h.init(recType, c.reqID, len(content))
if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
return err
}
if _, err := c.buf.Write(content); err != nil {
return err
}
if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
return err
}
_, err = c.rwc.Write(c.buf.Bytes())
return err
}
func (c *FCGIClient) writeBeginRequest(role uint16, flags uint8) error {
b := [8]byte{byte(role >> 8), byte(role), flags}
return c.writeRecord(BeginRequest, b[:])
}
func (c *FCGIClient) writeEndRequest(appStatus int, protocolStatus uint8) error {
b := make([]byte, 8)
binary.BigEndian.PutUint32(b, uint32(appStatus))
b[4] = protocolStatus
return c.writeRecord(EndRequest, b)
}
func (c *FCGIClient) writePairs(recType uint8, pairs map[string]string) error {
w := newWriter(c, recType)
b := make([]byte, 8)
nn := 0
for k, v := range pairs {
m := 8 + len(k) + len(v)
if m > maxWrite {
// param data size exceed 65535 bytes"
vl := maxWrite - 8 - len(k)
v = v[:vl]
}
n := encodeSize(b, uint32(len(k)))
n += encodeSize(b[n:], uint32(len(v)))
m = n + len(k) + len(v)
if (nn + m) > maxWrite {
w.Flush()
nn = 0
}
nn += m
if _, err := w.Write(b[:n]); err != nil {
return err
}
if _, err := w.WriteString(k); err != nil {
return err
}
if _, err := w.WriteString(v); err != nil {
return err
}
}
w.Close()
return nil
}
func encodeSize(b []byte, size uint32) int {
if size > 127 {
size |= 1 << 31
binary.BigEndian.PutUint32(b, size)
return 4
}
b[0] = byte(size)
return 1
}
// bufWriter encapsulates bufio.Writer but also closes the underlying stream when
// Closed.
type bufWriter struct {
closer io.Closer
*bufio.Writer
}
func (w *bufWriter) Close() error {
if err := w.Writer.Flush(); err != nil {
w.closer.Close()
return err
}
return w.closer.Close()
}
func newWriter(c *FCGIClient, recType uint8) *bufWriter {
s := &streamWriter{c: c, recType: recType}
w := bufio.NewWriterSize(s, maxWrite)
return &bufWriter{s, w}
}
// streamWriter abstracts out the separation of a stream into discrete records.
// It only writes maxWrite bytes at a time.
type streamWriter struct {
c *FCGIClient
recType uint8
}
func (w *streamWriter) Write(p []byte) (int, error) {
nn := 0
for len(p) > 0 {
n := len(p)
if n > maxWrite {
n = maxWrite
}
if err := w.c.writeRecord(w.recType, p[:n]); err != nil {
return nn, err
}
nn += n
p = p[n:]
}
return nn, nil
}
func (w *streamWriter) Close() error {
// send empty record to close the stream
return w.c.writeRecord(w.recType, nil)
}
type streamReader struct {
c *FCGIClient
buf []byte
}
func (w *streamReader) Read(p []byte) (n int, err error) {
if len(p) > 0 {
if len(w.buf) == 0 {
// filter outputs for error log
for {
rec := &record{}
var buf []byte
buf, err = rec.read(w.c.rwc)
if err != nil {
return
}
// standard error output
if rec.h.Type == Stderr {
w.c.stderr.Write(buf)
continue
}
w.buf = buf
break
}
}
n = len(p)
if n > len(w.buf) {
n = len(w.buf)
}
copy(p, w.buf[:n])
w.buf = w.buf[n:]
}
return
}
// Do made the request and returns a io.Reader that translates the data read
// from fcgi responder out of fcgi packet before returning it.
func (c *FCGIClient) Do(p map[string]string, req io.Reader) (r io.Reader, err error) {
err = c.writeBeginRequest(uint16(Responder), 0)
if err != nil {
return
}
err = c.writePairs(Params, p)
if err != nil {
return
}
body := newWriter(c, Stdin)
if req != nil {
_, _ = io.Copy(body, req)
}
body.Close()
r = &streamReader{c: c}
return
}
// clientCloser is a io.ReadCloser. It wraps a io.Reader with a Closer
// that closes FCGIClient connection.
type clientCloser struct {
*FCGIClient
io.Reader
}
func (f clientCloser) Close() error { return f.rwc.Close() }
// Request returns a HTTP Response with Header and Body
// from fcgi responder
func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) {
r, err := c.Do(p, req)
if err != nil {
return
}
rb := bufio.NewReader(r)
tp := textproto.NewReader(rb)
resp = new(http.Response)
// Parse the response headers.
mimeHeader, err := tp.ReadMIMEHeader()
if err != nil && err != io.EOF {
return
}
resp.Header = http.Header(mimeHeader)
if resp.Header.Get("Status") != "" {
statusParts := strings.SplitN(resp.Header.Get("Status"), " ", 2)
resp.StatusCode, err = strconv.Atoi(statusParts[0])
if err != nil {
return
}
if len(statusParts) > 1 {
resp.Status = statusParts[1]
}
} else {
resp.StatusCode = http.StatusOK
}
// TODO: fixTransferEncoding ?
resp.TransferEncoding = resp.Header["Transfer-Encoding"]
resp.ContentLength, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if chunked(resp.TransferEncoding) {
resp.Body = clientCloser{c, httputil.NewChunkedReader(rb)}
} else {
resp.Body = clientCloser{c, ioutil.NopCloser(rb)}
}
return
}
// Get issues a GET request to the fcgi responder.
func (c *FCGIClient) Get(p map[string]string, body io.Reader, l int64) (resp *http.Response, err error) {
p["REQUEST_METHOD"] = "GET"
p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
return c.Request(p, body)
}
// Head issues a HEAD request to the fcgi responder.
func (c *FCGIClient) Head(p map[string]string) (resp *http.Response, err error) {
p["REQUEST_METHOD"] = "HEAD"
p["CONTENT_LENGTH"] = "0"
return c.Request(p, nil)
}
// Options issues an OPTIONS request to the fcgi responder.
func (c *FCGIClient) Options(p map[string]string) (resp *http.Response, err error) {
p["REQUEST_METHOD"] = "OPTIONS"
p["CONTENT_LENGTH"] = "0"
return c.Request(p, nil)
}
// Post issues a POST request to the fcgi responder. with request body
// in the format that bodyType specified
func (c *FCGIClient) Post(p map[string]string, method string, bodyType string, body io.Reader, l int64) (resp *http.Response, err error) {
if p == nil {
p = make(map[string]string)
}
p["REQUEST_METHOD"] = strings.ToUpper(method)
if len(p["REQUEST_METHOD"]) == 0 || p["REQUEST_METHOD"] == "GET" {
p["REQUEST_METHOD"] = "POST"
}
p["CONTENT_LENGTH"] = strconv.FormatInt(l, 10)
if len(bodyType) > 0 {
p["CONTENT_TYPE"] = bodyType
} else {
p["CONTENT_TYPE"] = "application/x-www-form-urlencoded"
}
return c.Request(p, body)
}
// PostForm issues a POST to the fcgi responder, with form
// as a string key to a list values (url.Values)
func (c *FCGIClient) PostForm(p map[string]string, data url.Values) (resp *http.Response, err error) {
body := bytes.NewReader([]byte(data.Encode()))
return c.Post(p, "POST", "application/x-www-form-urlencoded", body, int64(body.Len()))
}
// PostFile issues a POST to the fcgi responder in multipart(RFC 2046) standard,
// with form as a string key to a list values (url.Values),
// and/or with file as a string key to a list file path.
func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[string]string) (resp *http.Response, err error) {
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
bodyType := writer.FormDataContentType()
for key, val := range data {
for _, v0 := range val {
err = writer.WriteField(key, v0)
if err != nil {
return
}
}
}
for key, val := range file {
fd, e := os.Open(val)
if e != nil {
return nil, e
}
defer fd.Close()
part, e := writer.CreateFormFile(key, filepath.Base(val))
if e != nil {
return nil, e
}
_, err = io.Copy(part, fd)
if err != nil {
return
}
}
err = writer.Close()
if err != nil {
return
}
return c.Post(p, "POST", bodyType, buf, int64(buf.Len()))
}
// SetReadTimeout sets the read timeout for future calls that read from the
// fcgi responder. A zero value for t means no timeout will be set.
func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
if conn, ok := c.rwc.(net.Conn); ok && t != 0 {
return conn.SetReadDeadline(time.Now().Add(t))
}
return nil
}
// SetWriteTimeout sets the write timeout for future calls that send data to
// the fcgi responder. A zero value for t means no timeout will be set.
func (c *FCGIClient) SetWriteTimeout(t time.Duration) error {
if conn, ok := c.rwc.(net.Conn); ok && t != 0 {
return conn.SetWriteDeadline(time.Now().Add(t))
}
return nil
}
// Checks whether chunked is part of the encodings stack
func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }

View File

@ -0,0 +1,301 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// NOTE: These tests were adapted from the original
// repository from which this package was forked.
// The tests are slow (~10s) and in dire need of rewriting.
// As such, the tests have been disabled to speed up
// automated builds until they can be properly written.
package fastcgi
import (
"bytes"
"crypto/md5"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"net/http/fcgi"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
)
// test fcgi protocol includes:
// Get, Post, Post in multipart/form-data, and Post with files
// each key should be the md5 of the value or the file uploaded
// specify remote fcgi responder ip:port to test with php
// test failed if the remote fcgi(script) failed md5 verification
// and output "FAILED" in response
const (
scriptFile = "/tank/www/fcgic_test.php"
//ipPort = "remote-php-serv:59000"
ipPort = "127.0.0.1:59000"
)
var globalt *testing.T
type FastCGIServer struct{}
func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
if err := req.ParseMultipartForm(100000000); err != nil {
log.Printf("[ERROR] failed to parse: %v", err)
}
stat := "PASSED"
fmt.Fprintln(resp, "-")
fileNum := 0
{
length := 0
for k0, v0 := range req.Form {
h := md5.New()
_, _ = io.WriteString(h, v0[0])
_md5 := fmt.Sprintf("%x", h.Sum(nil))
length += len(k0)
length += len(v0[0])
// echo error when key != _md5(val)
if _md5 != k0 {
fmt.Fprintln(resp, "server:err ", _md5, k0)
stat = "FAILED"
}
}
if req.MultipartForm != nil {
fileNum = len(req.MultipartForm.File)
for kn, fns := range req.MultipartForm.File {
//fmt.Fprintln(resp, "server:filekey ", kn )
length += len(kn)
for _, f := range fns {
fd, err := f.Open()
if err != nil {
log.Println("server:", err)
return
}
h := md5.New()
l0, err := io.Copy(h, fd)
if err != nil {
log.Println(err)
return
}
length += int(l0)
defer fd.Close()
md5 := fmt.Sprintf("%x", h.Sum(nil))
//fmt.Fprintln(resp, "server:filemd5 ", md5 )
if kn != md5 {
fmt.Fprintln(resp, "server:err ", md5, kn)
stat = "FAILED"
}
//fmt.Fprintln(resp, "server:filename ", f.Filename )
}
}
}
fmt.Fprintln(resp, "server:got data length", length)
}
fmt.Fprintln(resp, "-"+stat+"-POST(", len(req.Form), ")-FILE(", fileNum, ")--")
}
func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) {
fcgi, err := Dial("tcp", ipPort)
if err != nil {
log.Println("err:", err)
return
}
length := 0
var resp *http.Response
switch reqType {
case 0:
if len(data) > 0 {
length = len(data)
rd := bytes.NewReader(data)
resp, err = fcgi.Post(fcgiParams, "", "", rd, int64(rd.Len()))
} else if len(posts) > 0 {
values := url.Values{}
for k, v := range posts {
values.Set(k, v)
length += len(k) + 2 + len(v)
}
resp, err = fcgi.PostForm(fcgiParams, values)
} else {
rd := bytes.NewReader(data)
resp, err = fcgi.Get(fcgiParams, rd, int64(rd.Len()))
}
default:
values := url.Values{}
for k, v := range posts {
values.Set(k, v)
length += len(k) + 2 + len(v)
}
for k, v := range files {
fi, _ := os.Lstat(v)
length += len(k) + int(fi.Size())
}
resp, err = fcgi.PostFile(fcgiParams, values, files)
}
if err != nil {
log.Println("err:", err)
return
}
defer resp.Body.Close()
content, _ = ioutil.ReadAll(resp.Body)
log.Println("c: send data length ≈", length, string(content))
fcgi.Close()
time.Sleep(1 * time.Second)
if bytes.Contains(content, []byte("FAILED")) {
globalt.Error("Server return failed message")
}
return
}
func generateRandFile(size int) (p string, m string) {
p = filepath.Join(os.TempDir(), "fcgict"+strconv.Itoa(rand.Int()))
// open output file
fo, err := os.Create(p)
if err != nil {
panic(err)
}
// close fo on exit and check for its returned error
defer func() {
if err := fo.Close(); err != nil {
panic(err)
}
}()
h := md5.New()
for i := 0; i < size/16; i++ {
buf := make([]byte, 16)
binary.PutVarint(buf, rand.Int63())
if _, err := fo.Write(buf); err != nil {
log.Printf("[ERROR] failed to write buffer: %v\n", err)
}
if _, err := h.Write(buf); err != nil {
log.Printf("[ERROR] failed to write buffer: %v\n", err)
}
}
m = fmt.Sprintf("%x", h.Sum(nil))
return
}
func DisabledTest(t *testing.T) {
// TODO: test chunked reader
globalt = t
rand.Seed(time.Now().UTC().UnixNano())
// server
go func() {
listener, err := net.Listen("tcp", ipPort)
if err != nil {
log.Println("listener creation failed: ", err)
}
srv := new(FastCGIServer)
if err := fcgi.Serve(listener, srv); err != nil {
log.Print("[ERROR] failed to start server: ", err)
}
}()
time.Sleep(1 * time.Second)
// init
fcgiParams := make(map[string]string)
fcgiParams["REQUEST_METHOD"] = "GET"
fcgiParams["SERVER_PROTOCOL"] = "HTTP/1.1"
//fcgi_params["GATEWAY_INTERFACE"] = "CGI/1.1"
fcgiParams["SCRIPT_FILENAME"] = scriptFile
// simple GET
log.Println("test:", "get")
sendFcgi(0, fcgiParams, nil, nil, nil)
// simple post data
log.Println("test:", "post")
sendFcgi(0, fcgiParams, []byte("c4ca4238a0b923820dcc509a6f75849b=1&7b8b965ad4bca0e41ab51de7b31363a1=n"), nil, nil)
log.Println("test:", "post data (more than 60KB)")
data := ""
for i := 0x00; i < 0xff; i++ {
v0 := strings.Repeat(string(i), 256)
h := md5.New()
_, _ = io.WriteString(h, v0)
k0 := fmt.Sprintf("%x", h.Sum(nil))
data += k0 + "=" + url.QueryEscape(v0) + "&"
}
sendFcgi(0, fcgiParams, []byte(data), nil, nil)
log.Println("test:", "post form (use url.Values)")
p0 := make(map[string]string, 1)
p0["c4ca4238a0b923820dcc509a6f75849b"] = "1"
p0["7b8b965ad4bca0e41ab51de7b31363a1"] = "n"
sendFcgi(1, fcgiParams, nil, p0, nil)
log.Println("test:", "post forms (256 keys, more than 1MB)")
p1 := make(map[string]string, 1)
for i := 0x00; i < 0xff; i++ {
v0 := strings.Repeat(string(i), 4096)
h := md5.New()
_, _ = io.WriteString(h, v0)
k0 := fmt.Sprintf("%x", h.Sum(nil))
p1[k0] = v0
}
sendFcgi(1, fcgiParams, nil, p1, nil)
log.Println("test:", "post file (1 file, 500KB)) ")
f0 := make(map[string]string, 1)
path0, m0 := generateRandFile(500000)
f0[m0] = path0
sendFcgi(1, fcgiParams, nil, p1, f0)
log.Println("test:", "post multiple files (2 files, 5M each) and forms (256 keys, more than 1MB data")
path1, m1 := generateRandFile(5000000)
f0[m1] = path1
sendFcgi(1, fcgiParams, nil, p1, f0)
log.Println("test:", "post only files (2 files, 5M each)")
sendFcgi(1, fcgiParams, nil, nil, f0)
log.Println("test:", "post only 1 file")
delete(f0, "m0")
sendFcgi(1, fcgiParams, nil, nil, f0)
if err := os.Remove(path0); err != nil {
log.Println("[ERROR] failed to remove path: ", err)
}
if err := os.Remove(path1); err != nil {
log.Println("[ERROR] failed to remove path: ", err)
}
}

View File

@ -0,0 +1,342 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fastcgi
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/caddyserver/caddy/v2/modules/caddytls"
"github.com/caddyserver/caddy/v2"
)
func init() {
caddy.RegisterModule(Transport{})
}
type Transport struct {
//////////////////////////////
// TODO: taken from v1 Handler type
SoftwareName string
SoftwareVersion string
ServerName string
ServerPort string
//////////////////////////
// TODO: taken from v1 Rule type
// The base path to match. Required.
// Path string
// upstream load balancer
// balancer
// Always process files with this extension with fastcgi.
// Ext string
// Use this directory as the fastcgi root directory. Defaults to the root
// directory of the parent virtual host.
Root string
// The path in the URL will be split into two, with the first piece ending
// with the value of SplitPath. The first piece will be assumed as the
// actual resource (CGI script) name, and the second piece will be set to
// PATH_INFO for the CGI script to use.
SplitPath string
// If the URL ends with '/' (which indicates a directory), these index
// files will be tried instead.
IndexFiles []string
// Environment Variables
EnvVars [][2]string
// Ignored paths
IgnoredSubPaths []string
// The duration used to set a deadline when connecting to an upstream.
DialTimeout time.Duration
// The duration used to set a deadline when reading from the FastCGI server.
ReadTimeout time.Duration
// The duration used to set a deadline when sending to the FastCGI server.
WriteTimeout time.Duration
}
// CaddyModule returns the Caddy module information.
func (Transport) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.transport.fastcgi",
New: func() caddy.Module { return new(Transport) },
}
}
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
// Create environment for CGI script
env, err := t.buildEnv(r)
if err != nil {
return nil, fmt.Errorf("building environment: %v", err)
}
// TODO:
// Connect to FastCGI gateway
// address, err := f.Address()
// if err != nil {
// return http.StatusBadGateway, err
// }
// network, address := parseAddress(address)
network, address := "tcp", r.URL.Host // TODO:
ctx := context.Background()
if t.DialTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, t.DialTimeout)
defer cancel()
}
fcgiBackend, err := DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("dialing backend: %v", err)
}
// fcgiBackend is closed when response body is closed (see clientCloser)
// read/write timeouts
if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil {
return nil, fmt.Errorf("setting read timeout: %v", err)
}
if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil {
return nil, fmt.Errorf("setting write timeout: %v", err)
}
var resp *http.Response
var contentLength int64
// if ContentLength is already set
if r.ContentLength > 0 {
contentLength = r.ContentLength
} else {
contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
}
switch r.Method {
case "HEAD":
resp, err = fcgiBackend.Head(env)
case "GET":
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
case "OPTIONS":
resp, err = fcgiBackend.Options(env)
default:
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
}
// TODO:
return resp, err
// Stuff brought over from v1 that might not be necessary here:
// if resp != nil && resp.Body != nil {
// defer resp.Body.Close()
// }
// if err != nil {
// if err, ok := err.(net.Error); ok && err.Timeout() {
// return http.StatusGatewayTimeout, err
// } else if err != io.EOF {
// return http.StatusBadGateway, err
// }
// }
// // Write response header
// writeHeader(w, resp)
// // Write the response body
// _, err = io.Copy(w, resp.Body)
// if err != nil {
// return http.StatusBadGateway, err
// }
// // Log any stderr output from upstream
// if fcgiBackend.stderr.Len() != 0 {
// // Remove trailing newline, error logger already does this.
// err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
// }
// // Normally we would return the status code if it is an error status (>= 400),
// // however, upstream FastCGI apps don't know about our contract and have
// // probably already written an error page. So we just return 0, indicating
// // that the response body is already written. However, we do return any
// // error value so it can be logged.
// // Note that the proxy middleware works the same way, returning status=0.
// return 0, err
}
// buildEnv returns a set of CGI environment variables for the request.
func (t Transport) buildEnv(r *http.Request) (map[string]string, error) {
var env map[string]string
// Separate remote IP and port; more lenient than net.SplitHostPort
var ip, port string
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 {
ip = r.RemoteAddr[:idx]
port = r.RemoteAddr[idx+1:]
} else {
ip = r.RemoteAddr
}
// Remove [] from IPv6 addresses
ip = strings.Replace(ip, "[", "", 1)
ip = strings.Replace(ip, "]", "", 1)
// TODO: respect index files? or leave that to matcher/rewrite (I prefer that)?
fpath := r.URL.Path
// Split path in preparation for env variables.
// Previous canSplit checks ensure this can never be -1.
// TODO: I haven't brought over canSplit; make sure this doesn't break
splitPos := t.splitPos(fpath)
// Request has the extension; path was split successfully
docURI := fpath[:splitPos+len(t.SplitPath)]
pathInfo := fpath[splitPos+len(t.SplitPath):]
scriptName := fpath
// Strip PATH_INFO from SCRIPT_NAME
scriptName = strings.TrimSuffix(scriptName, pathInfo)
// SCRIPT_FILENAME is the absolute path of SCRIPT_NAME
scriptFilename := filepath.Join(t.Root, scriptName)
// Add vhost path prefix to scriptName. Otherwise, some PHP software will
// have difficulty discovering its URL.
pathPrefix, _ := r.Context().Value(caddy.CtxKey("path_prefix")).(string)
scriptName = path.Join(pathPrefix, scriptName)
// TODO: Disabled for now
// // Get the request URI from context. The context stores the original URI in case
// // it was changed by a middleware such as rewrite. By default, we pass the
// // original URI in as the value of REQUEST_URI (the user can overwrite this
// // if desired). Most PHP apps seem to want the original URI. Besides, this is
// // how nginx defaults: http://stackoverflow.com/a/12485156/1048862
// reqURL, _ := r.Context().Value(httpserver.OriginalURLCtxKey).(url.URL)
// // Retrieve name of remote user that was set by some downstream middleware such as basicauth.
// remoteUser, _ := r.Context().Value(httpserver.RemoteUserCtxKey).(string)
requestScheme := "http"
if r.TLS != nil {
requestScheme = "https"
}
// Some variables are unused but cleared explicitly to prevent
// the parent environment from interfering.
env = map[string]string{
// Variables defined in CGI 1.1 spec
"AUTH_TYPE": "", // Not used
"CONTENT_LENGTH": r.Header.Get("Content-Length"),
"CONTENT_TYPE": r.Header.Get("Content-Type"),
"GATEWAY_INTERFACE": "CGI/1.1",
"PATH_INFO": pathInfo,
"QUERY_STRING": r.URL.RawQuery,
"REMOTE_ADDR": ip,
"REMOTE_HOST": ip, // For speed, remote host lookups disabled
"REMOTE_PORT": port,
"REMOTE_IDENT": "", // Not used
// "REMOTE_USER": remoteUser, // TODO:
"REQUEST_METHOD": r.Method,
"REQUEST_SCHEME": requestScheme,
"SERVER_NAME": t.ServerName,
"SERVER_PORT": t.ServerPort,
"SERVER_PROTOCOL": r.Proto,
"SERVER_SOFTWARE": t.SoftwareName + "/" + t.SoftwareVersion,
// Other variables
// "DOCUMENT_ROOT": rule.Root,
"DOCUMENT_URI": docURI,
"HTTP_HOST": r.Host, // added here, since not always part of headers
// "REQUEST_URI": reqURL.RequestURI(), // TODO:
"SCRIPT_FILENAME": scriptFilename,
"SCRIPT_NAME": scriptName,
}
// compliance with the CGI specification requires that
// PATH_TRANSLATED should only exist if PATH_INFO is defined.
// Info: https://www.ietf.org/rfc/rfc3875 Page 14
if env["PATH_INFO"] != "" {
env["PATH_TRANSLATED"] = filepath.Join(t.Root, pathInfo) // Info: http://www.oreilly.com/openbook/cgi/ch02_04.html
}
// Some web apps rely on knowing HTTPS or not
if r.TLS != nil {
env["HTTPS"] = "on"
// and pass the protocol details in a manner compatible with apache's mod_ssl
// (which is why these have a SSL_ prefix and not TLS_).
v, ok := tlsProtocolStrings[r.TLS.Version]
if ok {
env["SSL_PROTOCOL"] = v
}
// and pass the cipher suite in a manner compatible with apache's mod_ssl
for k, v := range caddytls.SupportedCipherSuites {
if v == r.TLS.CipherSuite {
env["SSL_CIPHER"] = k
break
}
}
}
// Add env variables from config (with support for placeholders in values)
repl := r.Context().Value(caddy.ReplacerCtxKey).(caddy.Replacer)
for _, envVar := range t.EnvVars {
env[envVar[0]] = repl.ReplaceAll(envVar[1], "")
}
// Add all HTTP headers to env variables
for field, val := range r.Header {
header := strings.ToUpper(field)
header = headerNameReplacer.Replace(header)
env["HTTP_"+header] = strings.Join(val, ", ")
}
return env, nil
}
// splitPos returns the index where path should
// be split based on t.SplitPath.
func (t Transport) splitPos(path string) int {
// TODO:
// if httpserver.CaseSensitivePath {
// return strings.Index(path, r.SplitPath)
// }
return strings.Index(strings.ToLower(path), strings.ToLower(t.SplitPath))
}
// TODO:
// Map of supported protocols to Apache ssl_mod format
// Note that these are slightly different from SupportedProtocols in caddytls/config.go
var tlsProtocolStrings = map[uint16]string{
tls.VersionTLS10: "TLSv1",
tls.VersionTLS11: "TLSv1.1",
tls.VersionTLS12: "TLSv1.2",
tls.VersionTLS13: "TLSv1.3",
}
var headerNameReplacer = strings.NewReplacer(" ", "_", "-", "_")

View File

@ -1,86 +0,0 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
import (
"net/http"
"time"
)
// Upstream represents the interface that must be satisfied to use the healthchecker.
type Upstream interface {
SetHealthiness(bool)
}
// HealthChecker represents a worker that periodically evaluates if proxy upstream host is healthy.
type HealthChecker struct {
upstream Upstream
Ticker *time.Ticker
HTTPClient *http.Client
StopChan chan bool
}
// ScheduleChecks periodically runs health checks against an upstream host.
func (h *HealthChecker) ScheduleChecks(url string) {
// check if a host is healthy on start vs waiting for timer
h.upstream.SetHealthiness(h.IsHealthy(url))
stop := make(chan bool)
h.StopChan = stop
go func() {
for {
select {
case <-h.Ticker.C:
h.upstream.SetHealthiness(h.IsHealthy(url))
case <-stop:
return
}
}
}()
}
// Stop stops the healthchecker from makeing further requests.
func (h *HealthChecker) Stop() {
h.Ticker.Stop()
close(h.StopChan)
}
// IsHealthy attempts to check if a upstream host is healthy.
func (h *HealthChecker) IsHealthy(url string) bool {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return false
}
resp, err := h.HTTPClient.Do(req)
if err != nil {
return false
}
if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return false
}
return true
}
// NewHealthCheckWorker returns a new instance of a HealthChecker.
func NewHealthCheckWorker(u Upstream, interval time.Duration, client *http.Client) *HealthChecker {
return &HealthChecker{
upstream: u,
Ticker: time.NewTicker(interval),
HTTPClient: client,
}
}

View File

@ -0,0 +1,133 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
import (
"net"
"net/http"
"time"
"github.com/caddyserver/caddy/v2"
)
func init() {
caddy.RegisterModule(HTTPTransport{})
}
// TODO: This is the default transport, basically just http.Transport, but we define JSON struct tags...
type HTTPTransport struct {
// TODO: Actually this is where the TLS config should go, technically...
// as well as keepalives and dial timeouts...
// TODO: It's possible that other transports (like fastcgi) might be
// able to borrow/use at least some of these config fields; if so,
// move them into a type called CommonTransport and embed it
TLS *TLSConfig `json:"tls,omitempty"`
KeepAlive *KeepAlive `json:"keep_alive,omitempty"`
Compression *bool `json:"compression,omitempty"`
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // TODO: NOTE: we use our health check stuff to enforce max REQUESTS per host, but this is connections
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
FallbackDelay caddy.Duration `json:"dial_fallback_delay,omitempty"`
ResponseHeaderTimeout caddy.Duration `json:"response_header_timeout,omitempty"`
ExpectContinueTimeout caddy.Duration `json:"expect_continue_timeout,omitempty"`
MaxResponseHeaderSize int64 `json:"max_response_header_size,omitempty"`
WriteBufferSize int `json:"write_buffer_size,omitempty"`
ReadBufferSize int `json:"read_buffer_size,omitempty"`
// TODO: ProxyConnectHeader?
RoundTripper http.RoundTripper `json:"-"`
}
// CaddyModule returns the Caddy module information.
func (HTTPTransport) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.transport.http",
New: func() caddy.Module { return new(HTTPTransport) },
}
}
func (h *HTTPTransport) Provision(ctx caddy.Context) error {
dialer := &net.Dialer{
Timeout: time.Duration(h.DialTimeout),
FallbackDelay: time.Duration(h.FallbackDelay),
// TODO: Resolver
}
rt := &http.Transport{
DialContext: dialer.DialContext,
MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
MaxResponseHeaderBytes: h.MaxResponseHeaderSize,
WriteBufferSize: h.WriteBufferSize,
ReadBufferSize: h.ReadBufferSize,
}
if h.TLS != nil {
rt.TLSHandshakeTimeout = time.Duration(h.TLS.HandshakeTimeout)
// TODO: rest of TLS config
}
if h.KeepAlive != nil {
dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval)
if enabled := h.KeepAlive.Enabled; enabled != nil {
rt.DisableKeepAlives = !*enabled
}
rt.MaxIdleConns = h.KeepAlive.MaxIdleConns
rt.MaxIdleConnsPerHost = h.KeepAlive.MaxIdleConnsPerHost
rt.IdleConnTimeout = time.Duration(h.KeepAlive.IdleConnTimeout)
}
if h.Compression != nil {
rt.DisableCompression = !*h.Compression
}
h.RoundTripper = rt
return nil
}
func (h HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return h.RoundTripper.RoundTrip(req)
}
type TLSConfig struct {
CAPool []string `json:"ca_pool,omitempty"`
ClientCertificate string `json:"client_certificate,omitempty"`
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
HandshakeTimeout caddy.Duration `json:"handshake_timeout,omitempty"`
}
type KeepAlive struct {
Enabled *bool `json:"enabled,omitempty"`
ProbeInterval caddy.Duration `json:"probe_interval,omitempty"`
MaxIdleConns int `json:"max_idle_conns,omitempty"`
MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"`
IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle
}
var (
defaultDialer = net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
// TODO: does this need to be configured to enable HTTP/2?
defaultTransport = &http.Transport{
DialContext: defaultDialer.DialContext,
TLSHandshakeTimeout: 5 * time.Second,
IdleConnTimeout: 2 * time.Minute,
}
)

View File

@ -1,53 +0,0 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
import (
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
func init() {
caddy.RegisterModule(new(LoadBalanced))
httpcaddyfile.RegisterHandlerDirective("reverse_proxy", parseCaddyfile) // TODO: "proxy"?
}
// CaddyModule returns the Caddy module information.
func (*LoadBalanced) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy",
New: func() caddy.Module { return new(LoadBalanced) },
}
}
// parseCaddyfile sets up the handler from Caddyfile tokens. Syntax:
//
// proxy [<matcher>] <to>
//
// TODO: This needs to be finished. It definitely needs to be able to open a block...
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
lb := new(LoadBalanced)
for h.Next() {
allTo := h.RemainingArgs()
if len(allTo) == 0 {
return nil, h.ArgErr()
}
for _, to := range allTo {
lb.Upstreams = append(lb.Upstreams, &UpstreamConfig{Host: to})
}
}
return lb, nil
}

868
modules/caddyhttp/reverseproxy/reverseproxy.go Executable file → Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,351 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
import (
"fmt"
"hash/fnv"
weakrand "math/rand"
"net"
"net/http"
"sync/atomic"
"time"
"github.com/caddyserver/caddy/v2"
)
func init() {
caddy.RegisterModule(RandomSelection{})
caddy.RegisterModule(RandomChoiceSelection{})
caddy.RegisterModule(LeastConnSelection{})
caddy.RegisterModule(RoundRobinSelection{})
caddy.RegisterModule(FirstSelection{})
caddy.RegisterModule(IPHashSelection{})
caddy.RegisterModule(URIHashSelection{})
caddy.RegisterModule(HeaderHashSelection{})
weakrand.Seed(time.Now().UTC().UnixNano())
}
// RandomSelection is a policy that selects
// an available host at random.
type RandomSelection struct{}
// CaddyModule returns the Caddy module information.
func (RandomSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.selection_policies.random",
New: func() caddy.Module { return new(RandomSelection) },
}
}
// Select returns an available host, if any.
func (r RandomSelection) Select(pool HostPool, request *http.Request) *Upstream {
// use reservoir sampling because the number of available
// hosts isn't known: https://en.wikipedia.org/wiki/Reservoir_sampling
var randomHost *Upstream
var count int
for _, upstream := range pool {
if !upstream.Available() {
continue
}
// (n % 1 == 0) holds for all n, therefore a
// upstream will always be chosen if there is at
// least one available
count++
if (weakrand.Int() % count) == 0 {
randomHost = upstream
}
}
return randomHost
}
// RandomChoiceSelection is a policy that selects
// two or more available hosts at random, then
// chooses the one with the least load.
type RandomChoiceSelection struct {
Choose int `json:"choose,omitempty"`
}
// CaddyModule returns the Caddy module information.
func (RandomChoiceSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.selection_policies.random_choice",
New: func() caddy.Module { return new(RandomChoiceSelection) },
}
}
func (r *RandomChoiceSelection) Provision(ctx caddy.Context) error {
if r.Choose == 0 {
r.Choose = 2
}
return nil
}
func (r RandomChoiceSelection) Validate() error {
if r.Choose < 2 {
return fmt.Errorf("choose must be at least 2")
}
return nil
}
// Select returns an available host, if any.
func (r RandomChoiceSelection) Select(pool HostPool, _ *http.Request) *Upstream {
k := r.Choose
if k > len(pool) {
k = len(pool)
}
choices := make([]*Upstream, k)
for i, upstream := range pool {
if !upstream.Available() {
continue
}
j := weakrand.Intn(i)
if j < k {
choices[j] = upstream
}
}
return leastRequests(choices)
}
// LeastConnSelection is a policy that selects the
// host with the least active requests. If multiple
// hosts have the same fewest number, one is chosen
// randomly. The term "conn" or "connection" is used
// in this policy name due to its similar meaning in
// other software, but our load balancer actually
// counts active requests rather than connections,
// since these days requests are multiplexed onto
// shared connections.
type LeastConnSelection struct{}
// CaddyModule returns the Caddy module information.
func (LeastConnSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.selection_policies.least_conn",
New: func() caddy.Module { return new(LeastConnSelection) },
}
}
// Select selects the up host with the least number of connections in the
// pool. If more than one host has the same least number of connections,
// one of the hosts is chosen at random.
func (LeastConnSelection) Select(pool HostPool, _ *http.Request) *Upstream {
var bestHost *Upstream
var count int
var leastReqs int
for _, host := range pool {
if !host.Available() {
continue
}
numReqs := host.NumRequests()
if numReqs < leastReqs {
leastReqs = numReqs
count = 0
}
// among hosts with same least connections, perform a reservoir
// sample: https://en.wikipedia.org/wiki/Reservoir_sampling
if numReqs == leastReqs {
count++
if (weakrand.Int() % count) == 0 {
bestHost = host
}
}
}
return bestHost
}
// RoundRobinSelection is a policy that selects
// a host based on round-robin ordering.
type RoundRobinSelection struct {
robin uint32
}
// CaddyModule returns the Caddy module information.
func (RoundRobinSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.selection_policies.round_robin",
New: func() caddy.Module { return new(RoundRobinSelection) },
}
}
// Select returns an available host, if any.
func (r *RoundRobinSelection) Select(pool HostPool, _ *http.Request) *Upstream {
n := uint32(len(pool))
if n == 0 {
return nil
}
for i := uint32(0); i < n; i++ {
atomic.AddUint32(&r.robin, 1)
host := pool[r.robin%n]
if host.Available() {
return host
}
}
return nil
}
// FirstSelection is a policy that selects
// the first available host.
type FirstSelection struct{}
// CaddyModule returns the Caddy module information.
func (FirstSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.selection_policies.first",
New: func() caddy.Module { return new(FirstSelection) },
}
}
// Select returns an available host, if any.
func (FirstSelection) Select(pool HostPool, _ *http.Request) *Upstream {
for _, host := range pool {
if host.Available() {
return host
}
}
return nil
}
// IPHashSelection is a policy that selects a host
// based on hashing the remote IP of the request.
type IPHashSelection struct{}
// CaddyModule returns the Caddy module information.
func (IPHashSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.selection_policies.ip_hash",
New: func() caddy.Module { return new(IPHashSelection) },
}
}
// Select returns an available host, if any.
func (IPHashSelection) Select(pool HostPool, req *http.Request) *Upstream {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
clientIP = req.RemoteAddr
}
return hostByHashing(pool, clientIP)
}
// URIHashSelection is a policy that selects a
// host by hashing the request URI.
type URIHashSelection struct{}
// CaddyModule returns the Caddy module information.
func (URIHashSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.selection_policies.uri_hash",
New: func() caddy.Module { return new(URIHashSelection) },
}
}
// Select returns an available host, if any.
func (URIHashSelection) Select(pool HostPool, req *http.Request) *Upstream {
return hostByHashing(pool, req.RequestURI)
}
// HeaderHashSelection is a policy that selects
// a host based on a given request header.
type HeaderHashSelection struct {
Field string `json:"field,omitempty"`
}
// CaddyModule returns the Caddy module information.
func (HeaderHashSelection) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
Name: "http.handlers.reverse_proxy.selection_policies.header",
New: func() caddy.Module { return new(HeaderHashSelection) },
}
}
// Select returns an available host, if any.
func (s HeaderHashSelection) Select(pool HostPool, req *http.Request) *Upstream {
if s.Field == "" {
return nil
}
val := req.Header.Get(s.Field)
if val == "" {
return RandomSelection{}.Select(pool, req)
}
return hostByHashing(pool, val)
}
// leastRequests returns the host with the
// least number of active requests to it.
// If more than one host has the same
// least number of active requests, then
// one of those is chosen at random.
func leastRequests(upstreams []*Upstream) *Upstream {
if len(upstreams) == 0 {
return nil
}
var best []*Upstream
var bestReqs int
for _, upstream := range upstreams {
reqs := upstream.NumRequests()
if reqs == 0 {
return upstream
}
if reqs <= bestReqs {
bestReqs = reqs
best = append(best, upstream)
}
}
return best[weakrand.Intn(len(best))]
}
// hostByHashing returns an available host
// from pool based on a hashable string s.
func hostByHashing(pool []*Upstream, s string) *Upstream {
poolLen := uint32(len(pool))
if poolLen == 0 {
return nil
}
index := hash(s) % poolLen
for i := uint32(0); i < poolLen; i++ {
index += i
upstream := pool[index%poolLen]
if upstream.Available() {
return upstream
}
}
return nil
}
// hash calculates a fast hash based on s.
func hash(s string) uint32 {
h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}
// Interface guards
var (
_ Selector = (*RandomSelection)(nil)
_ Selector = (*RandomChoiceSelection)(nil)
_ Selector = (*LeastConnSelection)(nil)
_ Selector = (*RoundRobinSelection)(nil)
_ Selector = (*FirstSelection)(nil)
_ Selector = (*IPHashSelection)(nil)
_ Selector = (*URIHashSelection)(nil)
_ Selector = (*HeaderHashSelection)(nil)
_ caddy.Validator = (*RandomChoiceSelection)(nil)
_ caddy.Provisioner = (*RandomChoiceSelection)(nil)
)

View File

@ -0,0 +1,363 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
// TODO: finish migrating these
// import (
// "net/http"
// "net/http/httptest"
// "os"
// "testing"
// )
// var workableServer *httptest.Server
// func TestMain(m *testing.M) {
// workableServer = httptest.NewServer(http.HandlerFunc(
// func(w http.ResponseWriter, r *http.Request) {
// // do nothing
// }))
// r := m.Run()
// workableServer.Close()
// os.Exit(r)
// }
// type customPolicy struct{}
// func (customPolicy) Select(pool HostPool, _ *http.Request) Host {
// return pool[0]
// }
// func testPool() HostPool {
// pool := []*UpstreamHost{
// {
// Name: workableServer.URL, // this should resolve (healthcheck test)
// },
// {
// Name: "http://localhost:99998", // this shouldn't
// },
// {
// Name: "http://C",
// },
// }
// return HostPool(pool)
// }
// func TestRoundRobinPolicy(t *testing.T) {
// pool := testPool()
// rrPolicy := &RoundRobin{}
// request, _ := http.NewRequest("GET", "/", nil)
// h := rrPolicy.Select(pool, request)
// // First selected host is 1, because counter starts at 0
// // and increments before host is selected
// if h != pool[1] {
// t.Error("Expected first round robin host to be second host in the pool.")
// }
// h = rrPolicy.Select(pool, request)
// if h != pool[2] {
// t.Error("Expected second round robin host to be third host in the pool.")
// }
// h = rrPolicy.Select(pool, request)
// if h != pool[0] {
// t.Error("Expected third round robin host to be first host in the pool.")
// }
// // mark host as down
// pool[1].Unhealthy = 1
// h = rrPolicy.Select(pool, request)
// if h != pool[2] {
// t.Error("Expected to skip down host.")
// }
// // mark host as up
// pool[1].Unhealthy = 0
// h = rrPolicy.Select(pool, request)
// if h == pool[2] {
// t.Error("Expected to balance evenly among healthy hosts")
// }
// // mark host as full
// pool[1].Conns = 1
// pool[1].MaxConns = 1
// h = rrPolicy.Select(pool, request)
// if h != pool[2] {
// t.Error("Expected to skip full host.")
// }
// }
// func TestLeastConnPolicy(t *testing.T) {
// pool := testPool()
// lcPolicy := &LeastConn{}
// request, _ := http.NewRequest("GET", "/", nil)
// pool[0].Conns = 10
// pool[1].Conns = 10
// h := lcPolicy.Select(pool, request)
// if h != pool[2] {
// t.Error("Expected least connection host to be third host.")
// }
// pool[2].Conns = 100
// h = lcPolicy.Select(pool, request)
// if h != pool[0] && h != pool[1] {
// t.Error("Expected least connection host to be first or second host.")
// }
// }
// func TestCustomPolicy(t *testing.T) {
// pool := testPool()
// customPolicy := &customPolicy{}
// request, _ := http.NewRequest("GET", "/", nil)
// h := customPolicy.Select(pool, request)
// if h != pool[0] {
// t.Error("Expected custom policy host to be the first host.")
// }
// }
// func TestIPHashPolicy(t *testing.T) {
// pool := testPool()
// ipHash := &IPHash{}
// request, _ := http.NewRequest("GET", "/", nil)
// // We should be able to predict where every request is routed.
// request.RemoteAddr = "172.0.0.1:80"
// h := ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// request.RemoteAddr = "172.0.0.2:80"
// h = ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// request.RemoteAddr = "172.0.0.3:80"
// h = ipHash.Select(pool, request)
// if h != pool[2] {
// t.Error("Expected ip hash policy host to be the third host.")
// }
// request.RemoteAddr = "172.0.0.4:80"
// h = ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// // we should get the same results without a port
// request.RemoteAddr = "172.0.0.1"
// h = ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// request.RemoteAddr = "172.0.0.2"
// h = ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// request.RemoteAddr = "172.0.0.3"
// h = ipHash.Select(pool, request)
// if h != pool[2] {
// t.Error("Expected ip hash policy host to be the third host.")
// }
// request.RemoteAddr = "172.0.0.4"
// h = ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// // we should get a healthy host if the original host is unhealthy and a
// // healthy host is available
// request.RemoteAddr = "172.0.0.1"
// pool[1].Unhealthy = 1
// h = ipHash.Select(pool, request)
// if h != pool[2] {
// t.Error("Expected ip hash policy host to be the third host.")
// }
// request.RemoteAddr = "172.0.0.2"
// h = ipHash.Select(pool, request)
// if h != pool[2] {
// t.Error("Expected ip hash policy host to be the third host.")
// }
// pool[1].Unhealthy = 0
// request.RemoteAddr = "172.0.0.3"
// pool[2].Unhealthy = 1
// h = ipHash.Select(pool, request)
// if h != pool[0] {
// t.Error("Expected ip hash policy host to be the first host.")
// }
// request.RemoteAddr = "172.0.0.4"
// h = ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// // We should be able to resize the host pool and still be able to predict
// // where a request will be routed with the same IP's used above
// pool = []*UpstreamHost{
// {
// Name: workableServer.URL, // this should resolve (healthcheck test)
// },
// {
// Name: "http://localhost:99998", // this shouldn't
// },
// }
// pool = HostPool(pool)
// request.RemoteAddr = "172.0.0.1:80"
// h = ipHash.Select(pool, request)
// if h != pool[0] {
// t.Error("Expected ip hash policy host to be the first host.")
// }
// request.RemoteAddr = "172.0.0.2:80"
// h = ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// request.RemoteAddr = "172.0.0.3:80"
// h = ipHash.Select(pool, request)
// if h != pool[0] {
// t.Error("Expected ip hash policy host to be the first host.")
// }
// request.RemoteAddr = "172.0.0.4:80"
// h = ipHash.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected ip hash policy host to be the second host.")
// }
// // We should get nil when there are no healthy hosts
// pool[0].Unhealthy = 1
// pool[1].Unhealthy = 1
// h = ipHash.Select(pool, request)
// if h != nil {
// t.Error("Expected ip hash policy host to be nil.")
// }
// }
// func TestFirstPolicy(t *testing.T) {
// pool := testPool()
// firstPolicy := &First{}
// req := httptest.NewRequest(http.MethodGet, "/", nil)
// h := firstPolicy.Select(pool, req)
// if h != pool[0] {
// t.Error("Expected first policy host to be the first host.")
// }
// pool[0].Unhealthy = 1
// h = firstPolicy.Select(pool, req)
// if h != pool[1] {
// t.Error("Expected first policy host to be the second host.")
// }
// }
// func TestUriPolicy(t *testing.T) {
// pool := testPool()
// uriPolicy := &URIHash{}
// request := httptest.NewRequest(http.MethodGet, "/test", nil)
// h := uriPolicy.Select(pool, request)
// if h != pool[0] {
// t.Error("Expected uri policy host to be the first host.")
// }
// pool[0].Unhealthy = 1
// h = uriPolicy.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected uri policy host to be the first host.")
// }
// request = httptest.NewRequest(http.MethodGet, "/test_2", nil)
// h = uriPolicy.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected uri policy host to be the second host.")
// }
// // We should be able to resize the host pool and still be able to predict
// // where a request will be routed with the same URI's used above
// pool = []*UpstreamHost{
// {
// Name: workableServer.URL, // this should resolve (healthcheck test)
// },
// {
// Name: "http://localhost:99998", // this shouldn't
// },
// }
// request = httptest.NewRequest(http.MethodGet, "/test", nil)
// h = uriPolicy.Select(pool, request)
// if h != pool[0] {
// t.Error("Expected uri policy host to be the first host.")
// }
// pool[0].Unhealthy = 1
// h = uriPolicy.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected uri policy host to be the first host.")
// }
// request = httptest.NewRequest(http.MethodGet, "/test_2", nil)
// h = uriPolicy.Select(pool, request)
// if h != pool[1] {
// t.Error("Expected uri policy host to be the second host.")
// }
// pool[0].Unhealthy = 1
// pool[1].Unhealthy = 1
// h = uriPolicy.Select(pool, request)
// if h != nil {
// t.Error("Expected uri policy policy host to be nil.")
// }
// }
// func TestHeaderPolicy(t *testing.T) {
// pool := testPool()
// tests := []struct {
// Name string
// Policy *Header
// RequestHeaderName string
// RequestHeaderValue string
// NilHost bool
// HostIndex int
// }{
// {"empty config", &Header{""}, "", "", true, 0},
// {"empty config+header+value", &Header{""}, "Affinity", "somevalue", true, 0},
// {"empty config+header", &Header{""}, "Affinity", "", true, 0},
// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 1},
// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 2},
// {"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 0},
// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue", false, 1},
// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue2", false, 0},
// {"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue3", false, 2},
// {"hash route with empty value", &Header{"Affinity"}, "Affinity", "", false, 1},
// }
// for idx, test := range tests {
// request, _ := http.NewRequest("GET", "/", nil)
// if test.RequestHeaderName != "" {
// request.Header.Add(test.RequestHeaderName, test.RequestHeaderValue)
// }
// host := test.Policy.Select(pool, request)
// if test.NilHost && host != nil {
// t.Errorf("%d: Expected host to be nil", idx)
// }
// if !test.NilHost && host == nil {
// t.Errorf("%d: Did not expect host to be nil", idx)
// }
// if !test.NilHost && host != pool[test.HostIndex] {
// t.Errorf("%d: Expected Header policy to be host %d", idx, test.HostIndex)
// }
// }
// }

View File

@ -1,450 +0,0 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package reverseproxy implements a load-balanced reverse proxy.
package reverseproxy
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"net"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)
// CircuitBreaker defines the functionality of a circuit breaker module.
type CircuitBreaker interface {
Ok() bool
RecordMetric(statusCode int, latency time.Duration)
}
type noopCircuitBreaker struct{}
func (ncb noopCircuitBreaker) RecordMetric(statusCode int, latency time.Duration) {}
func (ncb noopCircuitBreaker) Ok() bool {
return true
}
const (
// TypeBalanceRoundRobin represents the value to use for configuring a load balanced reverse proxy to use round robin load balancing.
TypeBalanceRoundRobin = iota
// TypeBalanceRandom represents the value to use for configuring a load balanced reverse proxy to use random load balancing.
TypeBalanceRandom
// TODO: add random with two choices
// msgNoHealthyUpstreams is returned if there are no upstreams that are healthy to proxy a request to
msgNoHealthyUpstreams = "No healthy upstreams."
// by default perform health checks every 30 seconds
defaultHealthCheckDur = time.Second * 30
// used when an upstream is unhealthy, health checks can be configured to perform at a faster rate
defaultFastHealthCheckDur = time.Second * 1
)
var (
// defaultTransport is the default transport to use for the reverse proxy.
defaultTransport = &http.Transport{
Dial: (&net.Dialer{
Timeout: 5 * time.Second,
}).Dial,
TLSHandshakeTimeout: 5 * time.Second,
}
// defaultHTTPClient is the default http client to use for the healthchecker.
defaultHTTPClient = &http.Client{
Timeout: time.Second * 10,
Transport: defaultTransport,
}
// typeMap maps caddy load balance configuration to the internal representation of the loadbalance algorithm type.
typeMap = map[string]int{
"round_robin": TypeBalanceRoundRobin,
"random": TypeBalanceRandom,
}
)
// NewLoadBalancedReverseProxy returns a collection of Upstreams that are to be loadbalanced.
func NewLoadBalancedReverseProxy(lb *LoadBalanced, ctx caddy.Context) error {
// set defaults
if lb.NoHealthyUpstreamsMessage == "" {
lb.NoHealthyUpstreamsMessage = msgNoHealthyUpstreams
}
if lb.TryInterval == "" {
lb.TryInterval = "20s"
}
// set request retry interval
ti, err := time.ParseDuration(lb.TryInterval)
if err != nil {
return fmt.Errorf("NewLoadBalancedReverseProxy: %v", err.Error())
}
lb.tryInterval = ti
// set load balance algorithm
t, ok := typeMap[lb.LoadBalanceType]
if !ok {
t = TypeBalanceRandom
}
lb.loadBalanceType = t
// setup each upstream
var us []*upstream
for _, uc := range lb.Upstreams {
// pass the upstream decr and incr methods to keep track of unhealthy nodes
nu, err := newUpstream(uc, lb.decrUnhealthy, lb.incrUnhealthy)
if err != nil {
return err
}
// setup any configured circuit breakers
var cbModule = "http.handlers.reverse_proxy.circuit_breaker"
var cb CircuitBreaker
if uc.CircuitBreaker != nil {
if _, err := caddy.GetModule(cbModule); err == nil {
val, err := ctx.LoadModule(cbModule, uc.CircuitBreaker)
if err == nil {
cbv, ok := val.(CircuitBreaker)
if ok {
cb = cbv
} else {
fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error())
cb = noopCircuitBreaker{}
}
} else {
fmt.Printf("\nerr: %v; cannot load circuit_breaker, using noop", err.Error())
cb = noopCircuitBreaker{}
}
} else {
fmt.Println("circuit_breaker module not loaded, using noop")
cb = noopCircuitBreaker{}
}
} else {
cb = noopCircuitBreaker{}
}
nu.CB = cb
// start a healthcheck worker which will periodically check to see if an upstream is healthy
// to proxy requests to.
nu.healthChecker = NewHealthCheckWorker(nu, defaultHealthCheckDur, defaultHTTPClient)
// TODO :- if path is empty why does this empty the entire Target?
// nu.Target.Path = uc.HealthCheckPath
nu.healthChecker.ScheduleChecks(nu.Target.String())
lb.HealthCheckers = append(lb.HealthCheckers, nu.healthChecker)
us = append(us, nu)
}
lb.upstreams = us
return nil
}
// LoadBalanced represents a collection of upstream hosts that are loadbalanced. It
// contains multiple features like health checking and circuit breaking functionality
// for upstreams.
type LoadBalanced struct {
mu sync.Mutex
numUnhealthy int32
selectedServer int // used during round robin load balancing
loadBalanceType int
tryInterval time.Duration
upstreams []*upstream
// The following struct fields are set by caddy configuration.
// TryInterval is the max duration for which request retrys will be performed for a request.
TryInterval string `json:"try_interval,omitempty"`
// Upstreams are the configs for upstream hosts
Upstreams []*UpstreamConfig `json:"upstreams,omitempty"`
// LoadBalanceType is the string representation of what loadbalancing algorithm to use. i.e. "random" or "round_robin".
LoadBalanceType string `json:"load_balance_type,omitempty"`
// NoHealthyUpstreamsMessage is returned as a response when there are no healthy upstreams to loadbalance to.
NoHealthyUpstreamsMessage string `json:"no_healthy_upstreams_message,omitempty"`
// TODO :- store healthcheckers as package level state where each upstream gets a single healthchecker
// currently a healthchecker is created for each upstream defined, even if a healthchecker was previously created
// for that upstream
HealthCheckers []*HealthChecker `json:"health_checkers,omitempty"`
}
// Cleanup stops all health checkers on a loadbalanced reverse proxy.
func (lb *LoadBalanced) Cleanup() error {
for _, hc := range lb.HealthCheckers {
hc.Stop()
}
return nil
}
// Provision sets up a new loadbalanced reverse proxy.
func (lb *LoadBalanced) Provision(ctx caddy.Context) error {
return NewLoadBalancedReverseProxy(lb, ctx)
}
// ServeHTTP implements the caddyhttp.MiddlewareHandler interface to
// dispatch an HTTP request to the proper server.
func (lb *LoadBalanced) ServeHTTP(w http.ResponseWriter, r *http.Request, _ caddyhttp.Handler) error {
// ensure requests don't hang if an upstream does not respond or is not eventually healthy
var u *upstream
var done bool
retryTimer := time.NewTicker(lb.tryInterval)
defer retryTimer.Stop()
go func() {
select {
case <-retryTimer.C:
done = true
}
}()
// keep trying to get an available upstream to process the request
for {
switch lb.loadBalanceType {
case TypeBalanceRandom:
u = lb.random()
case TypeBalanceRoundRobin:
u = lb.roundRobin()
}
// if we can't get an upstream and our retry interval has ended return an error response
if u == nil && done {
w.WriteHeader(http.StatusBadGateway)
fmt.Fprint(w, lb.NoHealthyUpstreamsMessage)
return fmt.Errorf(msgNoHealthyUpstreams)
}
// attempt to get an available upstream
if u == nil {
continue
}
start := time.Now()
// if we get an error retry until we get a healthy upstream
res, err := u.ReverseProxy.ServeHTTP(w, r)
if err != nil {
if err == context.Canceled {
return nil
}
continue
}
// record circuit breaker metrics
go u.CB.RecordMetric(res.StatusCode, time.Now().Sub(start))
return nil
}
}
// incrUnhealthy increments the amount of unhealthy nodes in a loadbalancer.
func (lb *LoadBalanced) incrUnhealthy() {
atomic.AddInt32(&lb.numUnhealthy, 1)
}
// decrUnhealthy decrements the amount of unhealthy nodes in a loadbalancer.
func (lb *LoadBalanced) decrUnhealthy() {
atomic.AddInt32(&lb.numUnhealthy, -1)
}
// roundRobin implements a round robin load balancing algorithm to select
// which server to forward requests to.
func (lb *LoadBalanced) roundRobin() *upstream {
if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) {
return nil
}
selected := lb.upstreams[lb.selectedServer]
lb.mu.Lock()
lb.selectedServer++
if lb.selectedServer >= len(lb.upstreams) {
lb.selectedServer = 0
}
lb.mu.Unlock()
if selected.IsHealthy() && selected.CB.Ok() {
return selected
}
return nil
}
// random implements a random server selector for load balancing.
func (lb *LoadBalanced) random() *upstream {
if atomic.LoadInt32(&lb.numUnhealthy) == int32(len(lb.upstreams)) {
return nil
}
n := rand.Int() % len(lb.upstreams)
selected := lb.upstreams[n]
if selected.IsHealthy() && selected.CB.Ok() {
return selected
}
return nil
}
// UpstreamConfig represents the config of an upstream.
type UpstreamConfig struct {
// Host is the host name of the upstream server.
Host string `json:"host,omitempty"`
// FastHealthCheckDuration is the duration for which a health check is performed when a node is considered unhealthy.
FastHealthCheckDuration string `json:"fast_health_check_duration,omitempty"`
CircuitBreaker json.RawMessage `json:"circuit_breaker,omitempty"`
// // CircuitBreakerConfig is the config passed to setup a circuit breaker.
// CircuitBreakerConfig *circuitbreaker.Config `json:"circuit_breaker,omitempty"`
circuitbreaker CircuitBreaker
// HealthCheckDuration is the default duration for which a health check is performed.
HealthCheckDuration string `json:"health_check_duration,omitempty"`
// HealthCheckPath is the path at the upstream host to use for healthchecks.
HealthCheckPath string `json:"health_check_path,omitempty"`
}
// upstream represents an upstream host.
type upstream struct {
Healthy int32 // 0 = false, 1 = true
Target *url.URL
ReverseProxy *ReverseProxy
Incr func()
Decr func()
CB CircuitBreaker
healthChecker *HealthChecker
healthCheckDur time.Duration
fastHealthCheckDur time.Duration
}
// newUpstream returns a new upstream.
func newUpstream(uc *UpstreamConfig, d func(), i func()) (*upstream, error) {
host := strings.TrimSpace(uc.Host)
protoIdx := strings.Index(host, "://")
if protoIdx == -1 || len(host[:protoIdx]) == 0 {
return nil, fmt.Errorf("protocol is required for host")
}
hostURL, err := url.Parse(host)
if err != nil {
return nil, err
}
// parse healthcheck durations
hcd, err := time.ParseDuration(uc.HealthCheckDuration)
if err != nil {
hcd = defaultHealthCheckDur
}
fhcd, err := time.ParseDuration(uc.FastHealthCheckDuration)
if err != nil {
fhcd = defaultFastHealthCheckDur
}
u := upstream{
healthCheckDur: hcd,
fastHealthCheckDur: fhcd,
Target: hostURL,
Decr: d,
Incr: i,
Healthy: int32(0), // assume is unhealthy on start
}
u.ReverseProxy = newReverseProxy(hostURL, u.SetHealthiness)
return &u, nil
}
// SetHealthiness sets whether an upstream is healthy or not. The health check worker is updated to
// perform checks faster if a node is unhealthy.
func (u *upstream) SetHealthiness(ok bool) {
h := atomic.LoadInt32(&u.Healthy)
var wasHealthy bool
if h == 1 {
wasHealthy = true
} else {
wasHealthy = false
}
if ok {
u.healthChecker.Ticker = time.NewTicker(u.healthCheckDur)
if !wasHealthy {
atomic.AddInt32(&u.Healthy, 1)
u.Decr()
}
} else {
u.healthChecker.Ticker = time.NewTicker(u.fastHealthCheckDur)
if wasHealthy {
atomic.AddInt32(&u.Healthy, -1)
u.Incr()
}
}
}
// IsHealthy returns whether an Upstream is healthy or not.
func (u *upstream) IsHealthy() bool {
i := atomic.LoadInt32(&u.Healthy)
if i == 1 {
return true
}
return false
}
// newReverseProxy returns a new reverse proxy handler.
func newReverseProxy(target *url.URL, setHealthiness func(bool)) *ReverseProxy {
errorHandler := func(w http.ResponseWriter, r *http.Request, err error) {
// we don't need to worry about cancelled contexts since this doesn't necessarilly mean that
// the upstream is unhealthy.
if err != context.Canceled {
setHealthiness(false)
}
}
rp := NewSingleHostReverseProxy(target)
rp.ErrorHandler = errorHandler
rp.Transport = defaultTransport // use default transport that times out in 5 seconds
return rp
}
// Interface guards
var (
_ caddyhttp.MiddlewareHandler = (*LoadBalanced)(nil)
_ caddy.Provisioner = (*LoadBalanced)(nil)
_ caddy.CleanerUpper = (*LoadBalanced)(nil)
)

86
usagepool.go Normal file
View File

@ -0,0 +1,86 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddy
import (
"fmt"
"sync"
"sync/atomic"
)
// UsagePool is a thread-safe map that pools values
// based on usage; a LoadOrStore operation increments
// the usage, and a Delete decrements from the usage.
// If the usage count reaches 0, the value will be
// removed from the map. There is no way to overwrite
// existing keys in the pool without first deleting
// it as many times as it was stored. Deleting too
// many times will panic.
//
// An empty UsagePool is NOT safe to use; always call
// NewUsagePool() to make a new value.
type UsagePool struct {
pool *sync.Map
}
// NewUsagePool returns a new usage pool.
func NewUsagePool() *UsagePool {
return &UsagePool{pool: new(sync.Map)}
}
// Delete decrements the usage count for key and removes the
// value from the underlying map if the usage is 0. It returns
// true if the usage count reached 0 and the value was deleted.
// It panics if the usage count drops below 0; always call
// Delete precisely as many times as LoadOrStore.
func (up *UsagePool) Delete(key interface{}) (deleted bool) {
usageVal, ok := up.pool.Load(key)
if !ok {
return false
}
upv := usageVal.(*usagePoolVal)
newUsage := atomic.AddInt32(&upv.usage, -1)
if newUsage == 0 {
up.pool.Delete(key)
return true
} else if newUsage < 0 {
panic(fmt.Sprintf("deleted more than stored: %#v (usage: %d)",
upv.value, upv.usage))
}
return false
}
// LoadOrStore puts val in the pool and returns false if key does
// not already exist; otherwise if the key exists, it loads the
// existing value, increments the usage for that value, and returns
// the value along with true.
func (up *UsagePool) LoadOrStore(key, val interface{}) (actual interface{}, loaded bool) {
usageVal := &usagePoolVal{
usage: 1,
value: val,
}
actual, loaded = up.pool.LoadOrStore(key, usageVal)
if loaded {
upv := actual.(*usagePoolVal)
actual = upv.value
atomic.AddInt32(&upv.usage, 1)
}
return
}
type usagePoolVal struct {
usage int32 // accessed atomically; must be 64-bit aligned for 32-bit systems
value interface{}
}