httpcaddyfile: Detect ambiguous site definitions (fix #4635)

Previously, our "duplicate key in server block" logic was flawed because
it did not account for the site's bind address. We defer this check to
when the listener addresses have been assigned, but before we commit
a server block to its listener.

Also refined how network address parsing and joining works, which was
necessary for a less convoluted fix.
This commit is contained in:
Matthew Holt 2022-07-25 17:28:20 -06:00
parent 0bebea0d4c
commit 1e18afb5c8
No known key found for this signature in database
GPG Key ID: 2A349DD577D586A5
5 changed files with 99 additions and 45 deletions

View File

@ -183,6 +183,8 @@ func (st *ServerType) consolidateAddrMappings(addrToServerBlocks map[string][]se
return sbaddrs return sbaddrs
} }
// listenerAddrsForServerBlockKey essentially converts the Caddyfile
// site addresses to Caddy listener addresses for each server block.
func (st *ServerType) listenerAddrsForServerBlockKey(sblock serverBlock, key string, func (st *ServerType) listenerAddrsForServerBlockKey(sblock serverBlock, key string,
options map[string]interface{}) ([]string, error) { options map[string]interface{}) ([]string, error) {
addr, err := ParseAddress(key) addr, err := ParseAddress(key)
@ -232,12 +234,14 @@ func (st *ServerType) listenerAddrsForServerBlockKey(sblock serverBlock, key str
// use a map to prevent duplication // use a map to prevent duplication
listeners := make(map[string]struct{}) listeners := make(map[string]struct{})
for _, host := range lnHosts { for _, host := range lnHosts {
addr, err := caddy.ParseNetworkAddress(host) // host can have network + host (e.g. "tcp6/localhost") but
if err == nil && addr.IsUnixNetwork() { // will/should not have port information because this usually
listeners[host] = struct{}{} // comes from the bind directive, so we append the port
} else { addr, err := caddy.ParseNetworkAddress(host + ":" + lnPort)
listeners[host+":"+lnPort] = struct{}{} if err != nil {
return nil, fmt.Errorf("parsing network address: %v", err)
} }
listeners[addr.String()] = struct{}{}
} }
// now turn map into list // now turn map into list

View File

@ -58,22 +58,13 @@ func (st ServerType) Setup(inputServerBlocks []caddyfile.ServerBlock,
gc := counter{new(int)} gc := counter{new(int)}
state := make(map[string]interface{}) state := make(map[string]interface{})
// load all the server blocks and associate them with a "pile" // load all the server blocks and associate them with a "pile" of config values
// of config values; also prohibit duplicate keys because they
// can make a config confusing if more than one server block is
// chosen to handle a request - we actually will make each
// server block's route terminal so that only one will run
sbKeys := make(map[string]struct{})
originalServerBlocks := make([]serverBlock, 0, len(inputServerBlocks)) originalServerBlocks := make([]serverBlock, 0, len(inputServerBlocks))
for i, sblock := range inputServerBlocks { for _, sblock := range inputServerBlocks {
for j, k := range sblock.Keys { for j, k := range sblock.Keys {
if j == 0 && strings.HasPrefix(k, "@") { if j == 0 && strings.HasPrefix(k, "@") {
return nil, warnings, fmt.Errorf("cannot define a matcher outside of a site block: '%s'", k) return nil, warnings, fmt.Errorf("cannot define a matcher outside of a site block: '%s'", k)
} }
if _, ok := sbKeys[k]; ok {
return nil, warnings, fmt.Errorf("duplicate site address not allowed: '%s' in %v (site block %d, key %d)", k, sblock.Keys, i, j)
}
sbKeys[k] = struct{}{}
} }
originalServerBlocks = append(originalServerBlocks, serverBlock{ originalServerBlocks = append(originalServerBlocks, serverBlock{
block: sblock, block: sblock,
@ -420,6 +411,23 @@ func (st *ServerType) serversFromPairings(
} }
for i, p := range pairings { for i, p := range pairings {
// detect ambiguous site definitions: server blocks which
// have the same host bound to the same interface (listener
// address), otherwise their routes will improperly be added
// to the same server (see issue #4635)
for j, sblock1 := range p.serverBlocks {
for _, key := range sblock1.block.Keys {
for k, sblock2 := range p.serverBlocks {
if k == j {
continue
}
if sliceContains(sblock2.block.Keys, key) {
return nil, fmt.Errorf("ambiguous site definition: %s", key)
}
}
}
}
srv := &caddyhttp.Server{ srv := &caddyhttp.Server{
Listen: p.addresses, Listen: p.addresses,
} }

View File

@ -68,7 +68,7 @@ func TestDuplicateHosts(t *testing.T) {
} }
`, `,
"caddyfile", "caddyfile",
"duplicate site address not allowed") "ambiguous site definition")
} }
func TestReadCookie(t *testing.T) { func TestReadCookie(t *testing.T) {

View File

@ -391,10 +391,13 @@ func (na NetworkAddress) port() string {
return fmt.Sprintf("%d-%d", na.StartPort, na.EndPort) return fmt.Sprintf("%d-%d", na.StartPort, na.EndPort)
} }
// String reconstructs the address string to the form expected // String reconstructs the address string for human display.
// by ParseNetworkAddress(). If the address is a unix socket, // The output can be parsed by ParseNetworkAddress(). If the
// any non-zero port will be dropped. // address is a unix socket, any non-zero port will be dropped.
func (na NetworkAddress) String() string { func (na NetworkAddress) String() string {
if na.Network == "tcp" && (na.Host != "" || na.port() != "") {
na.Network = "" // omit default network value for brevity
}
return JoinNetworkAddress(na.Network, na.Host, na.port()) return JoinNetworkAddress(na.Network, na.Host, na.port())
} }
@ -427,36 +430,38 @@ func isListenBindAddressAlreadyInUseError(err error) bool {
func ParseNetworkAddress(addr string) (NetworkAddress, error) { func ParseNetworkAddress(addr string) (NetworkAddress, error) {
var host, port string var host, port string
network, host, port, err := SplitNetworkAddress(addr) network, host, port, err := SplitNetworkAddress(addr)
if network == "" {
network = "tcp"
}
if err != nil { if err != nil {
return NetworkAddress{}, err return NetworkAddress{}, err
} }
if network == "" {
network = "tcp"
}
if isUnixNetwork(network) { if isUnixNetwork(network) {
return NetworkAddress{ return NetworkAddress{
Network: network, Network: network,
Host: host, Host: host,
}, nil }, nil
} }
ports := strings.SplitN(port, "-", 2)
if len(ports) == 1 {
ports = append(ports, ports[0])
}
var start, end uint64 var start, end uint64
start, err = strconv.ParseUint(ports[0], 10, 16) if port != "" {
if err != nil { ports := strings.SplitN(port, "-", 2)
return NetworkAddress{}, fmt.Errorf("invalid start port: %v", err) if len(ports) == 1 {
} ports = append(ports, ports[0])
end, err = strconv.ParseUint(ports[1], 10, 16) }
if err != nil { start, err = strconv.ParseUint(ports[0], 10, 16)
return NetworkAddress{}, fmt.Errorf("invalid end port: %v", err) if err != nil {
} return NetworkAddress{}, fmt.Errorf("invalid start port: %v", err)
if end < start { }
return NetworkAddress{}, fmt.Errorf("end port must not be less than start port") end, err = strconv.ParseUint(ports[1], 10, 16)
} if err != nil {
if (end - start) > maxPortSpan { return NetworkAddress{}, fmt.Errorf("invalid end port: %v", err)
return NetworkAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan) }
if end < start {
return NetworkAddress{}, fmt.Errorf("end port must not be less than start port")
}
if (end - start) > maxPortSpan {
return NetworkAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan)
}
} }
return NetworkAddress{ return NetworkAddress{
Network: network, Network: network,
@ -478,6 +483,19 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) {
return return
} }
host, port, err = net.SplitHostPort(a) host, port, err = net.SplitHostPort(a)
if err == nil || a == "" {
return
}
// in general, if there was an error, it was likely "missing port",
// so try adding a bogus port to take advantage of standard library's
// robust parser, then strip the artificial port before returning
// (don't overwrite original error though; might still be relevant)
var err2 error
host, port, err2 = net.SplitHostPort(a + ":0")
if err2 == nil {
err = nil
port = ""
}
return return
} }

View File

@ -32,9 +32,24 @@ func TestSplitNetworkAddress(t *testing.T) {
expectErr: true, expectErr: true,
}, },
{ {
input: "foo", input: "foo",
expectHost: "foo",
},
{
input: ":", // empty host & empty port
},
{
input: "::",
expectErr: true, expectErr: true,
}, },
{
input: "[::]",
expectHost: "::",
},
{
input: ":1234",
expectPort: "1234",
},
{ {
input: "foo:1234", input: "foo:1234",
expectHost: "foo", expectHost: "foo",
@ -80,10 +95,10 @@ func TestSplitNetworkAddress(t *testing.T) {
} { } {
actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input) actualNetwork, actualHost, actualPort, err := SplitNetworkAddress(tc.input)
if tc.expectErr && err == nil { if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err) t.Errorf("Test %d: Expected error but got %v", i, err)
} }
if !tc.expectErr && err != nil { if !tc.expectErr && err != nil {
t.Errorf("Test %d: Expected no error but got: %v", i, err) t.Errorf("Test %d: Expected no error but got %v", i, err)
} }
if actualNetwork != tc.expectNetwork { if actualNetwork != tc.expectNetwork {
t.Errorf("Test %d: Expected network '%s' but got '%s'", i, tc.expectNetwork, actualNetwork) t.Errorf("Test %d: Expected network '%s' but got '%s'", i, tc.expectNetwork, actualNetwork)
@ -169,8 +184,17 @@ func TestParseNetworkAddress(t *testing.T) {
expectErr: true, expectErr: true,
}, },
{ {
input: ":", input: ":",
expectErr: true, expectAddr: NetworkAddress{
Network: "tcp",
},
},
{
input: "[::]",
expectAddr: NetworkAddress{
Network: "tcp",
Host: "::",
},
}, },
{ {
input: ":1234", input: ":1234",