diff --git a/CHANGELOG.md b/CHANGELOG.md index 91aed9ef..d9818217 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -74,6 +74,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Make sure integration tests cover postgres for all scenarios - CLI commands (all except `serve`) only requires minimal configuration, no more errors or warnings from unset settings [#2109](https://github.com/juanfont/headscale/pull/2109) - CLI results are now concistently sent to stdout and errors to stderr [#2109](https://github.com/juanfont/headscale/pull/2109) +- Fix issue where shutting down headscale would hang [#2113](https://github.com/juanfont/headscale/pull/2113) ## 0.22.3 (2023-05-12) diff --git a/cmd/headscale/cli/serve.go b/cmd/headscale/cli/serve.go index 9f0fa35e..91597400 100644 --- a/cmd/headscale/cli/serve.go +++ b/cmd/headscale/cli/serve.go @@ -1,6 +1,9 @@ package cli import ( + "errors" + "net/http" + "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) @@ -22,8 +25,8 @@ var serveCmd = &cobra.Command{ } err = app.Serve() - if err != nil { - log.Fatal().Caller().Err(err).Msg("Error starting server") + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatal().Caller().Err(err).Msg("Headscale ran into an error and had to shut down.") } }, } diff --git a/hscontrol/app.go b/hscontrol/app.go index 1732135a..4a5b4679 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -770,7 +770,7 @@ func (h *Headscale) Serve() error { }) } default: - trace := log.Trace().Msgf + info := func(msg string) { log.Info().Msg(msg) } log.Info(). Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") @@ -778,55 +778,55 @@ func (h *Headscale) Serve() error { expireNodeCancel() h.ephemeralGC.Close() - trace("waiting for netmap stream to close") - h.pollNetMapStreamWG.Wait() - // Gracefully shut down servers ctx, cancel := context.WithTimeout( context.Background(), types.HTTPShutdownTimeout, ) - trace("shutting down debug http server") + info("shutting down debug http server") if err := debugHTTPServer.Shutdown(ctx); err != nil { - log.Error().Err(err).Msg("Failed to shutdown prometheus http") + log.Error().Err(err).Msg("failed to shutdown prometheus http") } - trace("shutting down main http server") + info("shutting down main http server") if err := httpServer.Shutdown(ctx); err != nil { - log.Error().Err(err).Msg("Failed to shutdown http") + log.Error().Err(err).Msg("failed to shutdown http") } - trace("shutting down grpc server (socket)") + info("closing node notifier") + h.nodeNotifier.Close() + + info("waiting for netmap stream to close") + h.pollNetMapStreamWG.Wait() + + info("shutting down grpc server (socket)") grpcSocket.GracefulStop() if grpcServer != nil { - trace("shutting down grpc server (external)") + info("shutting down grpc server (external)") grpcServer.GracefulStop() grpcListener.Close() } if tailsqlContext != nil { - trace("shutting down tailsql") + info("shutting down tailsql") tailsqlContext.Done() } - trace("closing node notifier") - h.nodeNotifier.Close() - // Close network listeners - trace("closing network listeners") + info("closing network listeners") debugHTTPListener.Close() httpListener.Close() grpcGatewayConn.Close() // Stop listening (and unlink the socket if unix type): - trace("closing socket listener") + info("closing socket listener") socketListener.Close() // Close db connections - trace("closing database connection") + info("closing database connection") err = h.db.Close() if err != nil { - log.Error().Err(err).Msg("Failed to close db") + log.Error().Err(err).Msg("failed to close db") } log.Info(). diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 0b663776..ceede6ba 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -36,6 +36,7 @@ type Notifier struct { connected *xsync.MapOf[types.NodeID, bool] b *batcher cfg *types.Config + closed bool } func NewNotifier(cfg *types.Config) *Notifier { @@ -43,6 +44,7 @@ func NewNotifier(cfg *types.Config) *Notifier { nodes: make(map[types.NodeID]chan<- types.StateUpdate), connected: xsync.NewMapOf[types.NodeID, bool](), cfg: cfg, + closed: false, } b := newBatcher(cfg.Tuning.BatchChangeDelay, n) n.b = b @@ -51,9 +53,19 @@ func NewNotifier(cfg *types.Config) *Notifier { return n } -// Close stops the batcher inside the notifier. +// Close stops the batcher and closes all channels. func (n *Notifier) Close() { + notifierWaitersForLock.WithLabelValues("lock", "close").Inc() + n.l.Lock() + defer n.l.Unlock() + notifierWaitersForLock.WithLabelValues("lock", "close").Dec() + + n.closed = true n.b.close() + + for _, c := range n.nodes { + close(c) + } } func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) { @@ -70,6 +82,10 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) { notifierWaitersForLock.WithLabelValues("lock", "add").Dec() notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds()) + if n.closed { + return + } + // If a channel exists, it means the node has opened a new // connection. Close the old channel and replace it. if curr, ok := n.nodes[nodeID]; ok { @@ -96,6 +112,10 @@ func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) b notifierWaitersForLock.WithLabelValues("lock", "remove").Dec() notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds()) + if n.closed { + return true + } + if len(n.nodes) == 0 { return true } @@ -154,6 +174,10 @@ func (n *Notifier) NotifyWithIgnore( update types.StateUpdate, ignoreNodeIDs ...types.NodeID, ) { + if n.closed { + return + } + notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc() n.b.addOrPassthrough(update) } @@ -170,6 +194,10 @@ func (n *Notifier) NotifyByNodeID( notifierWaitersForLock.WithLabelValues("lock", "notify").Dec() notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds()) + if n.closed { + return + } + if c, ok := n.nodes[nodeID]; ok { select { case <-ctx.Done(): @@ -205,6 +233,10 @@ func (n *Notifier) sendAll(update types.StateUpdate) { notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec() notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds()) + if n.closed { + return + } + for id, c := range n.nodes { // Whenever an update is sent to all nodes, there is a chance that the node // has disconnected and the goroutine that was supposed to consume the update