diff --git a/listen.go b/listen.go index e0d67c6ab..0cd3fabb7 100644 --- a/listen.go +++ b/listen.go @@ -30,18 +30,34 @@ func reuseUnixSocket(network, addr string) (any, error) { return nil, nil } -func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (net.Listener, error) { - sharedLn, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { - ln, err := config.Listen(ctx, network, address) +func listenReusable(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (any, error) { + switch network { + case "udp", "udp4", "udp6", "unixgram": + sharedPc, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { + pc, err := config.ListenPacket(ctx, network, address) + if err != nil { + return nil, err + } + return &sharedPacketConn{PacketConn: pc, key: lnKey}, nil + }) if err != nil { return nil, err } - return &sharedListener{Listener: ln, key: lnKey}, nil - }) - if err != nil { - return nil, err + return &fakeClosePacketConn{sharedPacketConn: sharedPc.(*sharedPacketConn)}, nil + + default: + sharedLn, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { + ln, err := config.Listen(ctx, network, address) + if err != nil { + return nil, err + } + return &sharedListener{Listener: ln, key: lnKey}, nil + }) + if err != nil { + return nil, err + } + return &fakeCloseListener{sharedListener: sharedLn.(*sharedListener), keepAlivePeriod: config.KeepAlive}, nil } - return &fakeCloseListener{sharedListener: sharedLn.(*sharedListener), keepAlivePeriod: config.KeepAlive}, nil } // fakeCloseListener is a private wrapper over a listener that @@ -98,7 +114,7 @@ func (fcl *fakeCloseListener) Accept() (net.Conn, error) { // so that it's clear in the code that side-effects are shared with other // users of this listener, not just our own reference to it; we also don't // do anything with the error because all we could do is log it, but we - // expliclty assign it to nothing so we don't forget it's there if needed + // explicitly assign it to nothing so we don't forget it's there if needed _ = fcl.sharedListener.clearDeadline() if netErr, ok := err.(net.Error); ok && netErr.Timeout() { @@ -172,3 +188,75 @@ func (sl *sharedListener) setDeadline() error { func (sl *sharedListener) Destruct() error { return sl.Listener.Close() } + +// fakeClosePacketConn is like fakeCloseListener, but for PacketConns, +// or more specifically, *net.UDPConn +type fakeClosePacketConn struct { + closed int32 // accessed atomically; belongs to this struct only + *sharedPacketConn // embedded, so we also become a net.PacketConn; its key is used in Close +} + +func (fcpc *fakeClosePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + // if the listener is already "closed", return error + if atomic.LoadInt32(&fcpc.closed) == 1 { + return 0, nil, &net.OpError{ + Op: "readfrom", + Net: fcpc.LocalAddr().Network(), + Addr: fcpc.LocalAddr(), + Err: errFakeClosed, + } + } + + // call underlying readfrom + n, addr, err = fcpc.sharedPacketConn.ReadFrom(p) + if err != nil { + // this server was stopped, so clear the deadline and let + // any new server continue reading; but we will exit + if atomic.LoadInt32(&fcpc.closed) == 1 { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + if err = fcpc.SetReadDeadline(time.Time{}); err != nil { + return + } + } + } + return + } + + return +} + +// Close won't close the underlying socket unless there is no more reference, then listenerPool will close it. +func (fcpc *fakeClosePacketConn) Close() error { + if atomic.CompareAndSwapInt32(&fcpc.closed, 0, 1) { + _ = fcpc.SetReadDeadline(time.Now()) // unblock ReadFrom() calls to kick old servers out of their loops + _, _ = listenerPool.Delete(fcpc.sharedPacketConn.key) + } + return nil +} + +func (fcpc *fakeClosePacketConn) Unwrap() net.PacketConn { + return fcpc.sharedPacketConn.PacketConn +} + +// sharedPacketConn is like sharedListener, but for net.PacketConns. +type sharedPacketConn struct { + net.PacketConn + key string +} + +// Destruct closes the underlying socket. +func (spc *sharedPacketConn) Destruct() error { + return spc.PacketConn.Close() +} + +// Unwrap returns the underlying socket +func (spc *sharedPacketConn) Unwrap() net.PacketConn { + return spc.PacketConn +} + +// Interface guards (see https://github.com/caddyserver/caddy/issues/3998) +var ( + _ (interface { + Unwrap() net.PacketConn + }) = (*fakeClosePacketConn)(nil) +) diff --git a/listen_unix.go b/listen_unix.go index 8870da5e9..34cd76c5c 100644 --- a/listen_unix.go +++ b/listen_unix.go @@ -22,8 +22,10 @@ package caddy import ( "context" "errors" + "io" "io/fs" "net" + "os" "sync/atomic" "syscall" @@ -87,7 +89,7 @@ func reuseUnixSocket(network, addr string) (any, error) { return nil, nil } -func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (net.Listener, error) { +func listenReusable(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (any, error) { // wrap any Control function set by the user so we can also add our reusePort control without clobbering theirs oldControl := config.Control config.Control = func(network, address string, c syscall.RawConn) error { @@ -103,7 +105,14 @@ func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string, // we still put it in the listenerPool so we can count how many // configs are using this socket; necessary to ensure we can know // whether to enforce shutdown delays, for example (see #5393). - ln, err := config.Listen(ctx, network, address) + var ln io.Closer + var err error + switch network { + case "udp", "udp4", "udp6", "unixgram": + ln, err = config.ListenPacket(ctx, network, address) + default: + ln, err = config.Listen(ctx, network, address) + } if err == nil { listenerPool.LoadOrStore(lnKey, nil) } @@ -117,9 +126,23 @@ func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string, unixSockets[lnKey] = ln.(*unixListener) } + // TODO: Not 100% sure this is necessary, but we do this for net.UnixListener in listen_unix.go, so... + if unix, ok := ln.(*net.UnixConn); ok { + ln = &unixConn{unix, address, lnKey, &one} + unixSockets[lnKey] = ln.(*unixConn) + } + // lightly wrap the listener so that when it is closed, // we can decrement the usage pool counter - return deleteListener{ln, lnKey}, err + switch specificLn := ln.(type) { + case net.Listener: + return deleteListener{specificLn, lnKey}, err + case net.PacketConn: + return deletePacketConn{specificLn, lnKey}, err + } + + // other types, I guess we just return them directly + return ln, err } // reusePort sets SO_REUSEPORT. Ineffective for unix sockets. @@ -158,6 +181,36 @@ func (uln *unixListener) Close() error { return uln.UnixListener.Close() } +type unixConn struct { + *net.UnixConn + filename string + mapKey string + count *int32 // accessed atomically +} + +func (uc *unixConn) Close() error { + newCount := atomic.AddInt32(uc.count, -1) + if newCount == 0 { + defer func() { + unixSocketsMu.Lock() + delete(unixSockets, uc.mapKey) + unixSocketsMu.Unlock() + _ = syscall.Unlink(uc.filename) + }() + } + return uc.UnixConn.Close() +} + +func (uc *unixConn) Unwrap() net.PacketConn { + return uc.UnixConn +} + +// unixSockets keeps track of the currently-active unix sockets +// so we can transfer their FDs gracefully during reloads. +var unixSockets = make(map[string]interface { + File() (*os.File, error) +}) + // deleteListener is a type that simply deletes itself // from the listenerPool when it closes. It is used // solely for the purpose of reference counting (i.e. @@ -171,3 +224,19 @@ func (dl deleteListener) Close() error { _, _ = listenerPool.Delete(dl.lnKey) return dl.Listener.Close() } + +// deletePacketConn is like deleteListener, but +// for net.PacketConns. +type deletePacketConn struct { + net.PacketConn + lnKey string +} + +func (dl deletePacketConn) Close() error { + _, _ = listenerPool.Delete(dl.lnKey) + return dl.PacketConn.Close() +} + +func (dl deletePacketConn) Unwrap() net.PacketConn { + return dl.PacketConn +} diff --git a/listeners.go b/listeners.go index 67c519670..84a32e45a 100644 --- a/listeners.go +++ b/listeners.go @@ -28,7 +28,6 @@ import ( "strings" "sync" "sync/atomic" - "syscall" "time" "github.com/quic-go/quic-go" @@ -149,11 +148,13 @@ func (na NetworkAddress) Listen(ctx context.Context, portOffset uint, config net } func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net.ListenConfig) (any, error) { - var ln any - var err error - var address string - var unixFileMode fs.FileMode - var isAbtractUnixSocket bool + var ( + ln any + err error + address string + unixFileMode fs.FileMode + isAbtractUnixSocket bool + ) // split unix socket addr early so lnKey // is independent of permissions bits @@ -181,27 +182,10 @@ func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net lnKey := listenerKey(na.Network, address) - switch na.Network { - case "tcp", "tcp4", "tcp6", "unix", "unixpacket": - ln, err = listenTCPOrUnix(ctx, lnKey, na.Network, address, config) - case "unixgram": - ln, err = config.ListenPacket(ctx, na.Network, address) - case "udp", "udp4", "udp6": - sharedPc, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { - pc, err := config.ListenPacket(ctx, na.Network, address) - if err != nil { - return nil, err - } - return &sharedPacketConn{PacketConn: pc, key: lnKey}, nil - }) - if err != nil { - return nil, err - } - spc := sharedPc.(*sharedPacketConn) - ln = &fakeClosePacketConn{spc: spc, UDPConn: spc.PacketConn.(*net.UDPConn)} - } if strings.HasPrefix(na.Network, "ip") { ln, err = config.ListenPacket(ctx, na.Network, address) + } else { + ln, err = listenReusable(ctx, lnKey, na.Network, address, config) } if err != nil { return nil, err @@ -210,13 +194,6 @@ func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net return nil, fmt.Errorf("unsupported network type: %s", na.Network) } - // TODO: Not 100% sure this is necessary, but we do this for net.UnixListener in listen_unix.go, so... - if unix, ok := ln.(*net.UnixConn); ok { - one := int32(1) - ln = &unixConn{unix, address, lnKey, &one} - unixSockets[lnKey] = unix - } - if IsUnixNetwork(na.Network) { if !isAbtractUnixSocket { if err := os.Chmod(address, unixFileMode); err != nil { @@ -555,20 +532,8 @@ func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) ( // and the request counter will reflect current http server ctx, cancel := sql.sqs.addState(tlsConf, activeRequests) - // TODO: to serve QUIC over a unix socket, currently we need to hold onto - // the underlying net.PacketConn (which we wrap as unixConn to keep count - // of closes) because closing the quic.EarlyListener doesn't actually close - // the underlying PacketConn, but we need to for unix sockets since we dup - // the file descriptor and thus need to close the original; track issue: - // https://github.com/quic-go/quic-go/issues/3560#issuecomment-1258959608 - var unix *unixConn - if uc, ok := ln.(*unixConn); ok { - unix = uc - } - return &fakeCloseQuicListener{ sharedQuicListener: sql, - uc: unix, context: ctx, contextCancel: cancel, }, nil @@ -677,17 +642,6 @@ func (sql *sharedQuicListener) Destruct() error { return sql.packetConn.Close() } -// sharedPacketConn is like sharedListener, but for net.PacketConns. -type sharedPacketConn struct { - net.PacketConn - key string -} - -// Destruct closes the underlying socket. -func (spc *sharedPacketConn) Destruct() error { - return spc.PacketConn.Close() -} - // fakeClosedErr returns an error value that is not temporary // nor a timeout, suitable for making the caller think the // listener is actually closed @@ -707,39 +661,9 @@ func fakeClosedErr(l interface{ Addr() net.Addr }) error { // socket is actually left open. var errFakeClosed = fmt.Errorf("listener 'closed' 😉") -// fakeClosePacketConn is like fakeCloseListener, but for PacketConns, -// or more specifically, *net.UDPConn -type fakeClosePacketConn struct { - closed int32 // accessed atomically; belongs to this struct only - spc *sharedPacketConn // its key is used in Close - *net.UDPConn // embedded, so we also become a net.PacketConn and enable several other optimizations done by quic-go -} - -// interface guard for extra optimizations -// needed by QUIC implementation: https://github.com/caddyserver/caddy/issues/3998, https://github.com/caddyserver/caddy/issues/5605 -var _ quic.OOBCapablePacketConn = (*fakeClosePacketConn)(nil) - -// https://pkg.go.dev/golang.org/x/net/ipv4#NewPacketConn is used by quic-go and requires a net.PacketConn type assertable to a net.Conn, -// but doesn't actually use these methods, the only methods needed are `ReadMsgUDP` and `SyscallConn`. -var _ net.Conn = (*fakeClosePacketConn)(nil) - -// Unwrap returns the underlying net.UDPConn for quic-go optimization -func (fcpc *fakeClosePacketConn) Unwrap() any { - return fcpc.UDPConn -} - -// Close won't close the underlying socket unless there is no more reference, then listenerPool will close it. -func (fcpc *fakeClosePacketConn) Close() error { - if atomic.CompareAndSwapInt32(&fcpc.closed, 0, 1) { - _, _ = listenerPool.Delete(fcpc.spc.key) - } - return nil -} - type fakeCloseQuicListener struct { - closed int32 // accessed atomically; belongs to this struct only - *sharedQuicListener // embedded, so we also become a quic.EarlyListener - uc *unixConn // underlying unix socket, if UDS + closed int32 // accessed atomically; belongs to this struct only + *sharedQuicListener // embedded, so we also become a quic.EarlyListener context context.Context contextCancel context.CancelFunc } @@ -766,11 +690,6 @@ func (fcql *fakeCloseQuicListener) Close() error { if atomic.CompareAndSwapInt32(&fcql.closed, 0, 1) { fcql.contextCancel() _, _ = listenerPool.Delete(fcql.sharedQuicListener.key) - if fcql.uc != nil { - // unix sockets need to be closed ourselves because we dup() the file - // descriptor when we reuse them, so this avoids a resource leak - fcql.uc.Close() - } } return nil } @@ -796,34 +715,7 @@ func RegisterNetwork(network string, getListener ListenerFunc) { networkTypes[network] = getListener } -type unixConn struct { - *net.UnixConn - filename string - mapKey string - count *int32 // accessed atomically -} - -func (uc *unixConn) Close() error { - newCount := atomic.AddInt32(uc.count, -1) - if newCount == 0 { - defer func() { - unixSocketsMu.Lock() - delete(unixSockets, uc.mapKey) - unixSocketsMu.Unlock() - _ = syscall.Unlink(uc.filename) - }() - } - return uc.UnixConn.Close() -} - -// unixSockets keeps track of the currently-active unix sockets -// so we can transfer their FDs gracefully during reloads. -var ( - unixSockets = make(map[string]interface { - File() (*os.File, error) - }) - unixSocketsMu sync.Mutex -) +var unixSocketsMu sync.Mutex // getListenerFromPlugin returns a listener on the given network and address // if a plugin has registered the network name. It may return (nil, nil) if @@ -867,11 +759,3 @@ type ListenerWrapper interface { var listenerPool = NewUsagePool() const maxPortSpan = 65535 - -// Interface guards (see https://github.com/caddyserver/caddy/issues/3998) -var ( - _ (interface{ SetReadBuffer(int) error }) = (*fakeClosePacketConn)(nil) - _ (interface { - SyscallConn() (syscall.RawConn, error) - }) = (*fakeClosePacketConn)(nil) -)