diff --git a/routes.go b/routes.go index 36d67a90..221db60f 100644 --- a/routes.go +++ b/routes.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" + "github.com/rs/zerolog/log" "gorm.io/gorm" ) @@ -34,6 +35,10 @@ func (r *Route) String() string { return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String()) } +func (r *Route) isExitRoute() bool { + return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 +} + func (rs Routes) toPrefixes() []netip.Prefix { prefixes := make([]netip.Prefix, len(rs)) for i, r := range rs { @@ -54,6 +59,23 @@ func (h *Headscale) isUniquePrefix(route Route) bool { return count == 0 } +func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { + var route Route + err := h.db. + Preload("Machine"). + Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). + First(&route).Error + if err != nil && err != gorm.ErrRecordNotFound { + return nil, err + } + + if err == gorm.ErrRecordNotFound { + return nil, gorm.ErrRecordNotFound + } + + return &route, nil +} + // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { @@ -120,3 +142,104 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { return nil } + +func (h *Headscale) handlePrimarySubnetFailover() error { + // first, get all the enabled routes + var routes []Route + err := h.db. + Preload("Machine"). + Where("advertised = ? AND enabled = ?", true, true). + Find(&routes).Error + if err != nil && err != gorm.ErrRecordNotFound { + log.Error().Err(err).Msg("error getting routes") + } + + for _, route := range routes { + if route.isExitRoute() { + continue + } + + if !route.IsPrimary { + _, err := h.getPrimaryRoute(netip.Prefix(route.Prefix)) + if h.isUniquePrefix(route) || err == gorm.ErrRecordNotFound { + route.IsPrimary = true + err := h.db.Save(&route).Error + if err != nil { + log.Error().Err(err).Msg("error marking route as primary") + + return err + } + continue + } + } + + if route.IsPrimary { + if route.Machine.isOnline() { + continue + } + + // machine offline, find a new primary + log.Info(). + Str("machine", route.Machine.Hostname). + Str("prefix", netip.Prefix(route.Prefix).String()). + Msgf("machine offline, finding a new primary subnet") + + // find a new primary route + var newPrimaryRoutes []Route + err := h.db. + Preload("Machine"). + Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", + route.Prefix, + route.MachineID, + true, true). + Find(&newPrimaryRoutes).Error + if err != nil && err != gorm.ErrRecordNotFound { + log.Error().Err(err).Msg("error finding new primary route") + + return err + } + + var newPrimaryRoute *Route + for _, r := range newPrimaryRoutes { + if r.Machine.isOnline() { + newPrimaryRoute = &r + break + } + } + + if newPrimaryRoute == nil { + log.Warn(). + Str("machine", route.Machine.Hostname). + Str("prefix", netip.Prefix(route.Prefix).String()). + Msgf("no alternative primary route found") + continue + } + + log.Info(). + Str("old_machine", route.Machine.Hostname). + Str("prefix", netip.Prefix(route.Prefix).String()). + Str("new_machine", newPrimaryRoute.Machine.Hostname). + Msgf("found new primary route") + + // disable the old primary route + route.IsPrimary = false + err = h.db.Save(&route).Error + if err != nil { + log.Error().Err(err).Msg("error disabling old primary route") + + return err + } + + // enable the new primary route + newPrimaryRoute.IsPrimary = true + err = h.db.Save(&newPrimaryRoute).Error + if err != nil { + log.Error().Err(err).Msg("error enabling new primary route") + + return err + } + } + } + + return nil +}