From 6718ff71d3ad0821de910ef18051d02ce6705aeb Mon Sep 17 00:00:00 2001 From: Juan Font Date: Thu, 24 Nov 2022 22:41:11 +0000 Subject: [PATCH] Added helper methods for subnet failover + unit tests Added method to perform subnet failover Added tests for subnet failover --- machine.go | 17 ++++-- routes.go | 123 +++++++++++++++++++++++++++++++++++++++++ routes_test.go | 147 ++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 281 insertions(+), 6 deletions(-) diff --git a/machine.go b/machine.go index cc15248f..0ac56d88 100644 --- a/machine.go +++ b/machine.go @@ -138,6 +138,17 @@ func (machine Machine) isExpired() bool { return time.Now().UTC().After(*machine.Expiry) } +// isOnline returns if the machine is connected to Headscale. +// This is really a naive implementation, as we don't really see +// if there is a working connection between the client and the server. +func (machine *Machine) isOnline() bool { + if machine.LastSeen == nil { + return false + } + + return machine.LastSeen.After(time.Now().Add(-keepAliveInterval)) +} + func containsAddresses(inputs []string, addrs []string) bool { for _, addr := range addrs { if contains(inputs, addr) { @@ -708,9 +719,7 @@ func (h *Headscale) toNode( hostInfo := machine.GetHostInfo() - // A node is Online if it is connected to the control server, - // and we now we update LastSeen every keepAliveInterval duration at least. - online := machine.LastSeen.After(time.Now().Add(-keepAliveInterval)) + online := machine.isOnline() node := tailcfg.Node{ ID: tailcfg.NodeID(machine.ID), // this is the actual ID @@ -1027,7 +1036,7 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { // Mark already as primary if there is only this node offering this subnet // (and is not an exit route) - if prefix != ExitRouteV4 && prefix != ExitRouteV6 { + if !route.isExitRoute() { route.IsPrimary = h.isUniquePrefix(route) } diff --git a/routes.go b/routes.go index 36d67a90..221db60f 100644 --- a/routes.go +++ b/routes.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" + "github.com/rs/zerolog/log" "gorm.io/gorm" ) @@ -34,6 +35,10 @@ func (r *Route) String() string { return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) } +func (r *Route) isExitRoute() bool { + return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 +} + func (rs Routes) toPrefixes() []netip.Prefix { prefixes := make([]netip.Prefix, len(rs)) for i, r := range rs { @@ -54,6 +59,23 @@ func (h *Headscale) isUniquePrefix(route Route) bool { return count == 0 } +func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { + var route Route + err := h.db. + Preload("Machine"). + Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). + First(&route).Error + if err != nil && err != gorm.ErrRecordNotFound { + return nil, err + } + + if err == gorm.ErrRecordNotFound { + return nil, gorm.ErrRecordNotFound + } + + return &route, nil +} + // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { @@ -120,3 +142,104 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { return nil } + +func (h *Headscale) handlePrimarySubnetFailover() error { + // first, get all the enabled routes + var routes []Route + err := h.db. + Preload("Machine"). + Where("advertised = ? AND enabled = ?", true, true). + Find(&routes).Error + if err != nil && err != gorm.ErrRecordNotFound { + log.Error().Err(err).Msg("error getting routes") + } + + for _, route := range routes { + if route.isExitRoute() { + continue + } + + if !route.IsPrimary { + _, err := h.getPrimaryRoute(netip.Prefix(route.Prefix)) + if h.isUniquePrefix(route) || err == gorm.ErrRecordNotFound { + route.IsPrimary = true + err := h.db.Save(&route).Error + if err != nil { + log.Error().Err(err).Msg("error marking route as primary") + + return err + } + continue + } + } + + if route.IsPrimary { + if route.Machine.isOnline() { + continue + } + + // machine offline, find a new primary + log.Info(). + Str("machine", route.Machine.Hostname). + Str("prefix", netip.Prefix(route.Prefix).String()). + Msgf("machine offline, finding a new primary subnet") + + // find a new primary route + var newPrimaryRoutes []Route + err := h.db. + Preload("Machine"). + Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", + route.Prefix, + route.MachineID, + true, true). + Find(&newPrimaryRoutes).Error + if err != nil && err != gorm.ErrRecordNotFound { + log.Error().Err(err).Msg("error finding new primary route") + + return err + } + + var newPrimaryRoute *Route + for _, r := range newPrimaryRoutes { + if r.Machine.isOnline() { + newPrimaryRoute = &r + break + } + } + + if newPrimaryRoute == nil { + log.Warn(). + Str("machine", route.Machine.Hostname). + Str("prefix", netip.Prefix(route.Prefix).String()). + Msgf("no alternative primary route found") + continue + } + + log.Info(). + Str("old_machine", route.Machine.Hostname). + Str("prefix", netip.Prefix(route.Prefix).String()). + Str("new_machine", newPrimaryRoute.Machine.Hostname). + Msgf("found new primary route") + + // disable the old primary route + route.IsPrimary = false + err = h.db.Save(&route).Error + if err != nil { + log.Error().Err(err).Msg("error disabling old primary route") + + return err + } + + // enable the new primary route + newPrimaryRoute.IsPrimary = true + err = h.db.Save(&newPrimaryRoute).Error + if err != nil { + log.Error().Err(err).Msg("error enabling new primary route") + + return err + } + } + } + + return nil +} diff --git a/routes_test.go b/routes_test.go index 2560898e..8d98cebc 100644 --- a/routes_test.go +++ b/routes_test.go @@ -2,6 +2,7 @@ package headscale import ( "net/netip" + "time" "gopkg.in/check.v1" "tailscale.com/tailcfg" @@ -150,7 +151,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { RoutableIPs: []netip.Prefix{route, route2}, } machine1 := Machine{ - ID: 0, + ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", @@ -175,7 +176,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { RoutableIPs: []netip.Prefix{route2}, } machine2 := Machine{ - ID: 0, + ID: 2, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", @@ -209,3 +210,145 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(routes), check.Equals, 0) } + +func (s *Suite) TestSubnetFailover(c *check.C) { + namespace, err := app.CreateNamespace("test") + c.Assert(err, check.IsNil) + + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = app.GetMachine("test", "test_enable_route_machine") + c.Assert(err, check.NotNil) + + prefix, err := netip.ParsePrefix( + "10.0.0.0/24", + ) + c.Assert(err, check.IsNil) + + prefix2, err := netip.ParsePrefix( + "150.0.10.0/25", + ) + c.Assert(err, check.IsNil) + + hostInfo1 := tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{prefix, prefix2}, + } + + now := time.Now() + machine1 := Machine{ + ID: 1, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "test_enable_route_machine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: HostInfo(hostInfo1), + LastSeen: &now, + } + app.db.Save(&machine1) + + err = app.processMachineRoutes(&machine1) + c.Assert(err, check.IsNil) + + err = app.EnableRoutes(&machine1, prefix.String()) + c.Assert(err, check.IsNil) + + err = app.EnableRoutes(&machine1, prefix2.String()) + c.Assert(err, check.IsNil) + + err = app.handlePrimarySubnetFailover() + c.Assert(err, check.IsNil) + + enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes1), check.Equals, 2) + + route, err := app.getPrimaryRoute(prefix) + c.Assert(err, check.IsNil) + c.Assert(route.MachineID, check.Equals, machine1.ID) + + hostInfo2 := tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{prefix2}, + } + machine2 := Machine{ + ID: 2, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "test_enable_route_machine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: HostInfo(hostInfo2), + LastSeen: &now, + } + app.db.Save(&machine2) + + err = app.processMachineRoutes(&machine2) + c.Assert(err, check.IsNil) + + err = app.EnableRoutes(&machine2, prefix2.String()) + c.Assert(err, check.IsNil) + + err = app.handlePrimarySubnetFailover() + c.Assert(err, check.IsNil) + + enabledRoutes1, err = app.GetEnabledRoutes(&machine1) + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes1), check.Equals, 2) + + enabledRoutes2, err := app.GetEnabledRoutes(&machine2) + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes2), check.Equals, 1) + + routes, err := app.getMachinePrimaryRoutes(&machine1) + c.Assert(err, check.IsNil) + c.Assert(len(routes), check.Equals, 2) + + routes, err = app.getMachinePrimaryRoutes(&machine2) + c.Assert(err, check.IsNil) + c.Assert(len(routes), check.Equals, 0) + + // lets make machine1 lastseen 10 mins ago + before := now.Add(-10 * time.Minute) + machine1.LastSeen = &before + err = app.db.Save(&machine1).Error + c.Assert(err, check.IsNil) + + err = app.handlePrimarySubnetFailover() + c.Assert(err, check.IsNil) + + routes, err = app.getMachinePrimaryRoutes(&machine1) + c.Assert(err, check.IsNil) + c.Assert(len(routes), check.Equals, 1) + + routes, err = app.getMachinePrimaryRoutes(&machine2) + c.Assert(err, check.IsNil) + c.Assert(len(routes), check.Equals, 1) + + machine2.HostInfo = HostInfo(tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{prefix, prefix2}, + }) + err = app.db.Save(&machine2).Error + c.Assert(err, check.IsNil) + + err = app.processMachineRoutes(&machine2) + c.Assert(err, check.IsNil) + + err = app.EnableRoutes(&machine2, prefix.String()) + c.Assert(err, check.IsNil) + + err = app.handlePrimarySubnetFailover() + c.Assert(err, check.IsNil) + + routes, err = app.getMachinePrimaryRoutes(&machine1) + c.Assert(err, check.IsNil) + c.Assert(len(routes), check.Equals, 0) + + routes, err = app.getMachinePrimaryRoutes(&machine2) + c.Assert(err, check.IsNil) + c.Assert(len(routes), check.Equals, 2) +}