mirror of
https://github.com/caddyserver/caddy.git
synced 2024-11-22 11:00:53 +08:00
Reconcile upstream dial addresses and request host/URL information
My goodness that was complicated Blessed be request.Context Sort of
This commit is contained in:
parent
a60d54dbfd
commit
0830fbad03
18
listeners.go
18
listeners.go
|
@ -165,19 +165,19 @@ var (
|
||||||
listenersMu sync.Mutex
|
listenersMu sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseListenAddr parses addr, a string of the form "network/host:port"
|
// ParseNetworkAddress parses addr, a string of the form "network/host:port"
|
||||||
// (with any part optional) into its component parts. Because a port can
|
// (with any part optional) into its component parts. Because a port can
|
||||||
// also be a port range, there may be multiple addresses returned.
|
// also be a port range, there may be multiple addresses returned.
|
||||||
func ParseListenAddr(addr string) (network string, addrs []string, err error) {
|
func ParseNetworkAddress(addr string) (network string, addrs []string, err error) {
|
||||||
var host, port string
|
var host, port string
|
||||||
network, host, port, err = SplitListenAddr(addr)
|
network, host, port, err = SplitNetworkAddress(addr)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
network = "tcp"
|
network = "tcp"
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if network == "unix" {
|
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
||||||
addrs = []string{host}
|
addrs = []string{host}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -204,14 +204,14 @@ func ParseListenAddr(addr string) (network string, addrs []string, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SplitListenAddr splits a into its network, host, and port components.
|
// SplitNetworkAddress splits a into its network, host, and port components.
|
||||||
// Note that port may be a port range, or omitted for unix sockets.
|
// Note that port may be a port range, or omitted for unix sockets.
|
||||||
func SplitListenAddr(a string) (network, host, port string, err error) {
|
func SplitNetworkAddress(a string) (network, host, port string, err error) {
|
||||||
if idx := strings.Index(a, "/"); idx >= 0 {
|
if idx := strings.Index(a, "/"); idx >= 0 {
|
||||||
network = strings.ToLower(strings.TrimSpace(a[:idx]))
|
network = strings.ToLower(strings.TrimSpace(a[:idx]))
|
||||||
a = a[idx+1:]
|
a = a[idx+1:]
|
||||||
}
|
}
|
||||||
if network == "unix" {
|
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
||||||
host = a
|
host = a
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -219,11 +219,11 @@ func SplitListenAddr(a string) (network, host, port string, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// JoinListenAddr combines network, host, and port into a single
|
// JoinNetworkAddress combines network, host, and port into a single
|
||||||
// address string of the form "network/host:port". Port may be a
|
// address string of the form "network/host:port". Port may be a
|
||||||
// port range. For unix sockets, the network should be "unix" and
|
// port range. For unix sockets, the network should be "unix" and
|
||||||
// the path to the socket should be given in the host argument.
|
// the path to the socket should be given in the host argument.
|
||||||
func JoinListenAddr(network, host, port string) string {
|
func JoinNetworkAddress(network, host, port string) string {
|
||||||
var a string
|
var a string
|
||||||
if network != "" {
|
if network != "" {
|
||||||
a = network + "/"
|
a = network + "/"
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSplitListenerAddr(t *testing.T) {
|
func TestSplitNetworkAddress(t *testing.T) {
|
||||||
for i, tc := range []struct {
|
for i, tc := range []struct {
|
||||||
input string
|
input string
|
||||||
expectNetwork string
|
expectNetwork string
|
||||||
|
@ -67,8 +67,18 @@ func TestSplitListenerAddr(t *testing.T) {
|
||||||
expectNetwork: "unix",
|
expectNetwork: "unix",
|
||||||
expectHost: "/foo/bar",
|
expectHost: "/foo/bar",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
input: "unixgram//foo/bar",
|
||||||
|
expectNetwork: "unixgram",
|
||||||
|
expectHost: "/foo/bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "unixpacket//foo/bar",
|
||||||
|
expectNetwork: "unixpacket",
|
||||||
|
expectHost: "/foo/bar",
|
||||||
|
},
|
||||||
} {
|
} {
|
||||||
actualNetwork, actualHost, actualPort, err := SplitListenAddr(tc.input)
|
actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input)
|
||||||
if tc.expectErr && err == nil {
|
if tc.expectErr && err == nil {
|
||||||
t.Errorf("Test %d: Expected error but got: %v", i, err)
|
t.Errorf("Test %d: Expected error but got: %v", i, err)
|
||||||
}
|
}
|
||||||
|
@ -87,7 +97,7 @@ func TestSplitListenerAddr(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJoinListenerAddr(t *testing.T) {
|
func TestJoinNetworkAddress(t *testing.T) {
|
||||||
for i, tc := range []struct {
|
for i, tc := range []struct {
|
||||||
network, host, port string
|
network, host, port string
|
||||||
expect string
|
expect string
|
||||||
|
@ -129,14 +139,14 @@ func TestJoinListenerAddr(t *testing.T) {
|
||||||
expect: "unix//foo/bar",
|
expect: "unix//foo/bar",
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
actual := JoinListenAddr(tc.network, tc.host, tc.port)
|
actual := JoinNetworkAddress(tc.network, tc.host, tc.port)
|
||||||
if actual != tc.expect {
|
if actual != tc.expect {
|
||||||
t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual)
|
t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseListenerAddr(t *testing.T) {
|
func TestParseNetworkAddress(t *testing.T) {
|
||||||
for i, tc := range []struct {
|
for i, tc := range []struct {
|
||||||
input string
|
input string
|
||||||
expectNetwork string
|
expectNetwork string
|
||||||
|
@ -194,7 +204,7 @@ func TestParseListenerAddr(t *testing.T) {
|
||||||
expectAddrs: []string{"localhost:0"},
|
expectAddrs: []string{"localhost:0"},
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
actualNetwork, actualAddrs, err := ParseListenAddr(tc.input)
|
actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input)
|
||||||
if tc.expectErr && err == nil {
|
if tc.expectErr && err == nil {
|
||||||
t.Errorf("Test %d: Expected error but got: %v", i, err)
|
t.Errorf("Test %d: Expected error but got: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,7 @@ func (app *App) Validate() error {
|
||||||
lnAddrs := make(map[string]string)
|
lnAddrs := make(map[string]string)
|
||||||
for srvName, srv := range app.Servers {
|
for srvName, srv := range app.Servers {
|
||||||
for _, addr := range srv.Listen {
|
for _, addr := range srv.Listen {
|
||||||
netw, expanded, err := caddy.ParseListenAddr(addr)
|
netw, expanded, err := caddy.ParseNetworkAddress(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid listener address '%s': %v", addr, err)
|
return fmt.Errorf("invalid listener address '%s': %v", addr, err)
|
||||||
}
|
}
|
||||||
|
@ -149,7 +149,7 @@ func (app *App) Start() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, lnAddr := range srv.Listen {
|
for _, lnAddr := range srv.Listen {
|
||||||
network, addrs, err := caddy.ParseListenAddr(lnAddr)
|
network, addrs, err := caddy.ParseNetworkAddress(lnAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
|
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
|
||||||
}
|
}
|
||||||
|
@ -309,7 +309,7 @@ func (app *App) automaticHTTPS() error {
|
||||||
|
|
||||||
// create HTTP->HTTPS redirects
|
// create HTTP->HTTPS redirects
|
||||||
for _, addr := range srv.Listen {
|
for _, addr := range srv.Listen {
|
||||||
netw, host, port, err := caddy.SplitListenAddr(addr)
|
netw, host, port, err := caddy.SplitNetworkAddress(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%s: invalid listener address: %v", srvName, addr)
|
return fmt.Errorf("%s: invalid listener address: %v", srvName, addr)
|
||||||
}
|
}
|
||||||
|
@ -318,7 +318,7 @@ func (app *App) automaticHTTPS() error {
|
||||||
if httpPort == 0 {
|
if httpPort == 0 {
|
||||||
httpPort = DefaultHTTPPort
|
httpPort = DefaultHTTPPort
|
||||||
}
|
}
|
||||||
httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort))
|
httpRedirLnAddr := caddy.JoinNetworkAddress(netw, host, strconv.Itoa(httpPort))
|
||||||
lnAddrMap[httpRedirLnAddr] = struct{}{}
|
lnAddrMap[httpRedirLnAddr] = struct{}{}
|
||||||
|
|
||||||
if parts := strings.SplitN(port, "-", 2); len(parts) == 2 {
|
if parts := strings.SplitN(port, "-", 2); len(parts) == 2 {
|
||||||
|
@ -361,7 +361,7 @@ func (app *App) automaticHTTPS() error {
|
||||||
var lnAddrs []string
|
var lnAddrs []string
|
||||||
mapLoop:
|
mapLoop:
|
||||||
for addr := range lnAddrMap {
|
for addr := range lnAddrMap {
|
||||||
netw, addrs, err := caddy.ParseListenAddr(addr)
|
netw, addrs, err := caddy.ParseNetworkAddress(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -386,7 +386,7 @@ func (app *App) automaticHTTPS() error {
|
||||||
func (app *App) listenerTaken(network, address string) bool {
|
func (app *App) listenerTaken(network, address string) bool {
|
||||||
for _, srv := range app.Servers {
|
for _, srv := range app.Servers {
|
||||||
for _, addr := range srv.Listen {
|
for _, addr := range srv.Listen {
|
||||||
netw, addrs, err := caddy.ParseListenAddr(addr)
|
netw, addrs, err := caddy.ParseNetworkAddress(addr)
|
||||||
if err != nil || netw != network {
|
if err != nil || netw != network {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
|
||||||
"github.com/caddyserver/caddy/v2/modules/caddytls"
|
"github.com/caddyserver/caddy/v2/modules/caddytls"
|
||||||
|
|
||||||
"github.com/caddyserver/caddy/v2"
|
"github.com/caddyserver/caddy/v2"
|
||||||
|
@ -34,6 +35,7 @@ func init() {
|
||||||
caddy.RegisterModule(Transport{})
|
caddy.RegisterModule(Transport{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transport facilitates FastCGI communication.
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
//////////////////////////////
|
//////////////////////////////
|
||||||
// TODO: taken from v1 Handler type
|
// TODO: taken from v1 Handler type
|
||||||
|
@ -57,32 +59,32 @@ type Transport struct {
|
||||||
|
|
||||||
// Use this directory as the fastcgi root directory. Defaults to the root
|
// Use this directory as the fastcgi root directory. Defaults to the root
|
||||||
// directory of the parent virtual host.
|
// directory of the parent virtual host.
|
||||||
Root string
|
Root string `json:"root,omitempty"`
|
||||||
|
|
||||||
// The path in the URL will be split into two, with the first piece ending
|
// The path in the URL will be split into two, with the first piece ending
|
||||||
// with the value of SplitPath. The first piece will be assumed as the
|
// with the value of SplitPath. The first piece will be assumed as the
|
||||||
// actual resource (CGI script) name, and the second piece will be set to
|
// actual resource (CGI script) name, and the second piece will be set to
|
||||||
// PATH_INFO for the CGI script to use.
|
// PATH_INFO for the CGI script to use.
|
||||||
SplitPath string
|
SplitPath string `json:"split_path,omitempty"`
|
||||||
|
|
||||||
// If the URL ends with '/' (which indicates a directory), these index
|
// If the URL ends with '/' (which indicates a directory), these index
|
||||||
// files will be tried instead.
|
// files will be tried instead.
|
||||||
IndexFiles []string
|
// IndexFiles []string
|
||||||
|
|
||||||
// Environment Variables
|
// Environment Variables
|
||||||
EnvVars [][2]string
|
EnvVars [][2]string `json:"env,omitempty"`
|
||||||
|
|
||||||
// Ignored paths
|
// Ignored paths
|
||||||
IgnoredSubPaths []string
|
// IgnoredSubPaths []string
|
||||||
|
|
||||||
// The duration used to set a deadline when connecting to an upstream.
|
// The duration used to set a deadline when connecting to an upstream.
|
||||||
DialTimeout time.Duration
|
DialTimeout caddy.Duration `json:"dial_timeout,omitempty"`
|
||||||
|
|
||||||
// The duration used to set a deadline when reading from the FastCGI server.
|
// The duration used to set a deadline when reading from the FastCGI server.
|
||||||
ReadTimeout time.Duration
|
ReadTimeout caddy.Duration `json:"read_timeout,omitempty"`
|
||||||
|
|
||||||
// The duration used to set a deadline when sending to the FastCGI server.
|
// The duration used to set a deadline when sending to the FastCGI server.
|
||||||
WriteTimeout time.Duration
|
WriteTimeout caddy.Duration `json:"write_timeout,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CaddyModule returns the Caddy module information.
|
// CaddyModule returns the Caddy module information.
|
||||||
|
@ -93,102 +95,62 @@ func (Transport) CaddyModule() caddy.ModuleInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoundTrip implements http.RoundTripper.
|
||||||
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
// Create environment for CGI script
|
|
||||||
env, err := t.buildEnv(r)
|
env, err := t.buildEnv(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("building environment: %v", err)
|
return nil, fmt.Errorf("building environment: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:
|
// TODO: doesn't dialer have a Timeout field?
|
||||||
// Connect to FastCGI gateway
|
|
||||||
// address, err := f.Address()
|
|
||||||
// if err != nil {
|
|
||||||
// return http.StatusBadGateway, err
|
|
||||||
// }
|
|
||||||
// network, address := parseAddress(address)
|
|
||||||
network, address := "tcp", r.URL.Host // TODO:
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
if t.DialTimeout > 0 {
|
if t.DialTimeout > 0 {
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithTimeout(ctx, t.DialTimeout)
|
ctx, cancel = context.WithTimeout(ctx, time.Duration(t.DialTimeout))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extract dial information from request (this
|
||||||
|
// should embedded by the reverse proxy)
|
||||||
|
network, address := "tcp", r.URL.Host
|
||||||
|
if dialInfoVal := ctx.Value(reverseproxy.DialInfoCtxKey); dialInfoVal != nil {
|
||||||
|
dialInfo := dialInfoVal.(reverseproxy.DialInfo)
|
||||||
|
network = dialInfo.Network
|
||||||
|
address = dialInfo.Address
|
||||||
|
}
|
||||||
|
|
||||||
fcgiBackend, err := DialContext(ctx, network, address)
|
fcgiBackend, err := DialContext(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("dialing backend: %v", err)
|
return nil, fmt.Errorf("dialing backend: %v", err)
|
||||||
}
|
}
|
||||||
// fcgiBackend is closed when response body is closed (see clientCloser)
|
// fcgiBackend gets closed when response body is closed (see clientCloser)
|
||||||
|
|
||||||
// read/write timeouts
|
// read/write timeouts
|
||||||
if err := fcgiBackend.SetReadTimeout(t.ReadTimeout); err != nil {
|
if err := fcgiBackend.SetReadTimeout(time.Duration(t.ReadTimeout)); err != nil {
|
||||||
return nil, fmt.Errorf("setting read timeout: %v", err)
|
return nil, fmt.Errorf("setting read timeout: %v", err)
|
||||||
}
|
}
|
||||||
if err := fcgiBackend.SetWriteTimeout(t.WriteTimeout); err != nil {
|
if err := fcgiBackend.SetWriteTimeout(time.Duration(t.WriteTimeout)); err != nil {
|
||||||
return nil, fmt.Errorf("setting write timeout: %v", err)
|
return nil, fmt.Errorf("setting write timeout: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp *http.Response
|
contentLength := r.ContentLength
|
||||||
|
if contentLength == 0 {
|
||||||
var contentLength int64
|
|
||||||
// if ContentLength is already set
|
|
||||||
if r.ContentLength > 0 {
|
|
||||||
contentLength = r.ContentLength
|
|
||||||
} else {
|
|
||||||
contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
|
contentLength, _ = strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case "HEAD":
|
case http.MethodHead:
|
||||||
resp, err = fcgiBackend.Head(env)
|
resp, err = fcgiBackend.Head(env)
|
||||||
case "GET":
|
case http.MethodGet:
|
||||||
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
|
resp, err = fcgiBackend.Get(env, r.Body, contentLength)
|
||||||
case "OPTIONS":
|
case http.MethodOptions:
|
||||||
resp, err = fcgiBackend.Options(env)
|
resp, err = fcgiBackend.Options(env)
|
||||||
default:
|
default:
|
||||||
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
|
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:
|
|
||||||
return resp, err
|
return resp, err
|
||||||
|
|
||||||
// Stuff brought over from v1 that might not be necessary here:
|
|
||||||
|
|
||||||
// if resp != nil && resp.Body != nil {
|
|
||||||
// defer resp.Body.Close()
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if err != nil {
|
|
||||||
// if err, ok := err.(net.Error); ok && err.Timeout() {
|
|
||||||
// return http.StatusGatewayTimeout, err
|
|
||||||
// } else if err != io.EOF {
|
|
||||||
// return http.StatusBadGateway, err
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Write response header
|
|
||||||
// writeHeader(w, resp)
|
|
||||||
|
|
||||||
// // Write the response body
|
|
||||||
// _, err = io.Copy(w, resp.Body)
|
|
||||||
// if err != nil {
|
|
||||||
// return http.StatusBadGateway, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Log any stderr output from upstream
|
|
||||||
// if fcgiBackend.stderr.Len() != 0 {
|
|
||||||
// // Remove trailing newline, error logger already does this.
|
|
||||||
// err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Normally we would return the status code if it is an error status (>= 400),
|
|
||||||
// // however, upstream FastCGI apps don't know about our contract and have
|
|
||||||
// // probably already written an error page. So we just return 0, indicating
|
|
||||||
// // that the response body is already written. However, we do return any
|
|
||||||
// // error value so it can be logged.
|
|
||||||
// // Note that the proxy middleware works the same way, returning status=0.
|
|
||||||
// return 0, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildEnv returns a set of CGI environment variables for the request.
|
// buildEnv returns a set of CGI environment variables for the request.
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package reverseproxy
|
package reverseproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -93,15 +94,31 @@ func (h *Handler) activeHealthChecker() {
|
||||||
// health checks for all hosts in the global repository.
|
// health checks for all hosts in the global repository.
|
||||||
func (h *Handler) doActiveHealthChecksForAllHosts() {
|
func (h *Handler) doActiveHealthChecksForAllHosts() {
|
||||||
hosts.Range(func(key, value interface{}) bool {
|
hosts.Range(func(key, value interface{}) bool {
|
||||||
addr := key.(string)
|
networkAddr := key.(string)
|
||||||
host := value.(Host)
|
host := value.(Host)
|
||||||
|
|
||||||
go func(addr string, host Host) {
|
go func(networkAddr string, host Host) {
|
||||||
err := h.doActiveHealthCheck(addr, host)
|
network, addrs, err := caddy.ParseNetworkAddress(networkAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", addr, err)
|
log.Printf("[ERROR] reverse_proxy: active health check for host %s: bad network address: %v", networkAddr, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}(addr, host)
|
if len(addrs) != 1 {
|
||||||
|
log.Printf("[ERROR] reverse_proxy: active health check for host %s: multiple addresses (upstream must map to only one address)", networkAddr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hostAddr := addrs[0]
|
||||||
|
if network == "unix" || network == "unixgram" || network == "unixpacket" {
|
||||||
|
// this will be used as the Host portion of a http.Request URL, and
|
||||||
|
// paths to socket files would produce an error when creating URL,
|
||||||
|
// so use a fake Host value instead
|
||||||
|
hostAddr = network
|
||||||
|
}
|
||||||
|
err = h.doActiveHealthCheck(DialInfo{network, addrs[0]}, hostAddr, host)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[ERROR] reverse_proxy: active health check for host %s: %v", networkAddr, err)
|
||||||
|
}
|
||||||
|
}(networkAddr, host)
|
||||||
|
|
||||||
// continue to iterate all hosts
|
// continue to iterate all hosts
|
||||||
return true
|
return true
|
||||||
|
@ -115,26 +132,39 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
|
||||||
// according to whether it passes the health check. An error is
|
// according to whether it passes the health check. An error is
|
||||||
// returned only if the health check fails to occur or if marking
|
// returned only if the health check fails to occur or if marking
|
||||||
// the host's health status fails.
|
// the host's health status fails.
|
||||||
func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error {
|
func (h *Handler) doActiveHealthCheck(dialInfo DialInfo, hostAddr string, host Host) error {
|
||||||
// create the URL for the health check
|
// create the URL for the request that acts as a health check
|
||||||
u, err := url.Parse(hostAddr)
|
scheme := "http"
|
||||||
if err != nil {
|
if ht, ok := h.Transport.(*http.Transport); ok && ht.TLSClientConfig != nil {
|
||||||
return err
|
// this is kind of a hacky way to know if we should use HTTPS, but whatever
|
||||||
|
scheme = "https"
|
||||||
}
|
}
|
||||||
if h.HealthChecks.Active.Path != "" {
|
u := &url.URL{
|
||||||
u.Path = h.HealthChecks.Active.Path
|
Scheme: scheme,
|
||||||
|
Host: hostAddr,
|
||||||
|
Path: h.HealthChecks.Active.Path,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// adjust the port, if configured to be different
|
||||||
if h.HealthChecks.Active.Port != 0 {
|
if h.HealthChecks.Active.Port != 0 {
|
||||||
portStr := strconv.Itoa(h.HealthChecks.Active.Port)
|
portStr := strconv.Itoa(h.HealthChecks.Active.Port)
|
||||||
u.Host = net.JoinHostPort(u.Hostname(), portStr)
|
host, _, err := net.SplitHostPort(hostAddr)
|
||||||
|
if err != nil {
|
||||||
|
host = hostAddr
|
||||||
|
}
|
||||||
|
u.Host = net.JoinHostPort(host, portStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
// attach dialing information to this request
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = context.WithValue(ctx, caddy.ReplacerCtxKey, caddy.NewReplacer())
|
||||||
|
ctx = context.WithValue(ctx, DialInfoCtxKey, dialInfo)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("making request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// do the request, careful to tame the response body
|
// do the request, being careful to tame the response body
|
||||||
resp, err := h.HealthChecks.Active.httpClient.Do(req)
|
resp, err := h.HealthChecks.Active.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err)
|
log.Printf("[INFO] reverse_proxy: active health check: %s is down (HTTP request failed: %v)", hostAddr, err)
|
||||||
|
@ -149,7 +179,7 @@ func (h *Handler) doActiveHealthCheck(hostAddr string, host Host) error {
|
||||||
body = io.LimitReader(body, h.HealthChecks.Active.MaxSize)
|
body = io.LimitReader(body, h.HealthChecks.Active.MaxSize)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
// drain any remaining body so connection can be re-used
|
// drain any remaining body so connection could be re-used
|
||||||
io.Copy(ioutil.Discard, body)
|
io.Copy(ioutil.Discard, body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
}()
|
}()
|
||||||
|
@ -225,7 +255,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
|
||||||
err := upstream.Host.CountFail(1)
|
err := upstream.Host.CountFail(1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERROR] proxy: upstream %s: counting failure: %v",
|
log.Printf("[ERROR] proxy: upstream %s: counting failure: %v",
|
||||||
upstream.hostURL, err)
|
upstream.dialInfo, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// forget it later
|
// forget it later
|
||||||
|
@ -234,7 +264,7 @@ func (h *Handler) countFailure(upstream *Upstream) {
|
||||||
err := host.CountFail(-1)
|
err := host.CountFail(-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v",
|
log.Printf("[ERROR] proxy: upstream %s: expiring failure: %v",
|
||||||
upstream.hostURL, err)
|
upstream.dialInfo, err)
|
||||||
}
|
}
|
||||||
}(upstream.Host, failDuration)
|
}(upstream.Host, failDuration)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@ package reverseproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/caddyserver/caddy/v2"
|
"github.com/caddyserver/caddy/v2"
|
||||||
|
@ -59,7 +58,7 @@ type UpstreamPool []*Upstream
|
||||||
type Upstream struct {
|
type Upstream struct {
|
||||||
Host `json:"-"`
|
Host `json:"-"`
|
||||||
|
|
||||||
Address string `json:"address,omitempty"`
|
Dial string `json:"dial,omitempty"`
|
||||||
MaxRequests int `json:"max_requests,omitempty"`
|
MaxRequests int `json:"max_requests,omitempty"`
|
||||||
|
|
||||||
// TODO: This could be really useful, to bind requests
|
// TODO: This could be really useful, to bind requests
|
||||||
|
@ -68,8 +67,8 @@ type Upstream struct {
|
||||||
// IPAffinity string
|
// IPAffinity string
|
||||||
|
|
||||||
healthCheckPolicy *PassiveHealthChecks
|
healthCheckPolicy *PassiveHealthChecks
|
||||||
hostURL *url.URL
|
|
||||||
cb CircuitBreaker
|
cb CircuitBreaker
|
||||||
|
dialInfo DialInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Available returns true if the remote host
|
// Available returns true if the remote host
|
||||||
|
@ -101,11 +100,6 @@ func (u *Upstream) Full() bool {
|
||||||
return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
|
return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
|
||||||
}
|
}
|
||||||
|
|
||||||
// URL returns the upstream host's endpoint URL.
|
|
||||||
func (u *Upstream) URL() *url.URL {
|
|
||||||
return u.hostURL
|
|
||||||
}
|
|
||||||
|
|
||||||
// upstreamHost is the basic, in-memory representation
|
// upstreamHost is the basic, in-memory representation
|
||||||
// of the state of a remote host. It implements the
|
// of the state of a remote host. It implements the
|
||||||
// Host interface.
|
// Host interface.
|
||||||
|
@ -162,6 +156,34 @@ func (uh *upstreamHost) SetHealthy(healthy bool) (bool, error) {
|
||||||
return swapped, nil
|
return swapped, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialInfo contains information needed to dial a
|
||||||
|
// connection to an upstream host. This information
|
||||||
|
// may be different than that which is represented
|
||||||
|
// in a URL (for example, unix sockets don't have
|
||||||
|
// a host that can be represented in a URL, but
|
||||||
|
// they certainly have a network name and address).
|
||||||
|
type DialInfo struct {
|
||||||
|
// The network to use. This should be one of the
|
||||||
|
// values that is accepted by net.Dial:
|
||||||
|
// https://golang.org/pkg/net/#Dial
|
||||||
|
Network string
|
||||||
|
|
||||||
|
// The address to dial. Follows the same
|
||||||
|
// semantics and rules as net.Dial.
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the Caddy network address form
|
||||||
|
// by joining the network and address with a
|
||||||
|
// forward slash.
|
||||||
|
func (di DialInfo) String() string {
|
||||||
|
return di.Network + "/" + di.Address
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialInfoCtxKey is used to store a DialInfo
|
||||||
|
// in a context.Context.
|
||||||
|
const DialInfoCtxKey = caddy.CtxKey("dial_info")
|
||||||
|
|
||||||
// hosts is the global repository for hosts that are
|
// hosts is the global repository for hosts that are
|
||||||
// currently in use by active configuration(s). This
|
// currently in use by active configuration(s). This
|
||||||
// allows the state of remote hosts to be preserved
|
// allows the state of remote hosts to be preserved
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package reverseproxy
|
package reverseproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -63,14 +64,23 @@ func (HTTPTransport) CaddyModule() caddy.ModuleInfo {
|
||||||
|
|
||||||
// Provision sets up h.RoundTripper with a http.Transport
|
// Provision sets up h.RoundTripper with a http.Transport
|
||||||
// that is ready to use.
|
// that is ready to use.
|
||||||
func (h *HTTPTransport) Provision(ctx caddy.Context) error {
|
func (h *HTTPTransport) Provision(_ caddy.Context) error {
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: time.Duration(h.DialTimeout),
|
Timeout: time.Duration(h.DialTimeout),
|
||||||
FallbackDelay: time.Duration(h.FallbackDelay),
|
FallbackDelay: time.Duration(h.FallbackDelay),
|
||||||
// TODO: Resolver
|
// TODO: Resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
rt := &http.Transport{
|
rt := &http.Transport{
|
||||||
DialContext: dialer.DialContext,
|
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
// the proper dialing information should be embedded into the request's context
|
||||||
|
if dialInfoVal := ctx.Value(DialInfoCtxKey); dialInfoVal != nil {
|
||||||
|
dialInfo := dialInfoVal.(DialInfo)
|
||||||
|
network = dialInfo.Network
|
||||||
|
address = dialInfo.Address
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, address)
|
||||||
|
},
|
||||||
MaxConnsPerHost: h.MaxConnsPerHost,
|
MaxConnsPerHost: h.MaxConnsPerHost,
|
||||||
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
|
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),
|
||||||
ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
|
ExpectContinueTimeout: time.Duration(h.ExpectContinueTimeout),
|
||||||
|
@ -91,7 +101,6 @@ func (h *HTTPTransport) Provision(ctx caddy.Context) error {
|
||||||
|
|
||||||
if h.KeepAlive != nil {
|
if h.KeepAlive != nil {
|
||||||
dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval)
|
dialer.KeepAlive = time.Duration(h.KeepAlive.ProbeInterval)
|
||||||
|
|
||||||
if enabled := h.KeepAlive.Enabled; enabled != nil {
|
if enabled := h.KeepAlive.Enabled; enabled != nil {
|
||||||
rt.DisableKeepAlives = !*enabled
|
rt.DisableKeepAlives = !*enabled
|
||||||
}
|
}
|
||||||
|
@ -191,16 +200,3 @@ type KeepAlive struct {
|
||||||
MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"`
|
MaxIdleConnsPerHost int `json:"max_idle_conns_per_host,omitempty"`
|
||||||
IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle
|
IdleConnTimeout caddy.Duration `json:"idle_timeout,omitempty"` // how long should connections be kept alive when idle
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
defaultDialer = net.Dialer{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultTransport = &http.Transport{
|
|
||||||
DialContext: defaultDialer.DialContext,
|
|
||||||
TLSHandshakeTimeout: 5 * time.Second,
|
|
||||||
IdleConnTimeout: 2 * time.Minute,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -86,7 +85,18 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.Transport == nil {
|
if h.Transport == nil {
|
||||||
h.Transport = defaultTransport
|
t := &HTTPTransport{
|
||||||
|
KeepAlive: &KeepAlive{
|
||||||
|
ProbeInterval: caddy.Duration(30 * time.Second),
|
||||||
|
IdleConnTimeout: caddy.Duration(2 * time.Minute),
|
||||||
|
},
|
||||||
|
DialTimeout: caddy.Duration(10 * time.Second),
|
||||||
|
}
|
||||||
|
err := t.Provision(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("provisioning default transport: %v", err)
|
||||||
|
}
|
||||||
|
h.Transport = t
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.LoadBalancing == nil {
|
if h.LoadBalancing == nil {
|
||||||
|
@ -133,51 +143,65 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||||
go h.activeHealthChecker()
|
go h.activeHealthChecker()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var allUpstreams []*Upstream
|
||||||
for _, upstream := range h.Upstreams {
|
for _, upstream := range h.Upstreams {
|
||||||
upstream.cb = h.CB
|
// upstreams are allowed to map to only a single host,
|
||||||
|
// but an upstream's address may semantically represent
|
||||||
// url parser requires a scheme
|
// multiple addresses, so make sure to handle each
|
||||||
if !strings.Contains(upstream.Address, "://") {
|
// one in turn based on this one upstream config
|
||||||
upstream.Address = "http://" + upstream.Address
|
network, addresses, err := caddy.ParseNetworkAddress(upstream.Dial)
|
||||||
}
|
|
||||||
u, err := url.Parse(upstream.Address)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid upstream address %s: %v", upstream.Address, err)
|
return fmt.Errorf("parsing dial address: %v", err)
|
||||||
}
|
|
||||||
upstream.hostURL = u
|
|
||||||
|
|
||||||
// if host already exists from a current config,
|
|
||||||
// use that instead; otherwise, add it
|
|
||||||
// TODO: make hosts modular, so that their state can be distributed in enterprise for example
|
|
||||||
// TODO: If distributed, the pool should be stored in storage...
|
|
||||||
var host Host = new(upstreamHost)
|
|
||||||
activeHost, loaded := hosts.LoadOrStore(u.String(), host)
|
|
||||||
if loaded {
|
|
||||||
host = activeHost.(Host)
|
|
||||||
}
|
|
||||||
upstream.Host = host
|
|
||||||
|
|
||||||
// if the passive health checker has a non-zero "unhealthy
|
|
||||||
// request count" but the upstream has no MaxRequests set
|
|
||||||
// (they are the same thing, but one is a default value for
|
|
||||||
// for upstreams with a zero MaxRequests), copy the default
|
|
||||||
// value into this upstream, since the value in the upstream
|
|
||||||
// is what is used during availability checks
|
|
||||||
if h.HealthChecks != nil &&
|
|
||||||
h.HealthChecks.Passive != nil &&
|
|
||||||
h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
|
|
||||||
upstream.MaxRequests == 0 {
|
|
||||||
upstream.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.HealthChecks != nil {
|
for _, addr := range addresses {
|
||||||
|
// make a new upstream based on the original
|
||||||
|
// that has a singular dial address
|
||||||
|
upstreamCopy := *upstream
|
||||||
|
upstreamCopy.dialInfo = DialInfo{network, addr}
|
||||||
|
upstreamCopy.Dial = upstreamCopy.dialInfo.String()
|
||||||
|
upstreamCopy.cb = h.CB
|
||||||
|
|
||||||
|
// if host already exists from a current config,
|
||||||
|
// use that instead; otherwise, add it
|
||||||
|
// TODO: make hosts modular, so that their state can be distributed in enterprise for example
|
||||||
|
// TODO: If distributed, the pool should be stored in storage...
|
||||||
|
var host Host = new(upstreamHost)
|
||||||
|
activeHost, loaded := hosts.LoadOrStore(upstreamCopy.Dial, host)
|
||||||
|
if loaded {
|
||||||
|
host = activeHost.(Host)
|
||||||
|
}
|
||||||
|
upstreamCopy.Host = host
|
||||||
|
|
||||||
|
// if the passive health checker has a non-zero "unhealthy
|
||||||
|
// request count" but the upstream has no MaxRequests set
|
||||||
|
// (they are the same thing, but one is a default value for
|
||||||
|
// for upstreams with a zero MaxRequests), copy the default
|
||||||
|
// value into this upstream, since the value in the upstream
|
||||||
|
// (MaxRequests) is what is used during availability checks
|
||||||
|
if h.HealthChecks != nil &&
|
||||||
|
h.HealthChecks.Passive != nil &&
|
||||||
|
h.HealthChecks.Passive.UnhealthyRequestCount > 0 &&
|
||||||
|
upstreamCopy.MaxRequests == 0 {
|
||||||
|
upstreamCopy.MaxRequests = h.HealthChecks.Passive.UnhealthyRequestCount
|
||||||
|
}
|
||||||
|
|
||||||
// upstreams need independent access to the passive
|
// upstreams need independent access to the passive
|
||||||
// health check policy so they can, you know, passively
|
// health check policy because they run outside of the
|
||||||
// do health checks
|
// scope of a request handler
|
||||||
upstream.healthCheckPolicy = h.HealthChecks.Passive
|
if h.HealthChecks != nil {
|
||||||
|
upstreamCopy.healthCheckPolicy = h.HealthChecks.Passive
|
||||||
|
}
|
||||||
|
|
||||||
|
allUpstreams = append(allUpstreams, &upstreamCopy)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// replace the unmarshaled upstreams (possible 1:many
|
||||||
|
// address mapping) with our list, which is mapped 1:1,
|
||||||
|
// thus may have expanded the original list
|
||||||
|
h.Upstreams = allUpstreams
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,7 +216,7 @@ func (h *Handler) Cleanup() error {
|
||||||
|
|
||||||
// remove hosts from our config from the pool
|
// remove hosts from our config from the pool
|
||||||
for _, upstream := range h.Upstreams {
|
for _, upstream := range h.Upstreams {
|
||||||
hosts.Delete(upstream.hostURL.String())
|
hosts.Delete(upstream.dialInfo.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -222,6 +246,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// attach to the request information about how to dial the upstream;
|
||||||
|
// this is necessary because the information cannot be sufficiently
|
||||||
|
// or satisfactorily represented in a URL
|
||||||
|
ctx := context.WithValue(r.Context(), DialInfoCtxKey, upstream.dialInfo)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
// proxy the request to that upstream
|
// proxy the request to that upstream
|
||||||
proxyErr = h.reverseProxy(w, r, upstream)
|
proxyErr = h.reverseProxy(w, r, upstream)
|
||||||
if proxyErr == nil || proxyErr == context.Canceled {
|
if proxyErr == nil || proxyErr == context.Canceled {
|
||||||
|
@ -249,6 +279,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
|
||||||
// This assumes that no mutations of the request are performed
|
// This assumes that no mutations of the request are performed
|
||||||
// by h during or after proxying.
|
// by h during or after proxying.
|
||||||
func (h Handler) prepareRequest(req *http.Request) error {
|
func (h Handler) prepareRequest(req *http.Request) error {
|
||||||
|
// as a special (but very common) case, if the transport
|
||||||
|
// is HTTP, then ensure the request has the proper scheme
|
||||||
|
// because incoming requests by default are lacking it
|
||||||
|
if req.URL.Scheme == "" {
|
||||||
|
req.URL.Scheme = "http"
|
||||||
|
if ht, ok := h.Transport.(*HTTPTransport); ok && ht.TLS != nil {
|
||||||
|
req.URL.Scheme = "https"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if req.ContentLength == 0 {
|
if req.ContentLength == 0 {
|
||||||
req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries
|
req.Body = nil // Issue golang/go#16036: nil Body for http.Transport retries
|
||||||
}
|
}
|
||||||
|
@ -433,14 +473,8 @@ func (h Handler) tryAgain(start time.Time, proxyErr error) bool {
|
||||||
// directRequest modifies only req.URL so that it points to the
|
// directRequest modifies only req.URL so that it points to the
|
||||||
// given upstream host. It must modify ONLY the request URL.
|
// given upstream host. It must modify ONLY the request URL.
|
||||||
func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
|
func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
|
||||||
target := upstream.hostURL
|
if req.URL.Host == "" {
|
||||||
req.URL.Scheme = target.Scheme
|
req.URL.Host = upstream.dialInfo.Address
|
||||||
req.URL.Host = target.Host
|
|
||||||
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) // TODO: This might be a bug (if any part of the path was augmented from a previously-tried upstream; need to start from clean original path of request, same for query string!)
|
|
||||||
if target.RawQuery == "" || req.URL.RawQuery == "" {
|
|
||||||
req.URL.RawQuery = target.RawQuery + req.URL.RawQuery
|
|
||||||
} else {
|
|
||||||
req.URL.RawQuery = target.RawQuery + "&" + req.URL.RawQuery
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -168,7 +168,7 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next
|
||||||
// listeners in s that use a port which is not otherPort.
|
// listeners in s that use a port which is not otherPort.
|
||||||
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
|
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
|
||||||
for _, lnAddr := range s.Listen {
|
for _, lnAddr := range s.Listen {
|
||||||
_, addrs, err := caddy.ParseListenAddr(lnAddr)
|
_, addrs, err := caddy.ParseNetworkAddress(lnAddr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
for _, a := range addrs {
|
for _, a := range addrs {
|
||||||
_, port, err := net.SplitHostPort(a)
|
_, port, err := net.SplitHostPort(a)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user