diff --git a/listeners.go b/listeners.go index 7c5401a65..14a5c4991 100644 --- a/listeners.go +++ b/listeners.go @@ -303,13 +303,19 @@ func IsUnixNetwork(netw string) bool { // Network addresses are distinct from URLs and do not // use URL syntax. func ParseNetworkAddress(addr string) (NetworkAddress, error) { + return ParseNetworkAddressWithDefaults(addr, "tcp", 0) +} + +// ParseNetworkAddressWithDefaults is like ParseNetworkAddress but allows +// the default network and port to be specified. +func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort uint) (NetworkAddress, error) { var host, port string network, host, port, err := SplitNetworkAddress(addr) if err != nil { return NetworkAddress{}, err } if network == "" { - network = "tcp" + network = defaultNetwork } if IsUnixNetwork(network) { return NetworkAddress{ @@ -318,7 +324,10 @@ func ParseNetworkAddress(addr string) (NetworkAddress, error) { }, nil } var start, end uint64 - if port != "" { + if port == "" { + start = uint64(defaultPort) + end = uint64(defaultPort) + } else { before, after, found := strings.Cut(port, "-") if !found { after = before diff --git a/listeners_test.go b/listeners_test.go index c5aa5273a..5508a9f08 100644 --- a/listeners_test.go +++ b/listeners_test.go @@ -175,47 +175,57 @@ func TestJoinNetworkAddress(t *testing.T) { func TestParseNetworkAddress(t *testing.T) { for i, tc := range []struct { - input string - expectAddr NetworkAddress - expectErr bool + input string + defaultNetwork string + defaultPort uint + expectAddr NetworkAddress + expectErr bool }{ { input: "", expectErr: true, }, { - input: ":", + input: ":", + defaultNetwork: "udp", expectAddr: NetworkAddress{ - Network: "tcp", + Network: "udp", }, }, { - input: "[::]", + input: "[::]", + defaultNetwork: "udp", + defaultPort: 53, expectAddr: NetworkAddress{ - Network: "tcp", - Host: "::", + Network: "udp", + Host: "::", + StartPort: 53, + EndPort: 53, }, }, { - input: ":1234", + input: ":1234", + defaultNetwork: "udp", expectAddr: NetworkAddress{ - Network: "tcp", + Network: "udp", Host: "", StartPort: 1234, EndPort: 1234, }, }, { - input: "tcp/:1234", + input: "udp/:1234", + defaultNetwork: "udp", expectAddr: NetworkAddress{ - Network: "tcp", + Network: "udp", Host: "", StartPort: 1234, EndPort: 1234, }, }, { - input: "tcp6/:1234", + input: "tcp6/:1234", + defaultNetwork: "tcp", expectAddr: NetworkAddress{ Network: "tcp6", Host: "", @@ -224,7 +234,8 @@ func TestParseNetworkAddress(t *testing.T) { }, }, { - input: "tcp4/localhost:1234", + input: "tcp4/localhost:1234", + defaultNetwork: "tcp", expectAddr: NetworkAddress{ Network: "tcp4", Host: "localhost", @@ -233,14 +244,16 @@ func TestParseNetworkAddress(t *testing.T) { }, }, { - input: "unix//foo/bar", + input: "unix//foo/bar", + defaultNetwork: "tcp", expectAddr: NetworkAddress{ Network: "unix", Host: "/foo/bar", }, }, { - input: "localhost:1234-1234", + input: "localhost:1234-1234", + defaultNetwork: "tcp", expectAddr: NetworkAddress{ Network: "tcp", Host: "localhost", @@ -249,11 +262,13 @@ func TestParseNetworkAddress(t *testing.T) { }, }, { - input: "localhost:2-1", - expectErr: true, + input: "localhost:2-1", + defaultNetwork: "tcp", + expectErr: true, }, { - input: "localhost:0", + input: "localhost:0", + defaultNetwork: "tcp", expectAddr: NetworkAddress{ Network: "tcp", Host: "localhost", @@ -262,11 +277,138 @@ func TestParseNetworkAddress(t *testing.T) { }, }, { - input: "localhost:1-999999999999", - expectErr: true, + input: "localhost:1-999999999999", + defaultNetwork: "tcp", + expectErr: true, }, } { - actualAddr, err := ParseNetworkAddress(tc.input) + actualAddr, err := ParseNetworkAddressWithDefaults(tc.input, tc.defaultNetwork, tc.defaultPort) + if tc.expectErr && err == nil { + t.Errorf("Test %d: Expected error but got: %v", i, err) + } + if !tc.expectErr && err != nil { + t.Errorf("Test %d: Expected no error but got: %v", i, err) + } + + if actualAddr.Network != tc.expectAddr.Network { + t.Errorf("Test %d: Expected network '%v' but got '%v'", i, tc.expectAddr, actualAddr) + } + if !reflect.DeepEqual(tc.expectAddr, actualAddr) { + t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddr, actualAddr) + } + } +} + +func TestParseNetworkAddressWithDefaults(t *testing.T) { + for i, tc := range []struct { + input string + defaultNetwork string + defaultPort uint + expectAddr NetworkAddress + expectErr bool + }{ + { + input: "", + expectErr: true, + }, + { + input: ":", + defaultNetwork: "udp", + expectAddr: NetworkAddress{ + Network: "udp", + }, + }, + { + input: "[::]", + defaultNetwork: "udp", + defaultPort: 53, + expectAddr: NetworkAddress{ + Network: "udp", + Host: "::", + StartPort: 53, + EndPort: 53, + }, + }, + { + input: ":1234", + defaultNetwork: "udp", + expectAddr: NetworkAddress{ + Network: "udp", + Host: "", + StartPort: 1234, + EndPort: 1234, + }, + }, + { + input: "udp/:1234", + defaultNetwork: "udp", + expectAddr: NetworkAddress{ + Network: "udp", + Host: "", + StartPort: 1234, + EndPort: 1234, + }, + }, + { + input: "tcp6/:1234", + defaultNetwork: "tcp", + expectAddr: NetworkAddress{ + Network: "tcp6", + Host: "", + StartPort: 1234, + EndPort: 1234, + }, + }, + { + input: "tcp4/localhost:1234", + defaultNetwork: "tcp", + expectAddr: NetworkAddress{ + Network: "tcp4", + Host: "localhost", + StartPort: 1234, + EndPort: 1234, + }, + }, + { + input: "unix//foo/bar", + defaultNetwork: "tcp", + expectAddr: NetworkAddress{ + Network: "unix", + Host: "/foo/bar", + }, + }, + { + input: "localhost:1234-1234", + defaultNetwork: "tcp", + expectAddr: NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 1234, + EndPort: 1234, + }, + }, + { + input: "localhost:2-1", + defaultNetwork: "tcp", + expectErr: true, + }, + { + input: "localhost:0", + defaultNetwork: "tcp", + expectAddr: NetworkAddress{ + Network: "tcp", + Host: "localhost", + StartPort: 0, + EndPort: 0, + }, + }, + { + input: "localhost:1-999999999999", + defaultNetwork: "tcp", + expectErr: true, + }, + } { + actualAddr, err := ParseNetworkAddressWithDefaults(tc.input, tc.defaultNetwork, tc.defaultPort) if tc.expectErr && err == nil { t.Errorf("Test %d: Expected error but got: %v", i, err) } diff --git a/modules/caddyhttp/reverseproxy/upstreams.go b/modules/caddyhttp/reverseproxy/upstreams.go index 30bd7b575..42fedb651 100644 --- a/modules/caddyhttp/reverseproxy/upstreams.go +++ b/modules/caddyhttp/reverseproxy/upstreams.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "strconv" - "strings" "sync" "time" @@ -471,16 +470,9 @@ type UpstreamResolver struct { // and ensures they're ready to be used. func (u *UpstreamResolver) ParseAddresses() error { for _, v := range u.Addresses { - addr, err := caddy.ParseNetworkAddress(v) + addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53) if err != nil { - // If a port wasn't specified for the resolver, - // try defaulting to 53 and parse again - if strings.Contains(err.Error(), "missing port in address") { - addr, err = caddy.ParseNetworkAddress(v + ":53") - } - if err != nil { - return err - } + return err } if addr.PortRangeSize() != 1 { return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr) diff --git a/modules/caddypki/acmeserver/acmeserver.go b/modules/caddypki/acmeserver/acmeserver.go index 6ecdfdc66..0f739ec31 100644 --- a/modules/caddypki/acmeserver/acmeserver.go +++ b/modules/caddypki/acmeserver/acmeserver.go @@ -15,7 +15,10 @@ package acmeserver import ( + "context" "fmt" + weakrand "math/rand" + "net" "net/http" "os" "path/filepath" @@ -28,7 +31,7 @@ import ( "github.com/caddyserver/caddy/v2/modules/caddypki" "github.com/go-chi/chi" "github.com/smallstep/certificates/acme" - acmeAPI "github.com/smallstep/certificates/acme/api" + "github.com/smallstep/certificates/acme/api" acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" @@ -76,8 +79,26 @@ type Handler struct { // changed or removed in the future. SignWithRoot bool `json:"sign_with_root,omitempty"` + // The addresses of DNS resolvers to use when looking up + // the TXT records for solving DNS challenges. + // It accepts [network addresses](/docs/conventions#network-addresses) + // with port range of only 1. If the host is an IP address, + // it will be dialed directly to resolve the upstream server. + // If the host is not an IP address, the addresses are resolved + // using the [name resolution convention](https://golang.org/pkg/net/#hdr-Name_Resolution) + // of the Go standard library. If the array contains more + // than 1 resolver address, one is chosen at random. + Resolvers []string `json:"resolvers,omitempty"` + + logger *zap.Logger + resolvers []caddy.NetworkAddress + ctx caddy.Context + + acmeDB acme.DB + acmeAuth *authority.Authority + acmeClient acme.Client + acmeLinker acme.Linker acmeEndpoints http.Handler - logger *zap.Logger } // CaddyModule returns the Caddy module information. @@ -90,7 +111,9 @@ func (Handler) CaddyModule() caddy.ModuleInfo { // Provision sets up the ACME server handler. func (ash *Handler) Provision(ctx caddy.Context) error { + ash.ctx = ctx ash.logger = ctx.Logger() + // set some defaults if ash.CA == "" { ash.CA = caddypki.DefaultCAID @@ -142,31 +165,30 @@ func (ash *Handler) Provision(ctx caddy.Context) error { DB: database, } - auth, err := ca.NewAuthority(authorityConfig) + ash.acmeAuth, err = ca.NewAuthority(authorityConfig) if err != nil { return err } - var acmeDB acme.DB - if authorityConfig.DB != nil { - acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) - if err != nil { - return fmt.Errorf("configuring ACME DB: %v", err) - } + ash.acmeDB, err = acmeNoSQL.New(ash.acmeAuth.GetDatabase().(nosql.DB)) + if err != nil { + return fmt.Errorf("configuring ACME DB: %v", err) } - // create the router for the ACME endpoints - acmeRouterHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ - CA: auth, - DB: acmeDB, // stores all the server state - DNS: ash.Host, // used for directory links - Prefix: strings.Trim(ash.PathPrefix, "/"), // used for directory links - }) + ash.acmeClient, err = ash.makeClient() + if err != nil { + return err + } + + ash.acmeLinker = acme.NewLinker( + ash.Host, + strings.Trim(ash.PathPrefix, "/"), + ) // extract its http.Handler so we can use it directly r := chi.NewRouter() r.Route(ash.PathPrefix, func(r chi.Router) { - acmeRouterHandler.Route(r) + api.Route(r) }) ash.acmeEndpoints = r @@ -175,6 +197,16 @@ func (ash *Handler) Provision(ctx caddy.Context) error { func (ash Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { if strings.HasPrefix(r.URL.Path, ash.PathPrefix) { + acmeCtx := acme.NewContext( + r.Context(), + ash.acmeDB, + ash.acmeClient, + ash.acmeLinker, + nil, + ) + acmeCtx = authority.NewContext(acmeCtx, ash.acmeAuth) + r = r.WithContext(acmeCtx) + ash.acmeEndpoints.ServeHTTP(w, r) return nil } @@ -227,6 +259,55 @@ func (ash Handler) openDatabase() (*db.AuthDB, error) { return database.(databaseCloser).DB, err } +// makeClient creates an ACME client which will use a custom +// resolver instead of net.DefaultResolver. +func (ash Handler) makeClient() (acme.Client, error) { + for _, v := range ash.Resolvers { + addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53) + if err != nil { + return nil, err + } + if addr.PortRangeSize() != 1 { + return nil, fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr) + } + ash.resolvers = append(ash.resolvers, addr) + } + + var resolver *net.Resolver + if len(ash.resolvers) != 0 { + dialer := &net.Dialer{ + Timeout: 2 * time.Second, + } + resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + //nolint:gosec + addr := ash.resolvers[weakrand.Intn(len(ash.resolvers))] + return dialer.DialContext(ctx, addr.Network, addr.JoinHostPort(0)) + }, + } + } else { + resolver = net.DefaultResolver + } + + return resolverClient{ + Client: acme.NewClient(), + resolver: resolver, + ctx: ash.ctx, + }, nil +} + +type resolverClient struct { + acme.Client + + resolver *net.Resolver + ctx context.Context +} + +func (c resolverClient) LookupTxt(name string) ([]string, error) { + return c.resolver.LookupTXT(c.ctx, name) +} + const defaultPathPrefix = "/acme/" var keyCleaner = regexp.MustCompile(`[^\w.-_]`) diff --git a/modules/caddypki/acmeserver/caddyfile.go b/modules/caddypki/acmeserver/caddyfile.go index ae2d8ef11..3b52113b5 100644 --- a/modules/caddypki/acmeserver/caddyfile.go +++ b/modules/caddypki/acmeserver/caddyfile.go @@ -29,8 +29,9 @@ func init() { // parseACMEServer sets up an ACME server handler from Caddyfile tokens. // // acme_server [] { -// ca -// lifetime +// ca +// lifetime +// resolvers // } func parseACMEServer(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error) { if !h.Next() { @@ -74,6 +75,12 @@ func parseACMEServer(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error } acmeServer.Lifetime = caddy.Duration(dur) + + case "resolvers": + acmeServer.Resolvers = h.RemainingArgs() + if len(acmeServer.Resolvers) == 0 { + return nil, h.Errf("must specify at least one resolver address") + } } } }