diff --git a/listeners.go b/listeners.go index 768d97711..67c519670 100644 --- a/listeners.go +++ b/listeners.go @@ -470,38 +470,90 @@ func ListenPacket(network, addr string) (net.PacketConn, error) { // unixgram will be used; otherwise, udp will be used). // // NOTE: This API is EXPERIMENTAL and may be changed or removed. -// -// TODO: See if we can find a more elegant solution closer to the new NetworkAddress.Listen API. -func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) { - lnKey := listenerKey("quic+"+ln.LocalAddr().Network(), ln.LocalAddr().String()) +func (na NetworkAddress) ListenQUIC(ctx context.Context, portOffset uint, config net.ListenConfig, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) { + lnKey := listenerKey("quic"+na.Network, na.JoinHostPort(portOffset)) sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { - sqtc := newSharedQUICTLSConfig(tlsConf) + lnAny, err := na.Listen(ctx, portOffset, config) + if err != nil { + return nil, err + } + + ln := lnAny.(net.PacketConn) + + h3ln := ln + for { + // retrieve the underlying socket, so quic-go can optimize. + if unwrapper, ok := h3ln.(interface{ Unwrap() net.PacketConn }); ok { + h3ln = unwrapper.Unwrap() + } else { + break + } + } + + sqs := newSharedQUICState(tlsConf, activeRequests) // http3.ConfigureTLSConfig only uses this field and tls App sets this field as well //nolint:gosec - quicTlsConfig := &tls.Config{GetConfigForClient: sqtc.getConfigForClient} - earlyLn, err := quic.ListenEarly(ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{ + quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient} + earlyLn, err := quic.ListenEarly(h3ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{ Allow0RTT: true, RequireAddressValidation: func(clientAddr net.Addr) bool { - var highLoad bool - if activeRequests != nil { - highLoad = atomic.LoadInt64(activeRequests) > 1000 // TODO: make tunable? - } - return highLoad + // TODO: make tunable? + return sqs.getActiveRequests() > 1000 }, }) if err != nil { return nil, err } - return &sharedQuicListener{EarlyListener: earlyLn, sqtc: sqtc, key: lnKey}, nil + // using the original net.PacketConn to close them properly + return &sharedQuicListener{EarlyListener: earlyLn, packetConn: ln, sqs: sqs, key: lnKey}, nil }) if err != nil { return nil, err } sql := sharedEarlyListener.(*sharedQuicListener) - // add current tls.Config to sqtc, so GetConfigForClient will always return the latest tls.Config in case of context cancellation - ctx, cancel := sql.sqtc.addTLSConfig(tlsConf) + // add current tls.Config to sqs, so GetConfigForClient will always return the latest tls.Config in case of context cancellation, + // and the request counter will reflect current http server + ctx, cancel := sql.sqs.addState(tlsConf, activeRequests) + + return &fakeCloseQuicListener{ + sharedQuicListener: sql, + context: ctx, + contextCancel: cancel, + }, nil +} + +// DEPRECATED: Use NetworkAddress.ListenQUIC instead. This function will likely be changed or removed in the future. +// TODO: See if we can find a more elegant solution closer to the new NetworkAddress.Listen API. +func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) (http3.QUICEarlyListener, error) { + lnKey := listenerKey("quic+"+ln.LocalAddr().Network(), ln.LocalAddr().String()) + + sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { + sqs := newSharedQUICState(tlsConf, activeRequests) + // http3.ConfigureTLSConfig only uses this field and tls App sets this field as well + //nolint:gosec + quicTlsConfig := &tls.Config{GetConfigForClient: sqs.getConfigForClient} + earlyLn, err := quic.ListenEarly(ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{ + Allow0RTT: true, + RequireAddressValidation: func(clientAddr net.Addr) bool { + // TODO: make tunable? + return sqs.getActiveRequests() > 1000 + }, + }) + if err != nil { + return nil, err + } + return &sharedQuicListener{EarlyListener: earlyLn, sqs: sqs, key: lnKey}, nil + }) + if err != nil { + return nil, err + } + + sql := sharedEarlyListener.(*sharedQuicListener) + // add current tls.Config and request counter to sqs, so GetConfigForClient will always return the latest tls.Config in case of context cancellation, + // and the request counter will reflect current http server + ctx, cancel := sql.sqs.addState(tlsConf, activeRequests) // TODO: to serve QUIC over a unix socket, currently we need to hold onto // the underlying net.PacketConn (which we wrap as unixConn to keep count @@ -534,38 +586,50 @@ type contextAndCancelFunc struct { context.CancelFunc } -// sharedQUICTLSConfig manages GetConfigForClient +// sharedQUICState manages GetConfigForClient and current number of active requests // see issue: https://github.com/caddyserver/caddy/pull/4849 -type sharedQUICTLSConfig struct { - rmu sync.RWMutex - tlsConfs map[*tls.Config]contextAndCancelFunc - activeTlsConf *tls.Config +type sharedQUICState struct { + rmu sync.RWMutex + tlsConfs map[*tls.Config]contextAndCancelFunc + requestCounters map[*tls.Config]*int64 + activeTlsConf *tls.Config + activeRequestsCounter *int64 } -// newSharedQUICTLSConfig creates a new sharedQUICTLSConfig -func newSharedQUICTLSConfig(tlsConfig *tls.Config) *sharedQUICTLSConfig { - sqtc := &sharedQUICTLSConfig{ - tlsConfs: make(map[*tls.Config]contextAndCancelFunc), - activeTlsConf: tlsConfig, +// newSharedQUICState creates a new sharedQUICState +func newSharedQUICState(tlsConfig *tls.Config, activeRequests *int64) *sharedQUICState { + sqtc := &sharedQUICState{ + tlsConfs: make(map[*tls.Config]contextAndCancelFunc), + requestCounters: make(map[*tls.Config]*int64), + activeTlsConf: tlsConfig, + activeRequestsCounter: activeRequests, } - sqtc.addTLSConfig(tlsConfig) + sqtc.addState(tlsConfig, activeRequests) return sqtc } // getConfigForClient is used as tls.Config's GetConfigForClient field -func (sqtc *sharedQUICTLSConfig) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Config, error) { - sqtc.rmu.RLock() - defer sqtc.rmu.RUnlock() - return sqtc.activeTlsConf.GetConfigForClient(ch) +func (sqs *sharedQUICState) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Config, error) { + sqs.rmu.RLock() + defer sqs.rmu.RUnlock() + return sqs.activeTlsConf.GetConfigForClient(ch) } -// addTLSConfig adds tls.Config to the map if not present and returns the corresponding context and its cancelFunc -// so that when cancelled, the active tls.Config will change -func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Context, context.CancelFunc) { - sqtc.rmu.Lock() - defer sqtc.rmu.Unlock() +// getActiveRequests returns the number of active requests +func (sqs *sharedQUICState) getActiveRequests() int64 { + // Prevent a race when a context is cancelled and active request counter is being changed + sqs.rmu.RLock() + defer sqs.rmu.RUnlock() + return atomic.LoadInt64(sqs.activeRequestsCounter) +} - if cacc, ok := sqtc.tlsConfs[tlsConfig]; ok { +// addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc +// so that when cancelled, the active tls.Config and request counter will change +func (sqs *sharedQUICState) addState(tlsConfig *tls.Config, activeRequests *int64) (context.Context, context.CancelFunc) { + sqs.rmu.Lock() + defer sqs.rmu.Unlock() + + if cacc, ok := sqs.tlsConfs[tlsConfig]; ok { return cacc.Context, cacc.CancelFunc } @@ -573,23 +637,26 @@ func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Co wrappedCancel := func() { cancel() - sqtc.rmu.Lock() - defer sqtc.rmu.Unlock() + sqs.rmu.Lock() + defer sqs.rmu.Unlock() - delete(sqtc.tlsConfs, tlsConfig) - if sqtc.activeTlsConf == tlsConfig { - // select another tls.Config, if there is none, + delete(sqs.tlsConfs, tlsConfig) + delete(sqs.requestCounters, tlsConfig) + if sqs.activeTlsConf == tlsConfig { + // select another tls.Config and request counter, if there is none, // related sharedQuicListener will be destroyed anyway - for tc := range sqtc.tlsConfs { - sqtc.activeTlsConf = tc + for tc, counter := range sqs.requestCounters { + sqs.activeTlsConf = tc + sqs.activeRequestsCounter = counter break } } } - sqtc.tlsConfs[tlsConfig] = contextAndCancelFunc{ctx, wrappedCancel} + sqs.tlsConfs[tlsConfig] = contextAndCancelFunc{ctx, wrappedCancel} + sqs.requestCounters[tlsConfig] = activeRequests // there should be at most 2 tls.Configs - if len(sqtc.tlsConfs) > 2 { - Log().Warn("quic listener tls configs are more than 2", zap.Int("number of configs", len(sqtc.tlsConfs))) + if len(sqs.tlsConfs) > 2 { + Log().Warn("quic listener tls configs are more than 2", zap.Int("number of configs", len(sqs.tlsConfs))) } return ctx, wrappedCancel } @@ -597,13 +664,17 @@ func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Co // sharedQuicListener is like sharedListener, but for quic.EarlyListeners. type sharedQuicListener struct { *quic.EarlyListener - sqtc *sharedQUICTLSConfig - key string + packetConn net.PacketConn // we have to hold these because quic-go won't close listeners it didn't create + sqs *sharedQUICState + key string } -// Destruct closes the underlying QUIC listener. +// Destruct closes the underlying QUIC listener and its associated net.PacketConn. func (sql *sharedQuicListener) Destruct() error { - return sql.EarlyListener.Close() + // close EarlyListener first to stop any operations being done to the net.PacketConn + _ = sql.EarlyListener.Close() + // then close the net.PacketConn + return sql.packetConn.Close() } // sharedPacketConn is like sharedListener, but for net.PacketConns. @@ -652,6 +723,11 @@ var _ quic.OOBCapablePacketConn = (*fakeClosePacketConn)(nil) // but doesn't actually use these methods, the only methods needed are `ReadMsgUDP` and `SyscallConn`. var _ net.Conn = (*fakeClosePacketConn)(nil) +// Unwrap returns the underlying net.UDPConn for quic-go optimization +func (fcpc *fakeClosePacketConn) Unwrap() any { + return fcpc.UDPConn +} + // Close won't close the underlying socket unless there is no more reference, then listenerPool will close it. func (fcpc *fakeClosePacketConn) Close() error { if atomic.CompareAndSwapInt32(&fcpc.closed, 0, 1) { diff --git a/modules/caddyhttp/app.go b/modules/caddyhttp/app.go index 457a5f4d3..69cd73b96 100644 --- a/modules/caddyhttp/app.go +++ b/modules/caddyhttp/app.go @@ -617,17 +617,6 @@ func (app *App) Stop() error { zap.Error(err), zap.Strings("addresses", server.Listen)) } - - // TODO: we have to manually close our listeners because quic-go won't - // close listeners it didn't create along with the server itself... - // see https://github.com/quic-go/quic-go/issues/3560 - for _, el := range server.h3listeners { - if err := el.Close(); err != nil { - app.logger.Error("HTTP/3 listener close", - zap.Error(err), - zap.String("address", el.LocalAddr().String())) - } - } } stopH2Listener := func(server *Server) { defer finishedShutdown.Done() diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index cf1760947..d060738f1 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -228,7 +228,6 @@ type Server struct { server *http.Server h3server *http3.Server - h3listeners []net.PacketConn // TODO: we have to hold these because quic-go won't close listeners it didn't create h2listeners []*http2Listener addresses []caddy.NetworkAddress @@ -555,13 +554,7 @@ func (s *Server) findLastRouteWithHostMatcher() int { // the listener, with Server s as the handler. func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error { addr.Network = getHTTP3Network(addr.Network) - lnAny, err := addr.Listen(s.ctx, 0, net.ListenConfig{}) - if err != nil { - return err - } - ln := lnAny.(net.PacketConn) - - h3ln, err := caddy.ListenQUIC(ln, tlsCfg, &s.activeRequests) + h3ln, err := addr.ListenQUIC(s.ctx, 0, net.ListenConfig{}, tlsCfg, &s.activeRequests) if err != nil { return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err) } @@ -579,8 +572,6 @@ func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error } } - s.h3listeners = append(s.h3listeners, ln) - //nolint:errcheck go s.h3server.ServeListener(h3ln)