core: Use port ranges to avoid OOM with bad inputs (#2859)

* fix OOM issue caught by fuzzing

* use ParsedAddress as the struct name for the result of ParseNetworkAddress

* simplify code using the ParsedAddress type

* minor cleanups
This commit is contained in:
Mohammed Al Sahaf 2019-11-12 01:33:38 +03:00 committed by Matt Holt
parent a19da07b72
commit 93bc1b72e3
8 changed files with 201 additions and 130 deletions

View File

@ -48,23 +48,19 @@ type AdminConfig struct {
// listenAddr extracts a singular listen address from ac.Listen, // listenAddr extracts a singular listen address from ac.Listen,
// returning the network and the address of the listener. // returning the network and the address of the listener.
func (admin AdminConfig) listenAddr() (netw string, addr string, err error) { func (admin AdminConfig) listenAddr() (string, string, error) {
var listenAddrs []string
input := admin.Listen input := admin.Listen
if input == "" { if input == "" {
input = DefaultAdminListen input = DefaultAdminListen
} }
netw, listenAddrs, err = ParseNetworkAddress(input) listenAddr, err := ParseNetworkAddress(input)
if err != nil { if err != nil {
err = fmt.Errorf("parsing admin listener address: %v", err) return "", "", fmt.Errorf("parsing admin listener address: %v", err)
return
} }
if len(listenAddrs) != 1 { if listenAddr.PortRangeSize() != 1 {
err = fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddrs) return "", "", fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddr)
return
} }
addr = listenAddrs[0] return listenAddr.Network, listenAddr.JoinHostPort(0), nil
return
} }
// newAdminHandler reads admin's config and returns an http.Handler suitable // newAdminHandler reads admin's config and returns an http.Handler suitable

View File

@ -257,52 +257,94 @@ type globalListener struct {
pc net.PacketConn pc net.PacketConn
} }
var ( // ParsedAddress contains the individual components
listeners = make(map[string]*globalListener) // for a parsed network address of the form accepted
listenersMu sync.Mutex // by ParseNetworkAddress(). Network should be a
) // network value accepted by Go's net package. Port
// ranges are given by [StartPort, EndPort].
type ParsedAddress struct {
Network string
Host string
StartPort uint
EndPort uint
}
// ParseNetworkAddress parses addr, a string of the form "network/host:port" // JoinHostPort is like net.JoinHostPort, but where the port
// (with any part optional) into its component parts. Because a port can // is StartPort + offset.
// also be a port range, there may be multiple addresses returned. func (l ParsedAddress) JoinHostPort(offset uint) string {
func ParseNetworkAddress(addr string) (network string, addrs []string, err error) { return net.JoinHostPort(l.Host, strconv.Itoa(int(l.StartPort+offset)))
}
// PortRangeSize returns how many ports are in
// pa's port range. Port ranges are inclusive,
// so the size is the difference of start and
// end ports plus one.
func (pa ParsedAddress) PortRangeSize() uint {
return (pa.EndPort - pa.StartPort) + 1
}
// String reconstructs the address string to the form expected
// by ParseNetworkAddress().
func (pa ParsedAddress) String() string {
port := strconv.FormatUint(uint64(pa.StartPort), 10)
if pa.StartPort != pa.EndPort {
port += "-" + strconv.FormatUint(uint64(pa.EndPort), 10)
}
return JoinNetworkAddress(pa.Network, pa.Host, port)
}
// ParseNetworkAddress parses addr into its individual
// components. The input string is expected to be of
// the form "network/host:port-range" where any part is
// optional. The default network, if unspecified, is tcp.
// Port ranges are inclusive.
//
// Network addresses are distinct from URLs and do not
// use URL syntax.
func ParseNetworkAddress(addr string) (ParsedAddress, error) {
var host, port string var host, port string
network, host, port, err = SplitNetworkAddress(addr) network, host, port, err := SplitNetworkAddress(addr)
if network == "" { if network == "" {
network = "tcp" network = "tcp"
} }
if err != nil { if err != nil {
return return ParsedAddress{}, err
} }
if network == "unix" || network == "unixgram" || network == "unixpacket" { if network == "unix" || network == "unixgram" || network == "unixpacket" {
addrs = []string{host} return ParsedAddress{
return Network: network,
Host: host,
}, nil
} }
ports := strings.SplitN(port, "-", 2) ports := strings.SplitN(port, "-", 2)
if len(ports) == 1 { if len(ports) == 1 {
ports = append(ports, ports[0]) ports = append(ports, ports[0])
} }
var start, end int var start, end uint64
start, err = strconv.Atoi(ports[0]) start, err = strconv.ParseUint(ports[0], 10, 16)
if err != nil { if err != nil {
return return ParsedAddress{}, fmt.Errorf("invalid start port: %v", err)
} }
end, err = strconv.Atoi(ports[1]) end, err = strconv.ParseUint(ports[1], 10, 16)
if err != nil { if err != nil {
return return ParsedAddress{}, fmt.Errorf("invalid end port: %v", err)
} }
if end < start { if end < start {
err = fmt.Errorf("end port must be greater than start port") return ParsedAddress{}, fmt.Errorf("end port must not be less than start port")
return
} }
for p := start; p <= end; p++ { if (end - start) > maxPortSpan {
addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p))) return ParsedAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan)
} }
return return ParsedAddress{
Network: network,
Host: host,
StartPort: uint(start),
EndPort: uint(end),
}, nil
} }
// SplitNetworkAddress splits a into its network, host, and port components. // SplitNetworkAddress splits a into its network, host, and port components.
// Note that port may be a port range, or omitted for unix sockets. // Note that port may be a port range (:X-Y), or omitted for unix sockets.
func SplitNetworkAddress(a string) (network, host, port string, err error) { func SplitNetworkAddress(a string) (network, host, port string, err error) {
if idx := strings.Index(a, "/"); idx >= 0 { if idx := strings.Index(a, "/"); idx >= 0 {
network = strings.ToLower(strings.TrimSpace(a[:idx])) network = strings.ToLower(strings.TrimSpace(a[:idx]))
@ -317,9 +359,9 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) {
} }
// JoinNetworkAddress combines network, host, and port into a single // JoinNetworkAddress combines network, host, and port into a single
// address string of the form "network/host:port". Port may be a // address string of the form accepted by ParseNetworkAddress(). For unix sockets, the network
// port range. For unix sockets, the network should be "unix" and // should be "unix" and the path to the socket should be given as the
// the path to the socket should be given in the host argument. // host parameter.
func JoinNetworkAddress(network, host, port string) string { func JoinNetworkAddress(network, host, port string) string {
var a string var a string
if network != "" { if network != "" {
@ -332,3 +374,10 @@ func JoinNetworkAddress(network, host, port string) string {
} }
return a return a
} }
var (
listeners = make(map[string]*globalListener)
listenersMu sync.Mutex
)
const maxPortSpan = 65535

View File

@ -18,7 +18,7 @@
package caddy package caddy
func FuzzParseNetworkAddress(data []byte) int { func FuzzParseNetworkAddress(data []byte) int {
_, _, err := ParseNetworkAddress(string(data)) _, err := ParseNetworkAddress(string(data))
if err != nil { if err != nil {
return 0 return 0
} }

View File

@ -152,74 +152,101 @@ func TestJoinNetworkAddress(t *testing.T) {
func TestParseNetworkAddress(t *testing.T) { func TestParseNetworkAddress(t *testing.T) {
for i, tc := range []struct { for i, tc := range []struct {
input string input string
expectNetwork string expectAddr ParsedAddress
expectAddrs []string expectErr bool
expectErr bool
}{ }{
{ {
input: "", input: "",
expectNetwork: "tcp", expectErr: true,
expectErr: true,
}, },
{ {
input: ":", input: ":",
expectNetwork: "tcp", expectErr: true,
expectErr: true,
}, },
{ {
input: ":1234", input: ":1234",
expectNetwork: "tcp", expectAddr: ParsedAddress{
expectAddrs: []string{":1234"}, Network: "tcp",
Host: "",
StartPort: 1234,
EndPort: 1234,
},
}, },
{ {
input: "tcp/:1234", input: "tcp/:1234",
expectNetwork: "tcp", expectAddr: ParsedAddress{
expectAddrs: []string{":1234"}, Network: "tcp",
Host: "",
StartPort: 1234,
EndPort: 1234,
},
}, },
{ {
input: "tcp6/:1234", input: "tcp6/:1234",
expectNetwork: "tcp6", expectAddr: ParsedAddress{
expectAddrs: []string{":1234"}, Network: "tcp6",
Host: "",
StartPort: 1234,
EndPort: 1234,
},
}, },
{ {
input: "tcp4/localhost:1234", input: "tcp4/localhost:1234",
expectNetwork: "tcp4", expectAddr: ParsedAddress{
expectAddrs: []string{"localhost:1234"}, Network: "tcp4",
Host: "localhost",
StartPort: 1234,
EndPort: 1234,
},
}, },
{ {
input: "unix//foo/bar", input: "unix//foo/bar",
expectNetwork: "unix", expectAddr: ParsedAddress{
expectAddrs: []string{"/foo/bar"}, Network: "unix",
Host: "/foo/bar",
},
}, },
{ {
input: "localhost:1234-1234", input: "localhost:1234-1234",
expectNetwork: "tcp", expectAddr: ParsedAddress{
expectAddrs: []string{"localhost:1234"}, Network: "tcp",
Host: "localhost",
StartPort: 1234,
EndPort: 1234,
},
}, },
{ {
input: "localhost:2-1", input: "localhost:2-1",
expectNetwork: "tcp", expectErr: true,
expectErr: true,
}, },
{ {
input: "localhost:0", input: "localhost:0",
expectNetwork: "tcp", expectAddr: ParsedAddress{
expectAddrs: []string{"localhost:0"}, Network: "tcp",
Host: "localhost",
StartPort: 0,
EndPort: 0,
},
},
{
input: "localhost:1-999999999999",
expectErr: true,
}, },
} { } {
actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input) actualAddr, err := ParseNetworkAddress(tc.input)
if tc.expectErr && err == nil { if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err) t.Errorf("Test %d: Expected error but got: %v", i, err)
} }
if !tc.expectErr && err != nil { if !tc.expectErr && err != nil {
t.Errorf("Test %d: Expected no error but got: %v", i, err) t.Errorf("Test %d: Expected no error but got: %v", i, err)
} }
if actualNetwork != tc.expectNetwork {
t.Errorf("Test %d: Expected network '%s' but got '%s'", i, tc.expectNetwork, actualNetwork) if actualAddr.Network != tc.expectAddr.Network {
t.Errorf("Test %d: Expected network '%v' but got '%v'", i, tc.expectAddr, actualAddr)
} }
if !reflect.DeepEqual(tc.expectAddrs, actualAddrs) { if !reflect.DeepEqual(tc.expectAddr, actualAddr) {
t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddrs, actualAddrs) t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddr, actualAddr)
} }
} }
} }

View File

@ -135,15 +135,18 @@ func (app *App) Validate() error {
lnAddrs := make(map[string]string) lnAddrs := make(map[string]string)
for srvName, srv := range app.Servers { for srvName, srv := range app.Servers {
for _, addr := range srv.Listen { for _, addr := range srv.Listen {
netw, expanded, err := caddy.ParseNetworkAddress(addr) listenAddr, err := caddy.ParseNetworkAddress(addr)
if err != nil { if err != nil {
return fmt.Errorf("invalid listener address '%s': %v", addr, err) return fmt.Errorf("invalid listener address '%s': %v", addr, err)
} }
for _, a := range expanded { // check that every address in the port range is unique to this server;
if sn, ok := lnAddrs[netw+a]; ok { // we do not use <= here because PortRangeSize() adds 1 to EndPort for us
return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, a, sn) for i := uint(0); i < listenAddr.PortRangeSize(); i++ {
addr := caddy.JoinNetworkAddress(listenAddr.Network, listenAddr.Host, strconv.Itoa(int(listenAddr.StartPort+i)))
if sn, ok := lnAddrs[addr]; ok {
return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, addr, sn)
} }
lnAddrs[netw+a] = srvName lnAddrs[addr] = srvName
} }
} }
} }
@ -176,14 +179,15 @@ func (app *App) Start() error {
} }
for _, lnAddr := range srv.Listen { for _, lnAddr := range srv.Listen {
network, addrs, err := caddy.ParseNetworkAddress(lnAddr) listenAddr, err := caddy.ParseNetworkAddress(lnAddr)
if err != nil { if err != nil {
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
} }
for _, addr := range addrs { for i := uint(0); i <= listenAddr.PortRangeSize(); i++ {
ln, err := caddy.Listen(network, addr) hostport := listenAddr.JoinHostPort(i)
ln, err := caddy.Listen(listenAddr.Network, hostport)
if err != nil { if err != nil {
return fmt.Errorf("%s: listening on %s: %v", network, addr, err) return fmt.Errorf("%s: listening on %s: %v", listenAddr.Network, hostport, err)
} }
// enable HTTP/2 by default // enable HTTP/2 by default
@ -194,11 +198,10 @@ func (app *App) Start() error {
} }
// enable TLS // enable TLS
_, port, _ := net.SplitHostPort(addr) if len(srv.TLSConnPolicies) > 0 && int(i) != app.httpPort() {
if len(srv.TLSConnPolicies) > 0 && port != strconv.Itoa(app.httpPort()) {
tlsCfg, err := srv.TLSConnPolicies.TLSConfig(app.ctx) tlsCfg, err := srv.TLSConnPolicies.TLSConfig(app.ctx)
if err != nil { if err != nil {
return fmt.Errorf("%s/%s: making TLS configuration: %v", network, addr, err) return fmt.Errorf("%s/%s: making TLS configuration: %v", listenAddr.Network, hostport, err)
} }
ln = tls.NewListener(ln, tlsCfg) ln = tls.NewListener(ln, tlsCfg)
@ -206,15 +209,15 @@ func (app *App) Start() error {
// TODO: HTTP/3 support is experimental for now // TODO: HTTP/3 support is experimental for now
if srv.ExperimentalHTTP3 { if srv.ExperimentalHTTP3 {
app.logger.Info("enabling experimental HTTP/3 listener", app.logger.Info("enabling experimental HTTP/3 listener",
zap.String("addr", addr), zap.String("addr", hostport),
) )
h3ln, err := caddy.ListenPacket("udp", addr) h3ln, err := caddy.ListenPacket("udp", hostport)
if err != nil { if err != nil {
return fmt.Errorf("getting HTTP/3 UDP listener: %v", err) return fmt.Errorf("getting HTTP/3 UDP listener: %v", err)
} }
h3srv := &http3.Server{ h3srv := &http3.Server{
Server: &http.Server{ Server: &http.Server{
Addr: addr, Addr: hostport,
Handler: srv, Handler: srv,
TLSConfig: tlsCfg, TLSConfig: tlsCfg,
}, },

View File

@ -102,7 +102,7 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
host := value.(Host) host := value.(Host)
go func(networkAddr string, host Host) { go func(networkAddr string, host Host) {
network, addrs, err := caddy.ParseNetworkAddress(networkAddr) addr, err := caddy.ParseNetworkAddress(networkAddr)
if err != nil { if err != nil {
h.HealthChecks.Active.logger.Error("bad network address", h.HealthChecks.Active.logger.Error("bad network address",
zap.String("address", networkAddr), zap.String("address", networkAddr),
@ -110,20 +110,20 @@ func (h *Handler) doActiveHealthChecksForAllHosts() {
) )
return return
} }
if len(addrs) != 1 { if addr.PortRangeSize() != 1 {
h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)", h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)",
zap.String("address", networkAddr), zap.String("address", networkAddr),
) )
return return
} }
hostAddr := addrs[0] hostAddr := addr.JoinHostPort(0)
if network == "unix" || network == "unixgram" || network == "unixpacket" { if addr.Network == "unix" || addr.Network == "unixgram" || addr.Network == "unixpacket" {
// this will be used as the Host portion of a http.Request URL, and // this will be used as the Host portion of a http.Request URL, and
// paths to socket files would produce an error when creating URL, // paths to socket files would produce an error when creating URL,
// so use a fake Host value instead; unix sockets are usually local // so use a fake Host value instead; unix sockets are usually local
hostAddr = "localhost" hostAddr = "localhost"
} }
err = h.doActiveHealthCheck(DialInfo{Network: network, Address: addrs[0]}, hostAddr, host) err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: hostAddr}, hostAddr, host)
if err != nil { if err != nil {
h.HealthChecks.Active.logger.Error("active health check failed", h.HealthChecks.Active.logger.Error("active health check failed",
zap.String("address", networkAddr), zap.String("address", networkAddr),

View File

@ -16,8 +16,7 @@ package reverseproxy
import ( import (
"fmt" "fmt"
"net" "strconv"
"strings"
"sync/atomic" "sync/atomic"
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
@ -193,27 +192,20 @@ func (di DialInfo) String() string {
// the given Replacer. Note that the returned value is not a pointer. // the given Replacer. Note that the returned value is not a pointer.
func fillDialInfo(upstream *Upstream, repl caddy.Replacer) (DialInfo, error) { func fillDialInfo(upstream *Upstream, repl caddy.Replacer) (DialInfo, error) {
dial := repl.ReplaceAll(upstream.Dial, "") dial := repl.ReplaceAll(upstream.Dial, "")
netw, addrs, err := caddy.ParseNetworkAddress(dial) addr, err := caddy.ParseNetworkAddress(dial)
if err != nil { if err != nil {
return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err) return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err)
} }
if len(addrs) != 1 { if numPorts := addr.PortRangeSize(); numPorts != 1 {
return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d", return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
upstream.Dial, dial, len(addrs)) upstream.Dial, dial, numPorts)
}
var dialHost, dialPort string
if !strings.Contains(netw, "unix") {
dialHost, dialPort, err = net.SplitHostPort(addrs[0])
if err != nil {
dialHost = addrs[0] // assume there was no port
}
} }
return DialInfo{ return DialInfo{
Upstream: upstream, Upstream: upstream,
Network: netw, Network: addr.Network,
Address: addrs[0], Address: addr.JoinHostPort(0),
Host: dialHost, Host: addr.Host,
Port: dialPort, Port: strconv.Itoa(int(addr.StartPort)),
}, nil }, nil
} }

View File

@ -242,40 +242,44 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next
// listeners in s that use a port which is not otherPort. // listeners in s that use a port which is not otherPort.
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool { func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
for _, lnAddr := range s.Listen { for _, lnAddr := range s.Listen {
_, addrs, err := caddy.ParseNetworkAddress(lnAddr) laddrs, err := caddy.ParseNetworkAddress(lnAddr)
if err == nil { if err != nil {
for _, a := range addrs { continue
_, port, err := net.SplitHostPort(a) }
if err == nil && port != strconv.Itoa(otherPort) { if uint(otherPort) > laddrs.EndPort || uint(otherPort) < laddrs.StartPort {
return true return true
}
}
} }
} }
return false return false
} }
// hasListenerAddress returns true if s has a listener
// at the given address fullAddr. Currently, fullAddr
// must represent exactly one socket address (port
// ranges are not supported)
func (s *Server) hasListenerAddress(fullAddr string) bool { func (s *Server) hasListenerAddress(fullAddr string) bool {
netw, addrs, err := caddy.ParseNetworkAddress(fullAddr) laddrs, err := caddy.ParseNetworkAddress(fullAddr)
if err != nil { if err != nil {
return false return false
} }
if len(addrs) != 1 { if laddrs.PortRangeSize() != 1 {
return false return false // TODO: support port ranges
} }
addr := addrs[0]
for _, lnAddr := range s.Listen { for _, lnAddr := range s.Listen {
thisNetw, thisAddrs, err := caddy.ParseNetworkAddress(lnAddr) thisAddrs, err := caddy.ParseNetworkAddress(lnAddr)
if err != nil { if err != nil {
continue continue
} }
if thisNetw != netw { if thisAddrs.Network != laddrs.Network {
continue continue
} }
for _, a := range thisAddrs {
if a == addr { // host must be the same and port must fall within port range
return true if (thisAddrs.Host == laddrs.Host) &&
} (laddrs.StartPort <= thisAddrs.EndPort) &&
(laddrs.StartPort >= thisAddrs.StartPort) {
return true
} }
} }
return false return false