From 223cbe3d0b50487117c785f0755bb80a9ee65010 Mon Sep 17 00:00:00 2001
From: Francis Lavoie <lavofr@gmail.com>
Date: Tue, 10 Jan 2023 00:08:23 -0500
Subject: [PATCH] caddyhttp: Add server-level `trusted_proxies` config (#5103)

---
 caddyconfig/httpcaddyfile/serveroptions.go    | 11 +++
 .../global_server_options_single.txt          |  9 +++
 modules/caddyhttp/app.go                      | 20 +++++
 modules/caddyhttp/matchers.go                 |  9 +--
 modules/caddyhttp/replacer.go                 |  3 +-
 modules/caddyhttp/reverseproxy/caddyfile.go   |  9 +--
 .../caddyhttp/reverseproxy/reverseproxy.go    |  6 +-
 modules/caddyhttp/server.go                   | 73 ++++++++++++++++++-
 8 files changed, 117 insertions(+), 23 deletions(-)

diff --git a/caddyconfig/httpcaddyfile/serveroptions.go b/caddyconfig/httpcaddyfile/serveroptions.go
index 3e206c873..7aa0a2a8c 100644
--- a/caddyconfig/httpcaddyfile/serveroptions.go
+++ b/caddyconfig/httpcaddyfile/serveroptions.go
@@ -42,6 +42,7 @@ type serverOptions struct {
 	MaxHeaderBytes       int
 	Protocols            []string
 	StrictSNIHost        *bool
+	TrustedProxies       []string
 	ShouldLogCredentials bool
 	Metrics              *caddyhttp.Metrics
 }
@@ -176,6 +177,15 @@ func unmarshalCaddyfileServerOptions(d *caddyfile.Dispenser) (any, error) {
 				}
 				serverOpts.StrictSNIHost = &boolVal
 
+			case "trusted_proxies":
+				for d.NextArg() {
+					if d.Val() == "private_ranges" {
+						serverOpts.TrustedProxies = append(serverOpts.TrustedProxies, caddyhttp.PrivateRangesCIDR()...)
+						continue
+					}
+					serverOpts.TrustedProxies = append(serverOpts.TrustedProxies, d.Val())
+				}
+
 			case "metrics":
 				if d.NextArg() {
 					return nil, d.ArgErr()
@@ -269,6 +279,7 @@ func applyServerOptions(
 		server.MaxHeaderBytes = opts.MaxHeaderBytes
 		server.Protocols = opts.Protocols
 		server.StrictSNIHost = opts.StrictSNIHost
+		server.TrustedProxies = opts.TrustedProxies
 		server.Metrics = opts.Metrics
 		if opts.ShouldLogCredentials {
 			if server.Logs == nil {
diff --git a/caddytest/integration/caddyfile_adapt/global_server_options_single.txt b/caddytest/integration/caddyfile_adapt/global_server_options_single.txt
index 5fb673929..f767ea742 100644
--- a/caddytest/integration/caddyfile_adapt/global_server_options_single.txt
+++ b/caddytest/integration/caddyfile_adapt/global_server_options_single.txt
@@ -14,6 +14,7 @@
 		log_credentials
 		protocols h1 h2 h2c h3
 		strict_sni_host
+		trusted_proxies private_ranges
 	}
 }
 
@@ -55,6 +56,14 @@ foo.com {
 						}
 					],
 					"strict_sni_host": true,
+					"trusted_proxies": [
+						"192.168.0.0/16",
+						"172.16.0.0/12",
+						"10.0.0.0/8",
+						"127.0.0.1/8",
+						"fd00::/8",
+						"::1"
+					],
 					"logs": {
 						"should_log_credentials": true
 					},
diff --git a/modules/caddyhttp/app.go b/modules/caddyhttp/app.go
index 0943b32df..d790284c8 100644
--- a/modules/caddyhttp/app.go
+++ b/modules/caddyhttp/app.go
@@ -20,7 +20,9 @@ import (
 	"fmt"
 	"net"
 	"net/http"
+	"net/netip"
 	"strconv"
+	"strings"
 	"sync"
 	"time"
 
@@ -222,6 +224,24 @@ func (app *App) Provision(ctx caddy.Context) error {
 			srv.StrictSNIHost = &trueBool
 		}
 
+		// parse trusted proxy CIDRs ahead of time
+		for _, str := range srv.TrustedProxies {
+			if strings.Contains(str, "/") {
+				ipNet, err := netip.ParsePrefix(str)
+				if err != nil {
+					return fmt.Errorf("parsing CIDR expression: '%s': %v", str, err)
+				}
+				srv.trustedProxies = append(srv.trustedProxies, ipNet)
+			} else {
+				ipAddr, err := netip.ParseAddr(str)
+				if err != nil {
+					return fmt.Errorf("invalid IP address: '%s': %v", str, err)
+				}
+				ipNew := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
+				srv.trustedProxies = append(srv.trustedProxies, ipNew)
+			}
+		}
+
 		// process each listener address
 		for i := range srv.Listen {
 			lnOut, err := repl.ReplaceOrErr(srv.Listen[i], true, true)
diff --git a/modules/caddyhttp/matchers.go b/modules/caddyhttp/matchers.go
index 400c1543e..3064300bb 100644
--- a/modules/caddyhttp/matchers.go
+++ b/modules/caddyhttp/matchers.go
@@ -1281,14 +1281,7 @@ func (m *MatchRemoteIP) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
 				continue
 			}
 			if d.Val() == "private_ranges" {
-				m.Ranges = append(m.Ranges, []string{
-					"192.168.0.0/16",
-					"172.16.0.0/12",
-					"10.0.0.0/8",
-					"127.0.0.1/8",
-					"fd00::/8",
-					"::1",
-				}...)
+				m.Ranges = append(m.Ranges, PrivateRangesCIDR()...)
 				continue
 			}
 			m.Ranges = append(m.Ranges, d.Val())
diff --git a/modules/caddyhttp/replacer.go b/modules/caddyhttp/replacer.go
index e89c502b6..c58b56ed3 100644
--- a/modules/caddyhttp/replacer.go
+++ b/modules/caddyhttp/replacer.go
@@ -290,8 +290,7 @@ func addHTTPVarsToReplacer(repl *caddy.Replacer, req *http.Request, w http.Respo
 			// middleware variables
 			if strings.HasPrefix(key, varsReplPrefix) {
 				varName := key[len(varsReplPrefix):]
-				tbl := req.Context().Value(VarsCtxKey).(map[string]any)
-				raw := tbl[varName]
+				raw := GetVar(req.Context(), varName)
 				// variables can be dynamic, so always return true
 				// even when it may not be set; treat as empty then
 				return raw, true
diff --git a/modules/caddyhttp/reverseproxy/caddyfile.go b/modules/caddyhttp/reverseproxy/caddyfile.go
index f1f10d193..cd9b77cd3 100644
--- a/modules/caddyhttp/reverseproxy/caddyfile.go
+++ b/modules/caddyhttp/reverseproxy/caddyfile.go
@@ -549,14 +549,7 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
 			case "trusted_proxies":
 				for d.NextArg() {
 					if d.Val() == "private_ranges" {
-						h.TrustedProxies = append(h.TrustedProxies, []string{
-							"192.168.0.0/16",
-							"172.16.0.0/12",
-							"10.0.0.0/8",
-							"127.0.0.1/8",
-							"fd00::/8",
-							"::1",
-						}...)
+						h.TrustedProxies = append(h.TrustedProxies, caddyhttp.PrivateRangesCIDR()...)
 						continue
 					}
 					h.TrustedProxies = append(h.TrustedProxies, d.Val())
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go
index 3adec3d9c..88d98e827 100644
--- a/modules/caddyhttp/reverseproxy/reverseproxy.go
+++ b/modules/caddyhttp/reverseproxy/reverseproxy.go
@@ -701,16 +701,14 @@ func (h Handler) addForwardedHeaders(req *http.Request) error {
 
 	// Client IP may contain a zone if IPv6, so we need
 	// to pull that out before parsing the IP
-	if before, _, found := strings.Cut(clientIP, "%"); found {
-		clientIP = before
-	}
+	clientIP, _, _ = strings.Cut(clientIP, "%")
 	ipAddr, err := netip.ParseAddr(clientIP)
 	if err != nil {
 		return fmt.Errorf("invalid IP address: '%s': %v", clientIP, err)
 	}
 
 	// Check if the client is a trusted proxy
-	trusted := false
+	trusted := caddyhttp.GetVar(req.Context(), caddyhttp.TrustedProxyVarKey).(bool)
 	for _, ipRange := range h.trustedProxies {
 		if ipRange.Contains(ipAddr) {
 			trusted = true
diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go
index 7a244d2c7..f50abf0e3 100644
--- a/modules/caddyhttp/server.go
+++ b/modules/caddyhttp/server.go
@@ -21,6 +21,7 @@ import (
 	"fmt"
 	"net"
 	"net/http"
+	"net/netip"
 	"net/url"
 	"runtime"
 	"strings"
@@ -117,6 +118,18 @@ type Server struct {
 	// client authentication.
 	StrictSNIHost *bool `json:"strict_sni_host,omitempty"`
 
+	// A list of IP ranges (supports CIDR notation) from which
+	// requests should be trusted. By default, no proxies are
+	// trusted.
+	//
+	// On its own, this configuration will not do anything,
+	// but it can be used as a default set of ranges for
+	// handlers or matchers in routes to pick up, instead
+	// of needing to configure each of them. See the
+	// `reverse_proxy` handler for example, which uses this
+	// to trust sensitive incoming `X-Forwarded-*` headers.
+	TrustedProxies []string `json:"trusted_proxies,omitempty"`
+
 	// Enables access logging and configures how access logs are handled
 	// in this server. To minimally enable access logs, simply set this
 	// to a non-null, empty struct.
@@ -175,6 +188,9 @@ type Server struct {
 	h3listeners []net.PacketConn // TODO: we have to hold these because quic-go won't close listeners it didn't create
 	addresses   []caddy.NetworkAddress
 
+	// Holds the parsed CIDR ranges from TrustedProxies
+	trustedProxies []netip.Prefix
+
 	shutdownAt   time.Time
 	shutdownAtMu *sync.RWMutex
 
@@ -675,7 +691,9 @@ func PrepareRequest(r *http.Request, repl *caddy.Replacer, w http.ResponseWriter
 	// set up the context for the request
 	ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl)
 	ctx = context.WithValue(ctx, ServerCtxKey, s)
-	ctx = context.WithValue(ctx, VarsCtxKey, make(map[string]any))
+	ctx = context.WithValue(ctx, VarsCtxKey, map[string]any{
+		TrustedProxyVarKey: determineTrustedProxy(r, s),
+	})
 	ctx = context.WithValue(ctx, routeGroupCtxKey, make(map[string]struct{}))
 	var url2 url.URL // avoid letting this escape to the heap
 	ctx = context.WithValue(ctx, OriginalRequestCtxKey, originalRequest(r, &url2))
@@ -705,6 +723,43 @@ func originalRequest(req *http.Request, urlCopy *url.URL) http.Request {
 	}
 }
 
+// determineTrustedProxy parses the remote IP address of
+// the request, and determines (if the server configured it)
+// if the client is a trusted proxy.
+func determineTrustedProxy(r *http.Request, s *Server) bool {
+	// If there's no server, then we can't check anything
+	if s == nil {
+		return false
+	}
+
+	// Parse the remote IP, ignore the error as non-fatal,
+	// but the remote IP is required to continue, so we
+	// just return early. This should probably never happen
+	// though, unless some other module manipulated the request's
+	// remote address and used an invalid value.
+	clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
+	if err != nil {
+		return false
+	}
+
+	// Client IP may contain a zone if IPv6, so we need
+	// to pull that out before parsing the IP
+	clientIP, _, _ = strings.Cut(clientIP, "%")
+	ipAddr, err := netip.ParseAddr(clientIP)
+	if err != nil {
+		return false
+	}
+
+	// Check if the client is a trusted proxy
+	for _, ipRange := range s.trustedProxies {
+		if ipRange.Contains(ipAddr) {
+			return true
+		}
+	}
+
+	return false
+}
+
 // cloneURL makes a copy of r.URL and returns a
 // new value that doesn't reference the original.
 func cloneURL(from, to *url.URL) {
@@ -716,6 +771,19 @@ func cloneURL(from, to *url.URL) {
 	}
 }
 
+// PrivateRangesCIDR returns a list of private CIDR range
+// strings, which can be used as a configuration shortcut.
+func PrivateRangesCIDR() []string {
+	return []string{
+		"192.168.0.0/16",
+		"172.16.0.0/12",
+		"10.0.0.0/8",
+		"127.0.0.1/8",
+		"fd00::/8",
+		"::1",
+	}
+}
+
 // Context keys for HTTP request context values.
 const (
 	// For referencing the server instance
@@ -727,4 +795,7 @@ const (
 	// For a partial copy of the unmodified request that
 	// originally came into the server's entry handler
 	OriginalRequestCtxKey caddy.CtxKey = "original_request"
+
+	// For tracking whether the client is a trusted proxy
+	TrustedProxyVarKey string = "trusted_proxy"
 )