From 93bc1b72e3cd566e6447ad7a1f832474aad5dfcc Mon Sep 17 00:00:00 2001 From: Mohammed Al Sahaf Date: Tue, 12 Nov 2019 01:33:38 +0300 Subject: [PATCH] core: Use port ranges to avoid OOM with bad inputs (#2859) * fix OOM issue caught by fuzzing * use ParsedAddress as the struct name for the result of ParseNetworkAddress * simplify code using the ParsedAddress type * minor cleanups --- admin.go | 16 +-- listeners.go | 101 ++++++++++++----- listeners_fuzz.go | 2 +- listeners_test.go | 105 +++++++++++------- modules/caddyhttp/caddyhttp.go | 33 +++--- .../caddyhttp/reverseproxy/healthchecks.go | 10 +- modules/caddyhttp/reverseproxy/hosts.go | 24 ++-- modules/caddyhttp/server.go | 40 ++++--- 8 files changed, 201 insertions(+), 130 deletions(-) diff --git a/admin.go b/admin.go index 502a968c5..b1ced18e7 100644 --- a/admin.go +++ b/admin.go @@ -48,23 +48,19 @@ type AdminConfig struct { // listenAddr extracts a singular listen address from ac.Listen, // returning the network and the address of the listener. -func (admin AdminConfig) listenAddr() (netw string, addr string, err error) { - var listenAddrs []string +func (admin AdminConfig) listenAddr() (string, string, error) { input := admin.Listen if input == "" { input = DefaultAdminListen } - netw, listenAddrs, err = ParseNetworkAddress(input) + listenAddr, err := ParseNetworkAddress(input) if err != nil { - err = fmt.Errorf("parsing admin listener address: %v", err) - return + return "", "", fmt.Errorf("parsing admin listener address: %v", err) } - if len(listenAddrs) != 1 { - err = fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddrs) - return + if listenAddr.PortRangeSize() != 1 { + return "", "", fmt.Errorf("admin endpoint must have exactly one address; cannot listen on %v", listenAddr) } - addr = listenAddrs[0] - return + return listenAddr.Network, listenAddr.JoinHostPort(0), nil } // newAdminHandler reads admin's config and returns an http.Handler suitable diff --git a/listeners.go b/listeners.go index 4464b7873..37b4c299f 100644 --- a/listeners.go +++ b/listeners.go @@ -257,52 +257,94 @@ type globalListener struct { pc net.PacketConn } -var ( - listeners = make(map[string]*globalListener) - listenersMu sync.Mutex -) +// ParsedAddress contains the individual components +// for a parsed network address of the form accepted +// by ParseNetworkAddress(). Network should be a +// network value accepted by Go's net package. Port +// ranges are given by [StartPort, EndPort]. +type ParsedAddress struct { + Network string + Host string + StartPort uint + EndPort uint +} -// ParseNetworkAddress parses addr, a string of the form "network/host:port" -// (with any part optional) into its component parts. Because a port can -// also be a port range, there may be multiple addresses returned. -func ParseNetworkAddress(addr string) (network string, addrs []string, err error) { +// JoinHostPort is like net.JoinHostPort, but where the port +// is StartPort + offset. +func (l ParsedAddress) JoinHostPort(offset uint) string { + return net.JoinHostPort(l.Host, strconv.Itoa(int(l.StartPort+offset))) +} + +// PortRangeSize returns how many ports are in +// pa's port range. Port ranges are inclusive, +// so the size is the difference of start and +// end ports plus one. +func (pa ParsedAddress) PortRangeSize() uint { + return (pa.EndPort - pa.StartPort) + 1 +} + +// String reconstructs the address string to the form expected +// by ParseNetworkAddress(). +func (pa ParsedAddress) String() string { + port := strconv.FormatUint(uint64(pa.StartPort), 10) + if pa.StartPort != pa.EndPort { + port += "-" + strconv.FormatUint(uint64(pa.EndPort), 10) + } + return JoinNetworkAddress(pa.Network, pa.Host, port) +} + +// ParseNetworkAddress parses addr into its individual +// components. The input string is expected to be of +// the form "network/host:port-range" where any part is +// optional. The default network, if unspecified, is tcp. +// Port ranges are inclusive. +// +// Network addresses are distinct from URLs and do not +// use URL syntax. +func ParseNetworkAddress(addr string) (ParsedAddress, error) { var host, port string - network, host, port, err = SplitNetworkAddress(addr) + network, host, port, err := SplitNetworkAddress(addr) if network == "" { network = "tcp" } if err != nil { - return + return ParsedAddress{}, err } if network == "unix" || network == "unixgram" || network == "unixpacket" { - addrs = []string{host} - return + return ParsedAddress{ + Network: network, + Host: host, + }, nil } ports := strings.SplitN(port, "-", 2) if len(ports) == 1 { ports = append(ports, ports[0]) } - var start, end int - start, err = strconv.Atoi(ports[0]) + var start, end uint64 + start, err = strconv.ParseUint(ports[0], 10, 16) if err != nil { - return + return ParsedAddress{}, fmt.Errorf("invalid start port: %v", err) } - end, err = strconv.Atoi(ports[1]) + end, err = strconv.ParseUint(ports[1], 10, 16) if err != nil { - return + return ParsedAddress{}, fmt.Errorf("invalid end port: %v", err) } if end < start { - err = fmt.Errorf("end port must be greater than start port") - return + return ParsedAddress{}, fmt.Errorf("end port must not be less than start port") } - for p := start; p <= end; p++ { - addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p))) + if (end - start) > maxPortSpan { + return ParsedAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan) } - return + return ParsedAddress{ + Network: network, + Host: host, + StartPort: uint(start), + EndPort: uint(end), + }, nil } // SplitNetworkAddress splits a into its network, host, and port components. -// Note that port may be a port range, or omitted for unix sockets. +// Note that port may be a port range (:X-Y), or omitted for unix sockets. func SplitNetworkAddress(a string) (network, host, port string, err error) { if idx := strings.Index(a, "/"); idx >= 0 { network = strings.ToLower(strings.TrimSpace(a[:idx])) @@ -317,9 +359,9 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) { } // JoinNetworkAddress combines network, host, and port into a single -// address string of the form "network/host:port". Port may be a -// port range. For unix sockets, the network should be "unix" and -// the path to the socket should be given in the host argument. +// address string of the form accepted by ParseNetworkAddress(). For unix sockets, the network +// should be "unix" and the path to the socket should be given as the +// host parameter. func JoinNetworkAddress(network, host, port string) string { var a string if network != "" { @@ -332,3 +374,10 @@ func JoinNetworkAddress(network, host, port string) string { } return a } + +var ( + listeners = make(map[string]*globalListener) + listenersMu sync.Mutex +) + +const maxPortSpan = 65535 diff --git a/listeners_fuzz.go b/listeners_fuzz.go index 98465fd2d..826c57e82 100644 --- a/listeners_fuzz.go +++ b/listeners_fuzz.go @@ -18,7 +18,7 @@ package caddy func FuzzParseNetworkAddress(data []byte) int { - _, _, err := ParseNetworkAddress(string(data)) + _, err := ParseNetworkAddress(string(data)) if err != nil { return 0 } diff --git a/listeners_test.go b/listeners_test.go index bdddf3221..076b36558 100644 --- a/listeners_test.go +++ b/listeners_test.go @@ -152,74 +152,101 @@ func TestJoinNetworkAddress(t *testing.T) { func TestParseNetworkAddress(t *testing.T) { for i, tc := range []struct { - input string - expectNetwork string - expectAddrs []string - expectErr bool + input string + expectAddr ParsedAddress + expectErr bool }{ { - input: "", - expectNetwork: "tcp", - expectErr: true, + input: "", + expectErr: true, }, { - input: ":", - expectNetwork: "tcp", - expectErr: true, + input: ":", + expectErr: true, }, { - input: ":1234", - expectNetwork: "tcp", - expectAddrs: []string{":1234"}, + input: ":1234", + expectAddr: ParsedAddress{ + Network: "tcp", + Host: "", + StartPort: 1234, + EndPort: 1234, + }, }, { - input: "tcp/:1234", - expectNetwork: "tcp", - expectAddrs: []string{":1234"}, + input: "tcp/:1234", + expectAddr: ParsedAddress{ + Network: "tcp", + Host: "", + StartPort: 1234, + EndPort: 1234, + }, }, { - input: "tcp6/:1234", - expectNetwork: "tcp6", - expectAddrs: []string{":1234"}, + input: "tcp6/:1234", + expectAddr: ParsedAddress{ + Network: "tcp6", + Host: "", + StartPort: 1234, + EndPort: 1234, + }, }, { - input: "tcp4/localhost:1234", - expectNetwork: "tcp4", - expectAddrs: []string{"localhost:1234"}, + input: "tcp4/localhost:1234", + expectAddr: ParsedAddress{ + Network: "tcp4", + Host: "localhost", + StartPort: 1234, + EndPort: 1234, + }, }, { - input: "unix//foo/bar", - expectNetwork: "unix", - expectAddrs: []string{"/foo/bar"}, + input: "unix//foo/bar", + expectAddr: ParsedAddress{ + Network: "unix", + Host: "/foo/bar", + }, }, { - input: "localhost:1234-1234", - expectNetwork: "tcp", - expectAddrs: []string{"localhost:1234"}, + input: "localhost:1234-1234", + expectAddr: ParsedAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 1234, + EndPort: 1234, + }, }, { - input: "localhost:2-1", - expectNetwork: "tcp", - expectErr: true, + input: "localhost:2-1", + expectErr: true, }, { - input: "localhost:0", - expectNetwork: "tcp", - expectAddrs: []string{"localhost:0"}, + input: "localhost:0", + expectAddr: ParsedAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 0, + EndPort: 0, + }, + }, + { + input: "localhost:1-999999999999", + expectErr: true, }, } { - actualNetwork, actualAddrs, err := ParseNetworkAddress(tc.input) + actualAddr, err := ParseNetworkAddress(tc.input) if tc.expectErr && err == nil { t.Errorf("Test %d: Expected error but got: %v", i, err) } if !tc.expectErr && err != nil { t.Errorf("Test %d: Expected no error but got: %v", i, err) } - if actualNetwork != tc.expectNetwork { - t.Errorf("Test %d: Expected network '%s' but got '%s'", i, tc.expectNetwork, actualNetwork) + + if actualAddr.Network != tc.expectAddr.Network { + t.Errorf("Test %d: Expected network '%v' but got '%v'", i, tc.expectAddr, actualAddr) } - if !reflect.DeepEqual(tc.expectAddrs, actualAddrs) { - t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddrs, actualAddrs) + if !reflect.DeepEqual(tc.expectAddr, actualAddr) { + t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddr, actualAddr) } } } diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index 99a64c3b2..36d81542f 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -135,15 +135,18 @@ func (app *App) Validate() error { lnAddrs := make(map[string]string) for srvName, srv := range app.Servers { for _, addr := range srv.Listen { - netw, expanded, err := caddy.ParseNetworkAddress(addr) + listenAddr, err := caddy.ParseNetworkAddress(addr) if err != nil { return fmt.Errorf("invalid listener address '%s': %v", addr, err) } - for _, a := range expanded { - if sn, ok := lnAddrs[netw+a]; ok { - return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, a, sn) + // check that every address in the port range is unique to this server; + // we do not use <= here because PortRangeSize() adds 1 to EndPort for us + for i := uint(0); i < listenAddr.PortRangeSize(); i++ { + addr := caddy.JoinNetworkAddress(listenAddr.Network, listenAddr.Host, strconv.Itoa(int(listenAddr.StartPort+i))) + if sn, ok := lnAddrs[addr]; ok { + return fmt.Errorf("server %s: listener address repeated: %s (already claimed by server '%s')", srvName, addr, sn) } - lnAddrs[netw+a] = srvName + lnAddrs[addr] = srvName } } } @@ -176,14 +179,15 @@ func (app *App) Start() error { } for _, lnAddr := range srv.Listen { - network, addrs, err := caddy.ParseNetworkAddress(lnAddr) + listenAddr, err := caddy.ParseNetworkAddress(lnAddr) if err != nil { return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) } - for _, addr := range addrs { - ln, err := caddy.Listen(network, addr) + for i := uint(0); i <= listenAddr.PortRangeSize(); i++ { + hostport := listenAddr.JoinHostPort(i) + ln, err := caddy.Listen(listenAddr.Network, hostport) if err != nil { - return fmt.Errorf("%s: listening on %s: %v", network, addr, err) + return fmt.Errorf("%s: listening on %s: %v", listenAddr.Network, hostport, err) } // enable HTTP/2 by default @@ -194,11 +198,10 @@ func (app *App) Start() error { } // enable TLS - _, port, _ := net.SplitHostPort(addr) - if len(srv.TLSConnPolicies) > 0 && port != strconv.Itoa(app.httpPort()) { + if len(srv.TLSConnPolicies) > 0 && int(i) != app.httpPort() { tlsCfg, err := srv.TLSConnPolicies.TLSConfig(app.ctx) if err != nil { - return fmt.Errorf("%s/%s: making TLS configuration: %v", network, addr, err) + return fmt.Errorf("%s/%s: making TLS configuration: %v", listenAddr.Network, hostport, err) } ln = tls.NewListener(ln, tlsCfg) @@ -206,15 +209,15 @@ func (app *App) Start() error { // TODO: HTTP/3 support is experimental for now if srv.ExperimentalHTTP3 { app.logger.Info("enabling experimental HTTP/3 listener", - zap.String("addr", addr), + zap.String("addr", hostport), ) - h3ln, err := caddy.ListenPacket("udp", addr) + h3ln, err := caddy.ListenPacket("udp", hostport) if err != nil { return fmt.Errorf("getting HTTP/3 UDP listener: %v", err) } h3srv := &http3.Server{ Server: &http.Server{ - Addr: addr, + Addr: hostport, Handler: srv, TLSConfig: tlsCfg, }, diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 56e97bc0f..92b354787 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -102,7 +102,7 @@ func (h *Handler) doActiveHealthChecksForAllHosts() { host := value.(Host) go func(networkAddr string, host Host) { - network, addrs, err := caddy.ParseNetworkAddress(networkAddr) + addr, err := caddy.ParseNetworkAddress(networkAddr) if err != nil { h.HealthChecks.Active.logger.Error("bad network address", zap.String("address", networkAddr), @@ -110,20 +110,20 @@ func (h *Handler) doActiveHealthChecksForAllHosts() { ) return } - if len(addrs) != 1 { + if addr.PortRangeSize() != 1 { h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)", zap.String("address", networkAddr), ) return } - hostAddr := addrs[0] - if network == "unix" || network == "unixgram" || network == "unixpacket" { + hostAddr := addr.JoinHostPort(0) + if addr.Network == "unix" || addr.Network == "unixgram" || addr.Network == "unixpacket" { // this will be used as the Host portion of a http.Request URL, and // paths to socket files would produce an error when creating URL, // so use a fake Host value instead; unix sockets are usually local hostAddr = "localhost" } - err = h.doActiveHealthCheck(DialInfo{Network: network, Address: addrs[0]}, hostAddr, host) + err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: hostAddr}, hostAddr, host) if err != nil { h.HealthChecks.Active.logger.Error("active health check failed", zap.String("address", networkAddr), diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index a16bed0a0..8bad7c20a 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -16,8 +16,7 @@ package reverseproxy import ( "fmt" - "net" - "strings" + "strconv" "sync/atomic" "github.com/caddyserver/caddy/v2" @@ -193,27 +192,20 @@ func (di DialInfo) String() string { // the given Replacer. Note that the returned value is not a pointer. func fillDialInfo(upstream *Upstream, repl caddy.Replacer) (DialInfo, error) { dial := repl.ReplaceAll(upstream.Dial, "") - netw, addrs, err := caddy.ParseNetworkAddress(dial) + addr, err := caddy.ParseNetworkAddress(dial) if err != nil { return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err) } - if len(addrs) != 1 { + if numPorts := addr.PortRangeSize(); numPorts != 1 { return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d", - upstream.Dial, dial, len(addrs)) - } - var dialHost, dialPort string - if !strings.Contains(netw, "unix") { - dialHost, dialPort, err = net.SplitHostPort(addrs[0]) - if err != nil { - dialHost = addrs[0] // assume there was no port - } + upstream.Dial, dial, numPorts) } return DialInfo{ Upstream: upstream, - Network: netw, - Address: addrs[0], - Host: dialHost, - Port: dialPort, + Network: addr.Network, + Address: addr.JoinHostPort(0), + Host: addr.Host, + Port: strconv.Itoa(int(addr.StartPort)), }, nil } diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index e119c2d34..17860ed9d 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -242,40 +242,44 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next // listeners in s that use a port which is not otherPort. func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool { for _, lnAddr := range s.Listen { - _, addrs, err := caddy.ParseNetworkAddress(lnAddr) - if err == nil { - for _, a := range addrs { - _, port, err := net.SplitHostPort(a) - if err == nil && port != strconv.Itoa(otherPort) { - return true - } - } + laddrs, err := caddy.ParseNetworkAddress(lnAddr) + if err != nil { + continue + } + if uint(otherPort) > laddrs.EndPort || uint(otherPort) < laddrs.StartPort { + return true } } return false } +// hasListenerAddress returns true if s has a listener +// at the given address fullAddr. Currently, fullAddr +// must represent exactly one socket address (port +// ranges are not supported) func (s *Server) hasListenerAddress(fullAddr string) bool { - netw, addrs, err := caddy.ParseNetworkAddress(fullAddr) + laddrs, err := caddy.ParseNetworkAddress(fullAddr) if err != nil { return false } - if len(addrs) != 1 { - return false + if laddrs.PortRangeSize() != 1 { + return false // TODO: support port ranges } - addr := addrs[0] + for _, lnAddr := range s.Listen { - thisNetw, thisAddrs, err := caddy.ParseNetworkAddress(lnAddr) + thisAddrs, err := caddy.ParseNetworkAddress(lnAddr) if err != nil { continue } - if thisNetw != netw { + if thisAddrs.Network != laddrs.Network { continue } - for _, a := range thisAddrs { - if a == addr { - return true - } + + // host must be the same and port must fall within port range + if (thisAddrs.Host == laddrs.Host) && + (laddrs.StartPort <= thisAddrs.EndPort) && + (laddrs.StartPort >= thisAddrs.StartPort) { + return true } } return false