diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 65324f77..6edec248 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -60,6 +60,7 @@ jobs: - TestEnableDisableAutoApprovedRoute - TestAutoApprovedSubRoute2068 - TestSubnetRouteACL + - TestHASubnetRouterFailoverWhenNodeDisconnects2129 - TestHeadscale - TestCreateTailscale - TestTailscaleNodesJoiningHeadcale diff --git a/hscontrol/poll.go b/hscontrol/poll.go index a8ae01f4..870d5c5a 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -191,6 +191,7 @@ func (m *mapSession) serve() { // //nolint:gocyclo func (m *mapSession) serveLongPoll() { + start := time.Now() m.beforeServeLongPoll() // Clean up the session when the client disconnects @@ -220,16 +221,6 @@ func (m *mapSession) serveLongPoll() { m.pollFailoverRoutes("node connected", m.node) - // Upgrade the writer to a ResponseController - rc := http.NewResponseController(m.w) - - // Longpolling will break if there is a write timeout, - // so it needs to be disabled. - rc.SetWriteDeadline(time.Time{}) - - ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) - defer cancel() - m.keepAliveTicker = time.NewTicker(m.keepAlive) m.h.nodeNotifier.AddNode(m.node.ID, m.ch) @@ -243,12 +234,12 @@ func (m *mapSession) serveLongPoll() { // consume channels with update, keep alives or "batch" blocking signals select { case <-m.cancelCh: - m.tracef("poll cancelled received") + m.tracef("poll cancelled received (%s)", time.Since(start).String()) mapResponseEnded.WithLabelValues("cancelled").Inc() return - case <-ctx.Done(): - m.tracef("poll context done") + case <-m.ctx.Done(): + m.tracef("poll context done (%s): %s", time.Since(start).String(), m.ctx.Err().Error()) mapResponseEnded.WithLabelValues("done").Inc() return @@ -339,14 +330,7 @@ func (m *mapSession) serveLongPoll() { m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m) return } - - err = rc.Flush() - if err != nil { - mapResponseSent.WithLabelValues("error", updateType).Inc() - m.errf(err, "flushing the map response to client, for mapSession: %p", m) - return - } - + m.w.(http.Flusher).Flush() log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") if debugHighCardinalityMetrics { @@ -360,22 +344,17 @@ func (m *mapSession) serveLongPoll() { case <-m.keepAliveTicker.C: data, err := m.mapper.KeepAliveResponse(m.req, m.node) if err != nil { - m.errf(err, "Error generating the keep alive msg") + m.errf(err, "Error generating the keepalive msg") mapResponseSent.WithLabelValues("error", "keepalive").Inc() return } _, err = m.w.Write(data) if err != nil { - m.errf(err, "Cannot write keep alive message") - mapResponseSent.WithLabelValues("error", "keepalive").Inc() - return - } - err = rc.Flush() - if err != nil { - m.errf(err, "flushing keep alive to client, for mapSession: %p", m) + m.errf(err, "Cannot write keepalive message") mapResponseSent.WithLabelValues("error", "keepalive").Inc() return } + m.w.(http.Flusher).Flush() if debugHighCardinalityMetrics { mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) diff --git a/integration/route_test.go b/integration/route_test.go index f163fa14..ec520e57 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -13,6 +13,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" @@ -1316,3 +1317,252 @@ func TestSubnetRouteACL(t *testing.T) { t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff) } } + +func TestHASubnetRouterFailoverWhenNodeDisconnects2129(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + user := "enable-routing" + + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErrf(t, "failed to create scenario: %s", err) + // defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{ + user: 3, + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{}, + hsic.WithTestName("clientdisc"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom), + ) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + expectedRoutes := map[string]string{ + "1": "10.0.0.0/24", + "2": "10.0.0.0/24", + } + + // Sort nodes by ID + sort.SliceStable(allClients, func(i, j int) bool { + statusI, err := allClients[i].Status() + if err != nil { + return false + } + + statusJ, err := allClients[j].Status() + if err != nil { + return false + } + + return statusI.Self.ID < statusJ.Self.ID + }) + + subRouter1 := allClients[0] + subRouter2 := allClients[1] + + t.Logf("Advertise route from r1 (%s) and r2 (%s), making it HA, n1 is primary", subRouter1.Hostname(), subRouter2.Hostname()) + // advertise HA route on node 1 and 2 + // ID 1 will be primary + // ID 2 will be secondary + for _, client := range allClients[:2] { + status, err := client.Status() + assertNoErr(t, err) + + if route, ok := expectedRoutes[string(status.Self.ID)]; ok { + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + route, + } + _, _, err = client.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + } else { + t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID) + } + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + var routes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routes, + ) + + assertNoErr(t, err) + assert.Len(t, routes, 2) + + t.Logf("initial routes %#v", routes) + + for _, route := range routes { + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) + } + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(t, peerStatus.PrimaryRoutes) + } + } + + // Enable all routes + for _, route := range routes { + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "enable", + "--route", + strconv.Itoa(int(route.GetId())), + }) + assertNoErr(t, err) + + time.Sleep(time.Second) + } + + var enablingRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &enablingRoutes, + ) + assertNoErr(t, err) + assert.Len(t, enablingRoutes, 2) + + // Node 1 is primary + assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) + assert.Equal(t, true, enablingRoutes[0].GetEnabled()) + assert.Equal(t, true, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary") + + // Node 2 is not primary + assert.Equal(t, true, enablingRoutes[1].GetAdvertised()) + assert.Equal(t, true, enablingRoutes[1].GetEnabled()) + assert.Equal(t, false, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary") + + var nodeList []v1.Node + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &nodeList, + ) + assert.Nil(t, err) + assert.Len(t, nodeList, 3) + assert.True(t, nodeList[0].Online) + assert.True(t, nodeList[1].Online) + assert.True(t, nodeList[2].Online) + + // Kill off one of the docker containers to simulate a disconnect + err = scenario.DisconnectContainersFromScenario(subRouter1.Hostname()) + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var nodeListAfterDisconnect []v1.Node + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &nodeListAfterDisconnect, + ) + assert.Nil(t, err) + assert.Len(t, nodeListAfterDisconnect, 3) + assert.False(t, nodeListAfterDisconnect[0].Online) + assert.True(t, nodeListAfterDisconnect[1].Online) + assert.True(t, nodeListAfterDisconnect[2].Online) + + var routesAfterDisconnect []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterDisconnect, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterDisconnect, 2) + + // Node 1 is primary + assert.Equal(t, true, routesAfterDisconnect[0].GetAdvertised()) + assert.Equal(t, true, routesAfterDisconnect[0].GetEnabled()) + assert.Equal(t, false, routesAfterDisconnect[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be non-primary") + + // Node 2 is not primary + assert.Equal(t, true, routesAfterDisconnect[1].GetAdvertised()) + assert.Equal(t, true, routesAfterDisconnect[1].GetEnabled()) + assert.Equal(t, true, routesAfterDisconnect[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be primary") + + // // Ensure the node can reconncet as expected + // err = scenario.ConnectContainersToScenario(subRouter1.Hostname()) + // assertNoErr(t, err) + + // time.Sleep(5 * time.Second) + + // var nodeListAfterReconnect []v1.Node + // err = executeAndUnmarshal( + // headscale, + // []string{ + // "headscale", + // "nodes", + // "list", + // "--output", + // "json", + // }, + // &nodeListAfterReconnect, + // ) + // assert.Nil(t, err) + // assert.Len(t, nodeListAfterReconnect, 3) + // assert.True(t, nodeListAfterReconnect[0].Online) + // assert.True(t, nodeListAfterReconnect[1].Online) + // assert.True(t, nodeListAfterReconnect[2].Online) +} diff --git a/integration/scenario.go b/integration/scenario.go index b45c5fe7..1da771ab 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -651,3 +651,31 @@ func (s *Scenario) WaitForTailscaleLogout() error { return nil } + +// DisconnectContainersFromScenario disconnects a list of containers from the network. +func (s *Scenario) DisconnectContainersFromScenario(containers ...string) error { + for _, container := range containers { + if ctr, ok := s.pool.ContainerByName(container); ok { + err := ctr.DisconnectFromNetwork(s.network) + if err != nil { + return err + } + } + } + + return nil +} + +// ConnectContainersToScenario disconnects a list of containers from the network. +func (s *Scenario) ConnectContainersToScenario(containers ...string) error { + for _, container := range containers { + if ctr, ok := s.pool.ContainerByName(container); ok { + err := ctr.ConnectToNetwork(s.network) + if err != nil { + return err + } + } + } + + return nil +}