diff --git a/app.go b/app.go index 11c8d685..de6ef669 100644 --- a/app.go +++ b/app.go @@ -95,6 +95,7 @@ type Headscale struct { ipAllocationMutex sync.Mutex shutdownChan chan struct{} + wg sync.WaitGroup } // Look up the TLS constant relative to user-supplied TLS client @@ -153,6 +154,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { privateKey: privKey, aclRules: tailcfg.FilterAllowAll, // default allowall registrationCache: registrationCache, + wg: sync.WaitGroup{}, } err = app.initDB() @@ -567,6 +569,8 @@ func (h *Headscale) Serve() error { // https://github.com/soheilhy/cmux/issues/68 // https://github.com/soheilhy/cmux/issues/91 + var grpcServer *grpc.Server + var grpcListener net.Listener if tlsConfig != nil || h.cfg.GRPCAllowInsecure { log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr) @@ -587,12 +591,12 @@ func (h *Headscale) Serve() error { log.Warn().Msg("gRPC is running without security") } - grpcServer := grpc.NewServer(grpcOptions...) + grpcServer = grpc.NewServer(grpcOptions...) v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h)) reflection.Register(grpcServer) - grpcListener, err := net.Listen("tcp", h.cfg.GRPCAddr) + grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr) if err != nil { return fmt.Errorf("failed to bind to TCP address: %w", err) } @@ -668,7 +672,7 @@ func (h *Headscale) Serve() error { syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGHUP) - go func(c chan os.Signal) { + sig_func := func(c chan os.Signal) { // Wait for a SIGINT or SIGKILL: for { sig := <-c @@ -678,7 +682,7 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received SIGHUP, reloading ACL and Config") - // TODO(kradalby): Reload config on SIGHUP + // TODO(kradalby): Reload config on SIGHUP if h.cfg.ACL.PolicyPath != "" { aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) @@ -698,7 +702,8 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") - h.shutdownChan <- struct{}{} + close(h.shutdownChan) + h.wg.Wait() // Gracefully shut down servers ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout) @@ -710,6 +715,11 @@ func (h *Headscale) Serve() error { } grpcSocket.GracefulStop() + if grpcServer != nil { + grpcServer.GracefulStop() + grpcListener.Close() + } + // Close network listeners promHTTPListener.Close() httpListener.Close() @@ -736,7 +746,8 @@ func (h *Headscale) Serve() error { os.Exit(0) } } - }(sigc) + } + errorGroup.Go(func() error { sig_func(sigc); return nil }) return errorGroup.Wait() } diff --git a/poll.go b/poll.go index 6628a179..94941aa3 100644 --- a/poll.go +++ b/poll.go @@ -290,6 +290,9 @@ func (h *Headscale) PollNetMapStream( keepAliveChan chan []byte, updateChan chan struct{}, ) { + h.wg.Add(1) + defer h.wg.Done() + ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname) ctx, cancel := context.WithCancel(ctx) @@ -353,9 +356,9 @@ func (h *Headscale) PollNetMapStream( Str("channel", "pollData"). Int("bytes", len(data)). Msg("Data from pollData channel written successfully") - // TODO(kradalby): Abstract away all the database calls, this can cause race conditions - // when an outdated machine object is kept alive, e.g. db is update from - // command line, but then overwritten. + // TODO(kradalby): Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. err = h.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). @@ -431,9 +434,9 @@ func (h *Headscale) PollNetMapStream( Str("channel", "keepAlive"). Int("bytes", len(data)). Msg("Keep alive sent successfully") - // TODO(kradalby): Abstract away all the database calls, this can cause race conditions - // when an outdated machine object is kept alive, e.g. db is update from - // command line, but then overwritten. + // TODO(kradalby): Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. err = h.UpdateMachineFromDatabase(machine) if err != nil { log.Error(). @@ -588,9 +591,9 @@ func (h *Headscale) PollNetMapStream( Str("handler", "PollNetMapStream"). Str("machine", machine.Hostname). Msg("The client has closed the connection") - // TODO: Abstract away all the database calls, this can cause race conditions - // when an outdated machine object is kept alive, e.g. db is update from - // command line, but then overwritten. + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. err := h.UpdateMachineFromDatabase(machine) if err != nil { log.Error().