core: Support Windows absolute paths for UDS proxy upstreams (#5114)

* added some tests for parseUpstreamDialAddress

Test 4 fails because it produces "[[::1]]:80" instead of "[::1]:80"

* support absolute windows path in unix reverse proxy address

* make IsUnixNetwork public, support +h2c and reuse it
* add new tests
This commit is contained in:
Steffen Brüheim 2023-02-08 18:05:09 +01:00 committed by GitHub
parent c77a6bea66
commit 536c28d4dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 255 additions and 18 deletions

View File

@ -34,7 +34,7 @@ import (
// reuseUnixSocket copies and reuses the unix domain socket (UDS) if we already // reuseUnixSocket copies and reuses the unix domain socket (UDS) if we already
// have it open; if not, unlink it so we can have it. No-op if not a unix network. // have it open; if not, unlink it so we can have it. No-op if not a unix network.
func reuseUnixSocket(network, addr string) (any, error) { func reuseUnixSocket(network, addr string) (any, error) {
if !isUnixNetwork(network) { if !IsUnixNetwork(network) {
return nil, nil return nil, nil
} }
@ -103,7 +103,7 @@ func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string,
// reusePort sets SO_REUSEPORT. Ineffective for unix sockets. // reusePort sets SO_REUSEPORT. Ineffective for unix sockets.
func reusePort(network, address string, conn syscall.RawConn) error { func reusePort(network, address string, conn syscall.RawConn) error {
if isUnixNetwork(network) { if IsUnixNetwork(network) {
return nil return nil
} }
return conn.Control(func(descriptor uintptr) { return conn.Control(func(descriptor uintptr) {

View File

@ -205,7 +205,7 @@ func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net
// IsUnixNetwork returns true if na.Network is // IsUnixNetwork returns true if na.Network is
// unix, unixgram, or unixpacket. // unix, unixgram, or unixpacket.
func (na NetworkAddress) IsUnixNetwork() bool { func (na NetworkAddress) IsUnixNetwork() bool {
return isUnixNetwork(na.Network) return IsUnixNetwork(na.Network)
} }
// JoinHostPort is like net.JoinHostPort, but where the port // JoinHostPort is like net.JoinHostPort, but where the port
@ -289,8 +289,9 @@ func (na NetworkAddress) String() string {
return JoinNetworkAddress(na.Network, na.Host, na.port()) return JoinNetworkAddress(na.Network, na.Host, na.port())
} }
func isUnixNetwork(netw string) bool { // IsUnixNetwork returns true if the netw is a unix network.
return netw == "unix" || netw == "unixgram" || netw == "unixpacket" func IsUnixNetwork(netw string) bool {
return strings.HasPrefix(netw, "unix")
} }
// ParseNetworkAddress parses addr into its individual // ParseNetworkAddress parses addr into its individual
@ -310,7 +311,7 @@ func ParseNetworkAddress(addr string) (NetworkAddress, error) {
if network == "" { if network == "" {
network = "tcp" network = "tcp"
} }
if isUnixNetwork(network) { if IsUnixNetwork(network) {
return NetworkAddress{ return NetworkAddress{
Network: network, Network: network,
Host: host, Host: host,
@ -353,7 +354,7 @@ func SplitNetworkAddress(a string) (network, host, port string, err error) {
network = strings.ToLower(strings.TrimSpace(beforeSlash)) network = strings.ToLower(strings.TrimSpace(beforeSlash))
a = afterSlash a = afterSlash
} }
if isUnixNetwork(network) { if IsUnixNetwork(network) {
host = a host = a
return return
} }
@ -384,7 +385,7 @@ func JoinNetworkAddress(network, host, port string) string {
if network != "" { if network != "" {
a = network + "/" a = network + "/"
} }
if (host != "" && port == "") || isUnixNetwork(network) { if (host != "" && port == "") || IsUnixNetwork(network) {
a += host a += host
} else if port != "" { } else if port != "" {
a += net.JoinHostPort(host, port) a += net.JoinHostPort(host, port)

View File

@ -27,9 +27,6 @@ import (
// the dial address, including support for a scheme in front // the dial address, including support for a scheme in front
// as a shortcut for the port number, and a network type, // as a shortcut for the port number, and a network type,
// for example 'unix' to dial a unix socket. // for example 'unix' to dial a unix socket.
//
// TODO: the logic in this function is kind of sensitive, we
// need to write tests before making any more changes to it
func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) { func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) {
var network, scheme, host, port string var network, scheme, host, port string
@ -79,19 +76,14 @@ func parseUpstreamDialAddress(upstreamAddr string) (string, string, error) {
scheme, host, port = toURL.Scheme, toURL.Hostname(), toURL.Port() scheme, host, port = toURL.Scheme, toURL.Hostname(), toURL.Port()
} else { } else {
// extract network manually, since caddy.ParseNetworkAddress() will always add one
if beforeSlash, afterSlash, slashFound := strings.Cut(upstreamAddr, "/"); slashFound {
network = strings.ToLower(strings.TrimSpace(beforeSlash))
upstreamAddr = afterSlash
}
var err error var err error
host, port, err = net.SplitHostPort(upstreamAddr) network, host, port, err = caddy.SplitNetworkAddress(upstreamAddr)
if err != nil { if err != nil {
host = upstreamAddr host = upstreamAddr
} }
// we can assume a port if only a hostname is specified, but use of a // we can assume a port if only a hostname is specified, but use of a
// placeholder without a port likely means a port will be filled in // placeholder without a port likely means a port will be filled in
if port == "" && !strings.Contains(host, "{") { if port == "" && !strings.Contains(host, "{") && !caddy.IsUnixNetwork(network) {
port = "80" port = "80"
} }
} }

View File

@ -0,0 +1,244 @@
// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package reverseproxy
import "testing"
func TestParseUpstreamDialAddress(t *testing.T) {
for i, tc := range []struct {
input string
expectHostPort string
expectScheme string
expectErr bool
}{
{
input: "foo",
expectHostPort: "foo:80",
},
{
input: "foo:1234",
expectHostPort: "foo:1234",
},
{
input: "127.0.0.1",
expectHostPort: "127.0.0.1:80",
},
{
input: "127.0.0.1:1234",
expectHostPort: "127.0.0.1:1234",
},
{
input: "[::1]",
expectHostPort: "[::1]:80",
},
{
input: "[::1]:1234",
expectHostPort: "[::1]:1234",
},
{
input: "{foo}",
expectHostPort: "{foo}",
},
{
input: "{foo}:80",
expectHostPort: "{foo}:80",
},
{
input: "{foo}:{bar}",
expectHostPort: "{foo}:{bar}",
},
{
input: "http://foo",
expectHostPort: "foo:80",
expectScheme: "http",
},
{
input: "http://foo:1234",
expectHostPort: "foo:1234",
expectScheme: "http",
},
{
input: "http://127.0.0.1",
expectHostPort: "127.0.0.1:80",
expectScheme: "http",
},
{
input: "http://127.0.0.1:1234",
expectHostPort: "127.0.0.1:1234",
expectScheme: "http",
},
{
input: "http://[::1]",
expectHostPort: "[::1]:80",
expectScheme: "http",
},
{
input: "http://[::1]:80",
expectHostPort: "[::1]:80",
expectScheme: "http",
},
{
input: "https://foo",
expectHostPort: "foo:443",
expectScheme: "https",
},
{
input: "https://foo:1234",
expectHostPort: "foo:1234",
expectScheme: "https",
},
{
input: "https://127.0.0.1",
expectHostPort: "127.0.0.1:443",
expectScheme: "https",
},
{
input: "https://127.0.0.1:1234",
expectHostPort: "127.0.0.1:1234",
expectScheme: "https",
},
{
input: "https://[::1]",
expectHostPort: "[::1]:443",
expectScheme: "https",
},
{
input: "https://[::1]:1234",
expectHostPort: "[::1]:1234",
expectScheme: "https",
},
{
input: "h2c://foo",
expectHostPort: "foo:80",
expectScheme: "h2c",
},
{
input: "h2c://foo:1234",
expectHostPort: "foo:1234",
expectScheme: "h2c",
},
{
input: "h2c://127.0.0.1",
expectHostPort: "127.0.0.1:80",
expectScheme: "h2c",
},
{
input: "h2c://127.0.0.1:1234",
expectHostPort: "127.0.0.1:1234",
expectScheme: "h2c",
},
{
input: "h2c://[::1]",
expectHostPort: "[::1]:80",
expectScheme: "h2c",
},
{
input: "h2c://[::1]:1234",
expectHostPort: "[::1]:1234",
expectScheme: "h2c",
},
{
input: "unix//var/php.sock",
expectHostPort: "unix//var/php.sock",
},
{
input: "unix+h2c//var/grpc.sock",
expectHostPort: "unix//var/grpc.sock",
expectScheme: "h2c",
},
{
input: "unix/{foo}",
expectHostPort: "unix/{foo}",
},
{
input: "unix+h2c/{foo}",
expectHostPort: "unix/{foo}",
expectScheme: "h2c",
},
{
input: "unix//foo/{foo}/bar",
expectHostPort: "unix//foo/{foo}/bar",
},
{
input: "unix+h2c//foo/{foo}/bar",
expectHostPort: "unix//foo/{foo}/bar",
expectScheme: "h2c",
},
{
input: "http://{foo}",
expectErr: true,
},
{
input: "http:// :80",
expectErr: true,
},
{
input: "http://localhost/path",
expectErr: true,
},
{
input: "http://localhost?key=value",
expectErr: true,
},
{
input: "http://localhost#fragment",
expectErr: true,
},
{
input: "http://foo:443",
expectErr: true,
},
{
input: "https://foo:80",
expectErr: true,
},
{
input: "h2c://foo:443",
expectErr: true,
},
{
input: `unix/c:\absolute\path`,
expectHostPort: `unix/c:\absolute\path`,
},
{
input: `unix+h2c/c:\absolute\path`,
expectHostPort: `unix/c:\absolute\path`,
expectScheme: "h2c",
},
{
input: "unix/c:/absolute/path",
expectHostPort: "unix/c:/absolute/path",
},
{
input: "unix+h2c/c:/absolute/path",
expectHostPort: "unix/c:/absolute/path",
expectScheme: "h2c",
},
} {
actualHostPort, actualScheme, err := parseUpstreamDialAddress(tc.input)
if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got %v", i, err)
}
if !tc.expectErr && err != nil {
t.Errorf("Test %d: Expected no error but got %v", i, err)
}
if actualHostPort != tc.expectHostPort {
t.Errorf("Test %d: Expected host and port '%s' but got '%s'", i, tc.expectHostPort, actualHostPort)
}
if actualScheme != tc.expectScheme {
t.Errorf("Test %d: Expected scheme '%s' but got '%s'", i, tc.expectScheme, actualScheme)
}
}
}