diff --git a/hscontrol/app.go b/hscontrol/app.go index bdb5c1d9..7eb15495 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -503,7 +503,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } -// Serve launches a GIN server with the Headscale API. +// Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { @@ -532,7 +532,7 @@ func (h *Headscale) Serve() error { region, err := h.DERPServer.GenerateRegion() if err != nil { - return err + return fmt.Errorf("generating DERP region for embedded server: %w", err) } if h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { @@ -607,14 +607,14 @@ func (h *Headscale) Serve() error { }..., ) if err != nil { - return err + return fmt.Errorf("setting up gRPC gateway via socket: %w", err) } // Connect to the gRPC server over localhost to skip // the authentication. err = v1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn) if err != nil { - return err + return fmt.Errorf("registering Headscale API service to gRPC: %w", err) } // Start the local gRPC server without TLS and without authentication @@ -635,9 +635,7 @@ func (h *Headscale) Serve() error { tlsConfig, err := h.getTLSSettings() if err != nil { - log.Error().Err(err).Msg("Failed to set up TLS configuration") - - return err + return fmt.Errorf("configuring TLS settings: %w", err) } // @@ -702,15 +700,11 @@ func (h *Headscale) Serve() error { httpServer := &http.Server{ Addr: h.cfg.Addr, Handler: router, - ReadTimeout: types.HTTPReadTimeout, - // Go does not handle timeouts in HTTP very well, and there is - // no good way to handle streaming timeouts, therefore we need to - // keep this at unlimited and be careful to clean up connections - // https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming - // TODO(kradalby): this timeout can now be set per handler with http.ResponseController: - // https://www.alexedwards.net/blog/how-to-use-the-http-responsecontroller-type - // replace this so only the longpoller has no timeout. - WriteTimeout: 0, + ReadTimeout: types.HTTPTimeout, + + // Long polling should not have any timeout, this is overriden + // further down the chain + WriteTimeout: types.HTTPTimeout, } var httpListener net.Listener @@ -729,27 +723,46 @@ func (h *Headscale) Serve() error { log.Info(). Msgf("listening and serving HTTP on: %s", h.cfg.Addr) - promMux := http.NewServeMux() - promMux.Handle("/metrics", promhttp.Handler()) + debugMux := http.NewServeMux() + debugMux.HandleFunc("/debug/notifier", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(h.nodeNotifier.String())) - promHTTPServer := &http.Server{ + return + }) + debugMux.HandleFunc("/debug/mapresp", func(w http.ResponseWriter, r *http.Request) { + h.mapSessionMu.Lock() + defer h.mapSessionMu.Unlock() + + var b strings.Builder + b.WriteString("mapresponders:\n") + for k, v := range h.mapSessions { + fmt.Fprintf(&b, "\t%d: %p\n", k, v) + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(b.String())) + + return + }) + debugMux.Handle("/metrics", promhttp.Handler()) + + debugHTTPServer := &http.Server{ Addr: h.cfg.MetricsAddr, - Handler: promMux, - ReadTimeout: types.HTTPReadTimeout, + Handler: debugMux, + ReadTimeout: types.HTTPTimeout, WriteTimeout: 0, } - var promHTTPListener net.Listener - promHTTPListener, err = net.Listen("tcp", h.cfg.MetricsAddr) - + debugHTTPListener, err := net.Listen("tcp", h.cfg.MetricsAddr) if err != nil { return fmt.Errorf("failed to bind to TCP address: %w", err) } - errorGroup.Go(func() error { return promHTTPServer.Serve(promHTTPListener) }) + errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) }) log.Info(). - Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr) + Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr) var tailsqlContext context.Context if tailsqlEnabled { @@ -815,7 +828,7 @@ func (h *Headscale) Serve() error { context.Background(), types.HTTPShutdownTimeout, ) - if err := promHTTPServer.Shutdown(ctx); err != nil { + if err := debugHTTPServer.Shutdown(ctx); err != nil { log.Error().Err(err).Msg("Failed to shutdown prometheus http") } if err := httpServer.Shutdown(ctx); err != nil { @@ -833,7 +846,7 @@ func (h *Headscale) Serve() error { } // Close network listeners - promHTTPListener.Close() + debugHTTPListener.Close() httpListener.Close() grpcGatewayConn.Close() @@ -898,7 +911,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { server := &http.Server{ Addr: h.cfg.TLS.LetsEncrypt.Listen, Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)), - ReadTimeout: types.HTTPReadTimeout, + ReadTimeout: types.HTTPTimeout, } go func() { diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 6d778775..9cc78af8 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -4,11 +4,13 @@ import ( "errors" "fmt" "net/netip" + "sort" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "gorm.io/gorm" + "tailscale.com/util/set" ) var ErrRouteIsNotAvailable = errors.New("route is not available") @@ -402,11 +404,10 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) { return sendUpdate, nil } -// FailoverRouteIfAvailable takes a node and checks if the node's route -// currently have a functioning host that exposes the network. -// If it does not, it is failed over to another suitable route if there -// is one. -func FailoverRouteIfAvailable( +// FailoverNodeRoutesIfNeccessary takes a node and checks if the node's route +// need to be failed over to another host. +// If needed, the failover will be attempted. +func FailoverNodeRoutesIfNeccessary( tx *gorm.DB, isConnected types.NodeConnectedMap, node *types.Node, @@ -416,8 +417,12 @@ func FailoverRouteIfAvailable( return nil, nil } - var changedNodes []types.NodeID + log.Trace().Msgf("NODE ROUTES: %d", len(nodeRoutes)) + changedNodes := make(set.Set[types.NodeID]) + +nodeRouteLoop: for _, nodeRoute := range nodeRoutes { + log.Trace().Msgf("NODE ROUTE: %d", nodeRoute.ID) routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix)) if err != nil { return nil, fmt.Errorf("getting routes by prefix: %w", err) @@ -427,29 +432,37 @@ func FailoverRouteIfAvailable( if route.IsPrimary { // if we have a primary route, and the node is connected // nothing needs to be done. - if isConnected[route.Node.ID] { - return nil, nil + if conn, ok := isConnected[route.Node.ID]; conn && ok { + continue nodeRouteLoop } // if not, we need to failover the route failover := failoverRoute(isConnected, &route, routes) if failover != nil { - failover.save(tx) + err := failover.save(tx) if err != nil { return nil, fmt.Errorf("saving failover routes: %w", err) } - changedNodes = append(changedNodes, failover.old.Node.ID, failover.new.Node.ID) + changedNodes.Add(failover.old.Node.ID) + changedNodes.Add(failover.new.Node.ID) + + continue nodeRouteLoop } } } } + chng := changedNodes.Slice() + sort.SliceStable(chng, func(i, j int) bool { + return chng[i] < chng[j] + }) + if len(changedNodes) != 0 { return &types.StateUpdate{ Type: types.StatePeerChanged, - ChangeNodes: changedNodes, - Message: "called from db.FailoverRouteIfAvailable", + ChangeNodes: chng, + Message: "called from db.FailoverNodeRoutesIfNeccessary", }, nil } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 390cf700..453a7503 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -7,9 +7,9 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/stretchr/testify/assert" "gopkg.in/check.v1" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -270,6 +270,370 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } +var n = func(nid types.NodeID) types.Node { + return types.Node{ID: nid} +} +var np = func(nid types.NodeID) *types.Node { + no := n(nid) + return &no +} +var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route { + return types.Route{ + Model: gorm.Model{ + ID: id, + }, + Node: n(nid), + Prefix: prefix, + Enabled: enabled, + IsPrimary: primary, + } +} +var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route { + ro := r(id, nid, prefix, enabled, primary) + return &ro +} + +func dbForTest(t *testing.T, testName string) *HSDatabase { + t.Helper() + + tmpDir, err := os.MkdirTemp("", testName) + if err != nil { + t.Fatalf("creating tempdir: %s", err) + } + + dbPath := tmpDir + "/headscale_test.db" + + db, err = NewHeadscaleDatabase( + types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, + }, + "", + ) + if err != nil { + t.Fatalf("setting up database: %s", err) + } + + t.Logf("database set up at: %s", dbPath) + + return db +} + +func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { + su := func(nids ...types.NodeID) *types.StateUpdate { + return &types.StateUpdate{ + ChangeNodes: nids, + } + } + tests := []struct { + name string + nodes types.Nodes + routes types.Routes + isConnected []types.NodeConnectedMap + want []*types.StateUpdate + wantErr bool + }{ + { + name: "n1-down-n2-down-n1-up", + nodes: types.Nodes{ + np(1), + np(2), + np(1), + }, + routes: types.Routes{ + r(1, 1, ipp("10.0.0.0/24"), true, true), + r(2, 2, ipp("10.0.0.0/24"), true, false), + }, + isConnected: []types.NodeConnectedMap{ + // n1 goes down + { + 1: false, + 2: true, + }, + // n2 goes down + { + 1: false, + 2: false, + }, + // n1 comes up + { + 1: true, + 2: false, + }, + }, + want: []*types.StateUpdate{ + // route changes from 1 -> 2 + su(1, 2), + // both down, no change + nil, + // route changes from 2 -> 1 + su(1, 2), + }, + }, + { + name: "n1-recon-n2-down-n1-recon-n2-up", + nodes: types.Nodes{ + np(1), + np(2), + np(1), + np(2), + }, + routes: types.Routes{ + r(1, 1, ipp("10.0.0.0/24"), true, true), + r(2, 2, ipp("10.0.0.0/24"), true, false), + }, + isConnected: []types.NodeConnectedMap{ + // n1 up recon = noop + { + 1: true, + 2: true, + }, + // n2 goes down + { + 1: true, + 2: false, + }, + // n1 up recon = noop + { + 1: true, + 2: false, + }, + // n2 comes back up + { + 1: true, + 2: false, + }, + }, + want: []*types.StateUpdate{ + nil, + nil, + nil, + nil, + }, + }, + { + name: "n1-recon-n2-down-n1-recon-n2-up", + nodes: types.Nodes{ + np(1), + np(1), + np(3), + np(3), + np(2), + np(1), + }, + routes: types.Routes{ + r(1, 1, ipp("10.0.0.0/24"), true, true), + r(2, 2, ipp("10.0.0.0/24"), true, false), + r(3, 3, ipp("10.0.0.0/24"), true, false), + }, + isConnected: []types.NodeConnectedMap{ + // n1 goes down + { + 1: false, + 2: false, + 3: true, + }, + // n1 comes up + { + 1: true, + 2: false, + 3: true, + }, + // n3 goes down + { + 1: true, + 2: false, + 3: false, + }, + // n3 comes up + { + 1: true, + 2: false, + 3: true, + }, + // n2 comes up + { + 1: true, + 2: true, + 3: true, + }, + // n1 goes down + { + 1: false, + 2: true, + 3: true, + }, + }, + want: []*types.StateUpdate{ + su(1, 3), // n1 -> n3 + nil, + su(1, 3), // n3 -> n1 + nil, + nil, + su(1, 2), // n1 -> n2 + }, + }, + { + name: "n1-recon-n2-dis-n3-take", + nodes: types.Nodes{ + np(1), + np(3), + }, + routes: types.Routes{ + r(1, 1, ipp("10.0.0.0/24"), true, true), + r(2, 2, ipp("10.0.0.0/24"), false, false), + r(3, 3, ipp("10.0.0.0/24"), true, false), + }, + isConnected: []types.NodeConnectedMap{ + // n1 goes down + { + 1: false, + 2: true, + 3: true, + }, + // n3 goes down + { + 1: false, + 2: true, + 3: false, + }, + }, + want: []*types.StateUpdate{ + su(1, 3), // n1 -> n3 + nil, + }, + }, + { + name: "multi-n1-oneforeach-n2-n3", + nodes: types.Nodes{ + np(1), + }, + routes: types.Routes{ + r(1, 1, ipp("10.0.0.0/24"), true, true), + r(4, 1, ipp("10.1.0.0/24"), true, true), + r(2, 2, ipp("10.0.0.0/24"), true, false), + r(3, 3, ipp("10.1.0.0/24"), true, false), + }, + isConnected: []types.NodeConnectedMap{ + // n1 goes down + { + 1: false, + 2: true, + 3: true, + }, + }, + want: []*types.StateUpdate{ + su(1, 2, 3), // n1 -> n2,n3 + }, + }, + { + name: "multi-n1-onefor-n2-disabled-n3", + nodes: types.Nodes{ + np(1), + }, + routes: types.Routes{ + r(1, 1, ipp("10.0.0.0/24"), true, true), + r(4, 1, ipp("10.1.0.0/24"), true, true), + r(2, 2, ipp("10.0.0.0/24"), true, false), + r(3, 3, ipp("10.1.0.0/24"), false, false), + }, + isConnected: []types.NodeConnectedMap{ + // n1 goes down + { + 1: false, + 2: true, + 3: true, + }, + }, + want: []*types.StateUpdate{ + su(1, 2), // n1 -> n2, n3 is not enabled + }, + }, + { + name: "multi-n1-onefor-n2-offline-n3", + nodes: types.Nodes{ + np(1), + }, + routes: types.Routes{ + r(1, 1, ipp("10.0.0.0/24"), true, true), + r(4, 1, ipp("10.1.0.0/24"), true, true), + r(2, 2, ipp("10.0.0.0/24"), true, false), + r(3, 3, ipp("10.1.0.0/24"), true, false), + }, + isConnected: []types.NodeConnectedMap{ + // n1 goes down + { + 1: false, + 2: true, + 3: false, + }, + }, + want: []*types.StateUpdate{ + su(1, 2), // n1 -> n2, n3 is offline + }, + }, + { + name: "multi-n2-back-to-multi-n1", + nodes: types.Nodes{ + np(1), + }, + routes: types.Routes{ + r(1, 1, ipp("10.0.0.0/24"), true, false), + r(4, 1, ipp("10.1.0.0/24"), true, true), + r(2, 2, ipp("10.0.0.0/24"), true, true), + r(3, 3, ipp("10.1.0.0/24"), true, false), + }, + isConnected: []types.NodeConnectedMap{ + // n1 goes down + { + 1: true, + 2: false, + 3: true, + }, + }, + want: []*types.StateUpdate{ + su(1, 2), // n2 -> n1 + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if (len(tt.isConnected) != len(tt.want)) && len(tt.want) != len(tt.nodes) { + t.Fatalf("nodes (%d), isConnected updates (%d), wants (%d) must be equal", len(tt.nodes), len(tt.isConnected), len(tt.want)) + } + + db := dbForTest(t, tt.name) + + for _, route := range tt.routes { + if err := db.DB.Save(&route).Error; err != nil { + t.Fatalf("failed to create route: %s", err) + } + } + + for step := range len(tt.isConnected) { + node := tt.nodes[step] + isConnected := tt.isConnected[step] + want := tt.want[step] + + got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return FailoverNodeRoutesIfNeccessary(tx, isConnected, node) + }) + + if (err != nil) != tt.wantErr { + t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(types.StateUpdate{}, "Type", "Message")); diff != "" { + t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff) + } + } + }) + } +} func TestFailoverRouteTx(t *testing.T) { tests := []struct { @@ -637,19 +1001,7 @@ func TestFailoverRouteTx(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "failover-db-test") - assert.NoError(t, err) - - db, err = NewHeadscaleDatabase( - types.DatabaseConfig{ - Type: "sqlite3", - Sqlite: types.SqliteConfig{ - Path: tmpDir + "/headscale_test.db", - }, - }, - "", - ) - assert.NoError(t, err) + db := dbForTest(t, tt.name) for _, route := range tt.routes { if err := db.DB.Save(&route).Error; err != nil { diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 83c200a2..80ec520d 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -31,7 +31,7 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) { } func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { - ctx, cancel := context.WithTimeout(context.Background(), types.HTTPReadTimeout) + ctx, cancel := context.WithTimeout(context.Background(), types.HTTPTimeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil) @@ -40,7 +40,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { } client := http.Client{ - Timeout: types.HTTPReadTimeout, + Timeout: types.HTTPTimeout, } resp, err := client.Do(req) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 3debd378..92a89d0f 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -3,7 +3,6 @@ package hscontrol import ( "encoding/binary" "encoding/json" - "errors" "io" "net/http" @@ -12,7 +11,6 @@ import ( "github.com/rs/zerolog/log" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" - "gorm.io/gorm" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp" "tailscale.com/tailcfg" @@ -103,12 +101,12 @@ func (h *Headscale) NoiseUpgradeHandler( router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler) server := http.Server{ - ReadTimeout: types.HTTPReadTimeout, + ReadTimeout: types.HTTPTimeout, } noiseServer.httpBaseConfig = &http.Server{ Handler: router, - ReadHeaderTimeout: types.HTTPReadTimeout, + ReadHeaderTimeout: types.HTTPTimeout, } noiseServer.http2Server = &http2.Server{} @@ -225,15 +223,6 @@ func (ns *noiseServer) NoisePollNetMapHandler( key.NodePublic{}, ) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - log.Warn(). - Str("handler", "NoisePollNetMap"). - Uint64("node.id", node.ID.Uint64()). - Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String()) - http.Error(writer, "Internal error", http.StatusNotFound) - - return - } log.Error(). Str("handler", "NoisePollNetMap"). Uint64("node.id", node.ID.Uint64()). @@ -242,58 +231,59 @@ func (ns *noiseServer) NoisePollNetMapHandler( return } - log.Debug(). - Str("handler", "NoisePollNetMap"). - Str("node", node.Hostname). - Int("cap_ver", int(mapRequest.Version)). - Uint64("node.id", node.ID.Uint64()). - Msg("A node sending a MapRequest with Noise protocol") + sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node) - session := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node) + sess.tracef("a node sending a MapRequest with Noise protocol") // If a streaming mapSession exists for this node, close it // and start a new one. - if session.isStreaming() { - log.Debug(). - Caller(). - Uint64("node.id", node.ID.Uint64()). - Int("cap_ver", int(mapRequest.Version)). - Msg("Aquiring lock to check stream") + if sess.isStreaming() { + sess.tracef("aquiring lock to check stream") + ns.headscale.mapSessionMu.Lock() - if oldSession, ok := ns.headscale.mapSessions[node.ID]; ok { - log.Info(). - Caller(). - Uint64("node.id", node.ID.Uint64()). - Msg("Node has an open streaming session, replacing") - oldSession.close() + if _, ok := ns.headscale.mapSessions[node.ID]; ok { + // NOTE/TODO(kradalby): From how I understand the protocol, when + // a client connects with stream=true, and already has a streaming + // connection open, the correct way is to close the current channel + // and replace it. However, I cannot manage to get that working with + // some sort of lock/block happening on the cancelCh in the streaming + // session. + // Not closing the channel and replacing it puts us in a weird state + // which keeps a ghost stream open, receiving keep alives, but no updates. + // + // Typically a new connection is opened when one exists as a client which + // is already authenticated reconnects (e.g. down, then up). The client will + // start auth and streaming at the same time, and then cancel the streaming + // when the auth has finished successfully, opening a new connection. + // + // As a work-around to not replacing, abusing the clients "resilience" + // by reject the new connection which will cause the client to immediately + // reconnect and "fix" the issue, as the other connection typically has been + // closed, meaning there is nothing to replace. + // + // sess.infof("node has an open stream(%p), replacing with %p", oldSession, sess) + // oldSession.close() + + defer ns.headscale.mapSessionMu.Unlock() + + sess.infof("node has an open stream(%p), rejecting new stream", sess) + return } - ns.headscale.mapSessions[node.ID] = session + ns.headscale.mapSessions[node.ID] = sess ns.headscale.mapSessionMu.Unlock() - log.Debug(). - Caller(). - Uint64("node.id", node.ID.Uint64()). - Int("cap_ver", int(mapRequest.Version)). - Msg("Releasing lock to check stream") + sess.tracef("releasing lock to check stream") } - session.serve() + sess.serve() - if session.isStreaming() { - log.Debug(). - Caller(). - Uint64("node.id", node.ID.Uint64()). - Int("cap_ver", int(mapRequest.Version)). - Msg("Aquiring lock to remove stream") + if sess.isStreaming() { + sess.tracef("aquiring lock to remove stream") ns.headscale.mapSessionMu.Lock() + defer ns.headscale.mapSessionMu.Unlock() delete(ns.headscale.mapSessions, node.ID) - ns.headscale.mapSessionMu.Unlock() - log.Debug(). - Caller(). - Uint64("node.id", node.ID.Uint64()). - Int("cap_ver", int(mapRequest.Version)). - Msg("Releasing lock to remove stream") + sess.tracef("releasing lock to remove stream") } } diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 4ead615b..6c34af57 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -172,11 +172,19 @@ func (n *Notifier) String() string { n.l.RLock() defer n.l.RUnlock() - str := []string{"Notifier, in map:\n"} + var b strings.Builder + b.WriteString("chans:\n") for k, v := range n.nodes { - str = append(str, fmt.Sprintf("\t%d: %v\n", k, v)) + fmt.Fprintf(&b, "\t%d: %p\n", k, v) } - return strings.Join(str, "") + b.WriteString("\n") + b.WriteString("connected:\n") + + for k, v := range n.connected { + fmt.Fprintf(&b, "\t%d: %t\n", k, v) + } + + return b.String() } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 2b65f6d9..7b554f30 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -48,6 +48,8 @@ type mapSession struct { ch chan types.StateUpdate cancelCh chan struct{} + keepAliveTicker *time.Ticker + node *types.Node w http.ResponseWriter @@ -85,6 +87,8 @@ func (h *Headscale) newMapSession( ch: updateChan, cancelCh: make(chan struct{}), + keepAliveTicker: time.NewTicker(keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)), + // Loggers warnf: warnf, infof: infof, @@ -100,10 +104,9 @@ func (m *mapSession) close() { return } - select { - case m.cancelCh <- struct{}{}: - default: - } + m.tracef("mapSession (%p) sending message on cancel chan") + m.cancelCh <- struct{}{} + m.tracef("mapSession (%p) sent message on cancel chan") } func (m *mapSession) isStreaming() bool { @@ -118,13 +121,6 @@ func (m *mapSession) isReadOnlyUpdate() bool { return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly } -func (m *mapSession) flush200() { - m.w.WriteHeader(http.StatusOK) - if f, ok := m.w.(http.Flusher); ok { - f.Flush() - } -} - // handlePoll ensures the node gets the appropriate updates from either // polling or immediate responses. // @@ -211,7 +207,12 @@ func (m *mapSession) serve() { m.pollFailoverRoutes("node connected", m.node) - keepAliveTicker := time.NewTicker(keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)) + // Upgrade the writer to a ResponseController + rc := http.NewResponseController(m.w) + + // Longpolling will break if there is a write timeout, + // so it needs to be disabled. + rc.SetWriteDeadline(time.Time{}) ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) defer cancel() @@ -324,18 +325,16 @@ func (m *mapSession) serve() { startWrite := time.Now() _, err = m.w.Write(data) if err != nil { - m.errf(err, "Could not write the map response, for mapSession: %p, stream: %t", m, m.isStreaming()) - + m.errf(err, "Could not write the map response, for mapSession: %p", m) return } - if flusher, ok := m.w.(http.Flusher); ok { - flusher.Flush() - } else { - log.Error().Msg("Failed to create http flusher") - + err = rc.Flush() + if err != nil { + m.errf(err, "flushing the map response to client, for mapSession: %p", m) return } + log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") m.infof("update sent") @@ -402,7 +401,7 @@ func (m *mapSession) serve() { derp = true } - case <-keepAliveTicker.C: + case <-m.keepAliveTicker.C: data, err := m.mapper.KeepAliveResponse(m.req, m.node) if err != nil { m.errf(err, "Error generating the keep alive msg") @@ -415,11 +414,9 @@ func (m *mapSession) serve() { return } - if flusher, ok := m.w.(http.Flusher); ok { - flusher.Flush() - } else { - log.Error().Msg("Failed to create http flusher") - + err = rc.Flush() + if err != nil { + m.errf(err, "flushing keep alive to client, for mapSession: %p", m) return } } @@ -428,7 +425,7 @@ func (m *mapSession) serve() { func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) { update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { - return db.FailoverRouteIfAvailable(tx, m.h.nodeNotifier.ConnectedMap(), node) + return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.ConnectedMap(), node) }) if err != nil { m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) @@ -565,7 +562,7 @@ func (m *mapSession) handleEndpointUpdate() { }, m.node.ID) - m.flush200() + m.w.WriteHeader(http.StatusOK) return } @@ -654,7 +651,9 @@ func (m *mapSession) handleReadOnlyRequest() { m.errf(err, "Failed to write response") } - m.flush200() + m.w.WriteHeader(http.StatusOK) + + return } func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) { diff --git a/hscontrol/types/const.go b/hscontrol/types/const.go index e718eb2e..019c14b6 100644 --- a/hscontrol/types/const.go +++ b/hscontrol/types/const.go @@ -3,7 +3,7 @@ package types import "time" const ( - HTTPReadTimeout = 30 * time.Second + HTTPTimeout = 30 * time.Second HTTPShutdownTimeout = 3 * time.Second TLSALPN01ChallengeType = "TLS-ALPN-01" HTTP01ChallengeType = "HTTP-01" diff --git a/integration/hsic/config.go b/integration/hsic/config.go index 606718c7..64e6e6eb 100644 --- a/integration/hsic/config.go +++ b/integration/hsic/config.go @@ -124,7 +124,7 @@ func DefaultConfigEnv() map[string]string { "HEADSCALE_PRIVATE_KEY_PATH": "/tmp/private.key", "HEADSCALE_NOISE_PRIVATE_KEY_PATH": "/tmp/noise_private.key", "HEADSCALE_LISTEN_ADDR": "0.0.0.0:8080", - "HEADSCALE_METRICS_LISTEN_ADDR": "127.0.0.1:9090", + "HEADSCALE_METRICS_LISTEN_ADDR": "0.0.0.0:9090", "HEADSCALE_SERVER_URL": "http://headscale:8080", "HEADSCALE_DERP_URLS": "https://controlplane.tailscale.com/derpmap/default", "HEADSCALE_DERP_AUTO_UPDATE_ENABLED": "false", diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index b61827ac..f867a5e2 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -260,7 +260,7 @@ func New( runOptions := &dockertest.RunOptions{ Name: hsic.hostname, - ExposedPorts: append([]string{portProto}, hsic.extraPorts...), + ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...), Networks: []*dockertest.Network{network}, // Cmd: []string{"headscale", "serve"}, // TODO(kradalby): Get rid of this hack, we currently need to give us some diff --git a/integration/route_test.go b/integration/route_test.go index d185acff..150dbd27 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -252,7 +252,7 @@ func TestHASubnetRouterFailover(t *testing.T) { scenario, err := NewScenario() assertNoErrf(t, "failed to create scenario: %s", err) - defer scenario.Shutdown() + // defer scenario.Shutdown() spec := map[string]int{ user: 3,