diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 82a5295f..46ab162b 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -191,6 +191,7 @@ func (m *mapSession) serve() { // //nolint:gocyclo func (m *mapSession) serveLongPoll() { + start := time.Now() m.beforeServeLongPoll() // Clean up the session when the client disconnects @@ -235,16 +236,6 @@ func (m *mapSession) serveLongPoll() { m.pollFailoverRoutes("node connected", m.node) - // Upgrade the writer to a ResponseController - rc := http.NewResponseController(m.w) - - // Longpolling will break if there is a write timeout, - // so it needs to be disabled. - rc.SetWriteDeadline(time.Time{}) - - ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) - defer cancel() - m.keepAliveTicker = time.NewTicker(m.keepAlive) m.h.nodeNotifier.AddNode(m.node.ID, m.ch) @@ -258,12 +249,12 @@ func (m *mapSession) serveLongPoll() { // consume channels with update, keep alives or "batch" blocking signals select { case <-m.cancelCh: - m.tracef("poll cancelled received") + m.tracef("poll cancelled received (%s)", time.Since(start).String()) mapResponseEnded.WithLabelValues("cancelled").Inc() return - case <-ctx.Done(): - m.tracef("poll context done") + case <-m.ctx.Done(): + m.tracef("poll context done (%s): %s", time.Since(start).String(), m.ctx.Err().Error()) mapResponseEnded.WithLabelValues("done").Inc() return @@ -354,14 +345,7 @@ func (m *mapSession) serveLongPoll() { m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m) return } - - err = rc.Flush() - if err != nil { - mapResponseSent.WithLabelValues("error", updateType).Inc() - m.errf(err, "flushing the map response to client, for mapSession: %p", m) - return - } - + m.w.(http.Flusher).Flush() log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") if debugHighCardinalityMetrics { @@ -375,22 +359,17 @@ func (m *mapSession) serveLongPoll() { case <-m.keepAliveTicker.C: data, err := m.mapper.KeepAliveResponse(m.req, m.node) if err != nil { - m.errf(err, "Error generating the keep alive msg") + m.errf(err, "Error generating the keepalive msg") mapResponseSent.WithLabelValues("error", "keepalive").Inc() return } _, err = m.w.Write(data) if err != nil { - m.errf(err, "Cannot write keep alive message") - mapResponseSent.WithLabelValues("error", "keepalive").Inc() - return - } - err = rc.Flush() - if err != nil { - m.errf(err, "flushing keep alive to client, for mapSession: %p", m) + m.errf(err, "Cannot write keepalive message") mapResponseSent.WithLabelValues("error", "keepalive").Inc() return } + m.w.(http.Flusher).Flush() if debugHighCardinalityMetrics { mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) diff --git a/integration/route_test.go b/integration/route_test.go index cfed09cf..8a91ef4f 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -1492,7 +1492,7 @@ func TestHASubnetRouterFailoverWhenNodeDisconnects2129(t *testing.T) { assert.True(t, nodeList[2].Online) // Kill off one of the docker containers to simulate a disconnect - err = scenario.DisconnectContainers(subRouter1.Hostname()) + err = scenario.DisconnectContainersFromScenario(subRouter1.Hostname()) assertNoErr(t, err) time.Sleep(5 * time.Second) @@ -1514,4 +1514,53 @@ func TestHASubnetRouterFailoverWhenNodeDisconnects2129(t *testing.T) { assert.False(t, nodeListAfterDisconnect[0].Online) assert.True(t, nodeListAfterDisconnect[1].Online) assert.True(t, nodeListAfterDisconnect[2].Online) + + var routesAfterDisconnect []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterDisconnect, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterDisconnect, 2) + + // Node 1 is primary + assert.Equal(t, true, routesAfterDisconnect[0].GetAdvertised()) + assert.Equal(t, true, routesAfterDisconnect[0].GetEnabled()) + assert.Equal(t, false, routesAfterDisconnect[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be non-primary") + + // Node 2 is not primary + assert.Equal(t, true, routesAfterDisconnect[1].GetAdvertised()) + assert.Equal(t, true, routesAfterDisconnect[1].GetEnabled()) + assert.Equal(t, true, routesAfterDisconnect[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be primary") + + // // Ensure the node can reconncet as expected + // err = scenario.ConnectContainersToScenario(subRouter1.Hostname()) + // assertNoErr(t, err) + + // time.Sleep(5 * time.Second) + + // var nodeListAfterReconnect []v1.Node + // err = executeAndUnmarshal( + // headscale, + // []string{ + // "headscale", + // "nodes", + // "list", + // "--output", + // "json", + // }, + // &nodeListAfterReconnect, + // ) + // assert.Nil(t, err) + // assert.Len(t, nodeListAfterReconnect, 3) + // assert.True(t, nodeListAfterReconnect[0].Online) + // assert.True(t, nodeListAfterReconnect[1].Online) + // assert.True(t, nodeListAfterReconnect[2].Online) } diff --git a/integration/scenario.go b/integration/scenario.go index 988156af..94361021 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -650,8 +650,8 @@ func (s *Scenario) WaitForTailscaleLogout() error { return nil } -// DisconnectContainers disconnects a list of containers from the network. -func (s *Scenario) DisconnectContainers(containers ...string) error { +// DisconnectContainersFromScenario disconnects a list of containers from the network. +func (s *Scenario) DisconnectContainersFromScenario(containers ...string) error { for _, container := range containers { if ctr, ok := s.pool.ContainerByName(container); ok { err := ctr.DisconnectFromNetwork(s.network) @@ -663,3 +663,17 @@ func (s *Scenario) DisconnectContainers(containers ...string) error { return nil } + +// ConnectContainersToScenario disconnects a list of containers from the network. +func (s *Scenario) ConnectContainersToScenario(containers ...string) error { + for _, container := range containers { + if ctr, ok := s.pool.ContainerByName(container); ok { + err := ctr.ConnectToNetwork(s.network) + if err != nil { + return err + } + } + } + + return nil +}