diff --git a/CHANGELOG.md b/CHANGELOG.md index a5441a5a..a7908eaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473) - API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553) - Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611) - - The latest supported client is 1.36 + - The latest supported client is 1.38 - Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564) - If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url. - Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611) diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index dfaf512f..3f3322e2 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -6,25 +6,11 @@ import ( "github.com/efekarakus/termcolor" "github.com/juanfont/headscale/cmd/headscale/cli" - "github.com/pkg/profile" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) func main() { - if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { - if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { - err := os.MkdirAll(profilePath, os.ModePerm) - if err != nil { - log.Fatal().Err(err).Msg("failed to create profiling directory") - } - - defer profile.Start(profile.ProfilePath(profilePath)).Stop() - } else { - defer profile.Start().Stop() - } - } - var colors bool switch l := termcolor.SupportLevel(os.Stderr); l { case termcolor.Level16M: diff --git a/hscontrol/app.go b/hscontrol/app.go index 75dfddee..91d53263 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -33,6 +33,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" + "github.com/pkg/profile" "github.com/prometheus/client_golang/prometheus/promhttp" zl "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -48,6 +49,7 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" + "gorm.io/gorm" "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -61,7 +63,7 @@ var ( "unknown value for Lets Encrypt challenge type", ) errEmptyInitialDERPMap = errors.New( - "initial DERPMap is empty, Headscale requries at least one entry", + "initial DERPMap is empty, Headscale requires at least one entry", ) ) @@ -166,7 +168,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { cfg.DBtype, dbString, app.dbDebug, - app.nodeNotifier, cfg.IPPrefixes, cfg.BaseDomain) if err != nil { @@ -234,8 +235,23 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { // seen for longer than h.cfg.EphemeralNodeInactivityTimeout. func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) + + var update types.StateUpdate + var changed bool for range ticker.C { - h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout) + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) + + return nil + }); err != nil { + log.Error().Err(err).Msg("database error while expiring ephemeral nodes") + continue + } + + if changed && update.Valid() { + ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, update) + } } } @@ -246,9 +262,24 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) { ticker := time.NewTicker(interval) lastCheck := time.Unix(0, 0) + var update types.StateUpdate + var changed bool for range ticker.C { - lastCheck = h.db.ExpireExpiredNodes(lastCheck) + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) + + return nil + }); err != nil { + log.Error().Err(err).Msg("database error while expiring nodes") + continue + } + + log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes") + if changed && update.Valid() { + ctx := types.NotifyCtx(context.Background(), "expire-expired", "na") + h.nodeNotifier.NotifyAll(ctx, update) + } } } @@ -278,7 +309,8 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { DERPMap: h.DERPMap, } if stateUpdate.Valid() { - h.nodeNotifier.NotifyAll(stateUpdate) + ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na") + h.nodeNotifier.NotifyAll(ctx, stateUpdate) } } } @@ -485,6 +517,19 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches a GIN server with the Headscale API. func (h *Headscale) Serve() error { + if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { + if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { + err := os.MkdirAll(profilePath, os.ModePerm) + if err != nil { + log.Fatal().Err(err).Msg("failed to create profiling directory") + } + + defer profile.Start(profile.ProfilePath(profilePath)).Stop() + } else { + defer profile.Start().Stop() + } + } + var err error // Fetch an initial DERP Map before we start serving @@ -753,7 +798,8 @@ func (h *Headscale) Serve() error { Str("path", aclPath). Msg("ACL policy successfully reloaded, notifying nodes of change") - h.nodeNotifier.NotifyAll(types.StateUpdate{ + ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na") + h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ Type: types.StateFullUpdate, }) } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 9b44c2d8..3e9557a8 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -1,6 +1,7 @@ package hscontrol import ( + "context" "encoding/json" "errors" "fmt" @@ -8,6 +9,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -243,8 +245,6 @@ func (h *Headscale) handleRegister( // handleAuthKey contains the logic to manage auth key client registration // When using Noise, the machineKey is Zero. -// -// TODO: check if any locks are needed around IP allocation. func (h *Headscale) handleAuthKey( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, @@ -311,6 +311,9 @@ func (h *Headscale) handleAuthKey( nodeKey := registerRequest.NodeKey + var update types.StateUpdate + var mkey key.MachinePublic + // retrieve node information if it exist // The error is not important, because if it does not // exist, then this is a new node and we will move @@ -324,7 +327,7 @@ func (h *Headscale) handleAuthKey( node.NodeKey = nodeKey node.AuthKeyID = uint(pak.ID) - err := h.db.NodeSetExpiry(node, registerRequest.Expiry) + err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry) if err != nil { log.Error(). Caller(). @@ -335,10 +338,13 @@ func (h *Headscale) handleAuthKey( return } + mkey = node.MachineKey + update = types.StateUpdateExpire(node.ID, registerRequest.Expiry) + aclTags := pak.Proto().GetAclTags() if len(aclTags) > 0 { // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.db.SetTags(node, aclTags) + err = h.db.SetTags(node.ID, aclTags) if err != nil { log.Error(). @@ -370,6 +376,7 @@ func (h *Headscale) handleAuthKey( Hostname: registerRequest.Hostinfo.Hostname, GivenName: givenName, UserID: pak.User.ID, + User: pak.User, MachineKey: machineKey, RegisterMethod: util.RegisterMethodAuthKey, Expiry: ®isterRequest.Expiry, @@ -393,9 +400,18 @@ func (h *Headscale) handleAuthKey( return } + + mkey = node.MachineKey + update = types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from auth.handleAuthKey", + } } - err = h.db.UsePreAuthKey(pak) + err = h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.UsePreAuthKey(tx, pak) + }) if err != nil { log.Error(). Caller(). @@ -437,6 +453,13 @@ func (h *Headscale) handleAuthKey( Caller(). Err(err). Msg("Failed to write response") + return + } + + // TODO(kradalby): if notifying after register make sense. + if update.Valid() { + ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String()) } log.Info(). @@ -502,7 +525,7 @@ func (h *Headscale) handleNodeLogOut( Msg("Client requested logout") now := time.Now() - err := h.db.NodeSetExpiry(&node, now) + err := h.db.NodeSetExpiry(node.ID, now) if err != nil { log.Error(). Caller(). @@ -513,17 +536,10 @@ func (h *Headscale) handleNodeLogOut( return } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &now, - }, - }, - } + stateUpdate := types.StateUpdateExpire(node.ID, now) if stateUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) + ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } resp.AuthURL = "" @@ -554,7 +570,7 @@ func (h *Headscale) handleNodeLogOut( } if node.IsEphemeral() { - err = h.db.DeleteNode(&node) + err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap()) if err != nil { log.Error(). Err(err). @@ -562,6 +578,15 @@ func (h *Headscale) handleNodeLogOut( Msg("Cannot delete ephemeral node from the database") } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, stateUpdate) + } + return } @@ -633,7 +658,9 @@ func (h *Headscale) handleNodeKeyRefresh( Str("node", node.Hostname). Msg("We have the OldNodeKey in the database. This is a key refresh") - err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey) + err := h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey) + }) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/db/addresses.go b/hscontrol/db/addresses.go index beccf843..58578702 100644 --- a/hscontrol/db/addresses.go +++ b/hscontrol/db/addresses.go @@ -13,16 +13,23 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" + "gorm.io/gorm" ) var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.NodeAddresses, error) { + return getAvailableIPs(rx, hsdb.ipPrefixes) + }) +} + +func getAvailableIPs(rx *gorm.DB, ipPrefixes []netip.Prefix) (types.NodeAddresses, error) { var ips types.NodeAddresses var err error - for _, ipPrefix := range hsdb.ipPrefixes { + for _, ipPrefix := range ipPrefixes { var ip *netip.Addr - ip, err = hsdb.getAvailableIP(ipPrefix) + ip, err = getAvailableIP(rx, ipPrefix) if err != nil { return ips, err } @@ -32,8 +39,8 @@ func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) { return ips, err } -func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { - usedIps, err := hsdb.getUsedIPs() +func getAvailableIP(rx *gorm.DB, ipPrefix netip.Prefix) (*netip.Addr, error) { + usedIps, err := getUsedIPs(rx) if err != nil { return nil, err } @@ -64,12 +71,12 @@ func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, erro } } -func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { +func getUsedIPs(rx *gorm.DB) (*netipx.IPSet, error) { // FIXME: This really deserves a better data model, // but this was quick to get running and it should be enough // to begin experimenting with a dual stack tailnet. var addressesSlices []string - hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices) + rx.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices) var ips netipx.IPSetBuilder for _, slice := range addressesSlices { diff --git a/hscontrol/db/addresses_test.go b/hscontrol/db/addresses_test.go index 07059eab..ef336590 100644 --- a/hscontrol/db/addresses_test.go +++ b/hscontrol/db/addresses_test.go @@ -7,10 +7,16 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ips, err := db.getAvailableIPs() + tx := db.DB.Begin() + defer tx.Rollback() + + ips, err := getAvailableIPs(tx, []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + }) c.Assert(err, check.IsNil) @@ -30,7 +36,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -41,10 +47,13 @@ func (s *Suite) TestGetUsedIps(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - db.db.Save(&node) - - usedIps, err := db.getUsedIPs() + db.Write(func(tx *gorm.DB) error { + return tx.Save(&node).Error + }) + usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) { + return getUsedIPs(rx) + }) c.Assert(err, check.IsNil) expected := netip.MustParseAddr("10.27.0.1") @@ -63,19 +72,23 @@ func (s *Suite) TestGetUsedIps(c *check.C) { } func (s *Suite) TestGetMultiIp(c *check.C) { - user, err := db.CreateUser("test-ip-multi") + user, err := db.CreateUser("test-ip") c.Assert(err, check.IsNil) + ipPrefixes := []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + } + for index := 1; index <= 350; index++ { - db.ipAllocationMutex.Lock() + tx := db.DB.Begin() - ips, err := db.getAvailableIPs() + ips, err := getAvailableIPs(tx, ipPrefixes) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := CreatePreAuthKey(tx, user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = getNode(tx, "test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -86,12 +99,13 @@ func (s *Suite) TestGetMultiIp(c *check.C) { AuthKeyID: uint(pak.ID), IPAddresses: ips, } - db.db.Save(&node) - - db.ipAllocationMutex.Unlock() + tx.Save(&node) + c.Assert(tx.Commit().Error, check.IsNil) } - usedIps, err := db.getUsedIPs() + usedIps, err := Read(db.DB, func(rx *gorm.DB) (*netipx.IPSet, error) { + return getUsedIPs(rx) + }) c.Assert(err, check.IsNil) expected0 := netip.MustParseAddr("10.27.0.1") @@ -162,7 +176,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) node := types.Node{ @@ -172,7 +186,7 @@ func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) ips2, err := db.getAvailableIPs() c.Assert(err, check.IsNil) diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index bc8dc2bb..51083145 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -22,9 +22,6 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *types.APIKey, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err @@ -49,7 +46,7 @@ func (hsdb *HSDatabase) CreateAPIKey( Expiration: expiration, } - if err := hsdb.db.Save(&key).Error; err != nil { + if err := hsdb.DB.Save(&key).Error; err != nil { return "", nil, fmt.Errorf("failed to save API key to database: %w", err) } @@ -58,11 +55,8 @@ func (hsdb *HSDatabase) CreateAPIKey( // ListAPIKeys returns the list of ApiKeys for a user. func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - keys := []types.APIKey{} - if err := hsdb.db.Find(&keys).Error; err != nil { + if err := hsdb.DB.Find(&keys).Error; err != nil { return nil, err } @@ -71,11 +65,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - key := types.APIKey{} - if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { + if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -84,11 +75,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - key := types.APIKey{} - if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { + if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } @@ -98,10 +86,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { + if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -110,10 +95,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { // ExpireAPIKey marks a ApiKey as expired. func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { + if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -121,9 +103,6 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { } func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - prefix, hash, found := strings.Cut(keyStr, ".") if !found { return false, ErrAPIKeyFailedToParse diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 030a6f0b..df7b0a4c 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -7,12 +7,10 @@ import ( "fmt" "net/netip" "strings" - "sync" "time" "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" - "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -36,12 +34,7 @@ type KV struct { } type HSDatabase struct { - db *gorm.DB - notifier *notifier.Notifier - - mu sync.RWMutex - - ipAllocationMutex sync.Mutex + DB *gorm.DB ipPrefixes []netip.Prefix baseDomain string @@ -52,7 +45,6 @@ type HSDatabase struct { func NewHeadscaleDatabase( dbType, connectionAddr string, debug bool, - notifier *notifier.Notifier, ipPrefixes []netip.Prefix, baseDomain string, ) (*HSDatabase, error) { @@ -147,7 +139,9 @@ func NewHeadscaleDatabase( DiscoKey string } var results []result - err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error + err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes"). + Find(&results). + Error if err != nil { return err } @@ -180,7 +174,8 @@ func NewHeadscaleDatabase( } if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") { - log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...") + log.Info(). + Msgf("Database has legacy enabled_routes column in node, migrating...") type NodeAux struct { ID uint64 @@ -317,8 +312,7 @@ func NewHeadscaleDatabase( } db := HSDatabase{ - db: dbConn, - notifier: notifier, + DB: dbConn, ipPrefixes: ipPrefixes, baseDomain: baseDomain, @@ -376,7 +370,7 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { func (hsdb *HSDatabase) PingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - sqlDB, err := hsdb.db.DB() + sqlDB, err := hsdb.DB.DB() if err != nil { return err } @@ -385,10 +379,48 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error { } func (hsdb *HSDatabase) Close() error { - db, err := hsdb.db.DB() + db, err := hsdb.DB.DB() if err != nil { return err } return db.Close() } + +func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error { + rx := hsdb.DB.Begin() + defer rx.Rollback() + return fn(rx) +} + +func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) { + rx := db.Begin() + defer rx.Rollback() + ret, err := fn(rx) + if err != nil { + var no T + return no, err + } + return ret, nil +} + +func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error { + tx := hsdb.DB.Begin() + defer tx.Rollback() + if err := fn(tx); err != nil { + return err + } + + return tx.Commit().Error +} + +func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) { + tx := db.Begin() + defer tx.Rollback() + ret, err := fn(tx) + if err != nil { + var no T + return no, err + } + return ret, tx.Commit().Error +} diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index e2a82cc3..a7474299 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -34,22 +34,21 @@ var ( ) ) -// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listPeers(node) + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListPeers(rx, node) + }) } -func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { +// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. +func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) { log.Trace(). Caller(). Str("node", node.Hostname). Msg("Finding direct peers") nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -64,16 +63,15 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodes() +func (hsdb *HSDatabase) ListNodes() (types.Nodes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListNodes(rx) + }) } -func (hsdb *HSDatabase) listNodes() ([]types.Node, error) { - nodes := []types.Node{} - if err := hsdb.db. +func ListNodes(tx *gorm.DB) (types.Nodes, error) { + nodes := types.Nodes{} + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -85,16 +83,9 @@ func (hsdb *HSDatabase) listNodes() ([]types.Node, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodesByGivenName(givenName) -} - -func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) { +func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) { nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -106,12 +97,15 @@ func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, err return nodes, nil } -// GetNode finds a Node by name and user and returns the Node struct. -func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() +func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return getNode(rx, user, name) + }) +} - nodes, err := hsdb.ListNodesByUser(user) +// getNode finds a Node by name and user and returns the Node struct. +func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { + nodes, err := ListNodesByUser(tx, user) if err != nil { return nil, err } @@ -125,34 +119,16 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { return nil, ErrNodeNotFound } -// GetNodeByGivenName finds a Node by given name and user and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByGivenName( - user string, - givenName string, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - node := types.Node{} - if err := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Where("given_name = ?", givenName).First(&node).Error; err != nil { - return nil, err - } - - return nil, ErrNodeNotFound +func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByID(rx, id) + }) } // GetNodeByID finds a Node by ID and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) { mach := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -164,21 +140,19 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { return &mach, nil } -// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByMachineKey( - machineKey key.MachinePublic, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeByMachineKey(machineKey) +func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByMachineKey(rx, machineKey) + }) } -func (hsdb *HSDatabase) getNodeByMachineKey( +// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. +func GetNodeByMachineKey( + tx *gorm.DB, machineKey key.MachinePublic, ) (*types.Node, error) { mach := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -190,36 +164,24 @@ func (hsdb *HSDatabase) getNodeByMachineKey( return &mach, nil } -// GetNodeByNodeKey finds a Node by its current NodeKey. -func (hsdb *HSDatabase) GetNodeByNodeKey( +func (hsdb *HSDatabase) GetNodeByAnyKey( + machineKey key.MachinePublic, nodeKey key.NodePublic, + oldNodeKey key.NodePublic, ) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - node := types.Node{} - if result := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - First(&node, "node_key = ?", - nodeKey.String()); result.Error != nil { - return nil, result.Error - } - - return &node, nil + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey) + }) } // GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByAnyKey( +// TODO(kradalby): see if we can remove this. +func GetNodeByAnyKey( + tx *gorm.DB, machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, ) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - node := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). @@ -234,60 +196,44 @@ func (hsdb *HSDatabase) GetNodeByAnyKey( return &node, nil } -func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - if result := hsdb.db.Find(node).First(&node); result.Error != nil { - return result.Error - } - - return nil +func (hsdb *HSDatabase) SetTags( + nodeID uint64, + tags []string, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + return SetTags(tx, nodeID, tags) + }) } // SetTags takes a Node struct pointer and update the forced tags. -func (hsdb *HSDatabase) SetTags( - node *types.Node, +func SetTags( + tx *gorm.DB, + nodeID uint64, tags []string, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - if len(tags) == 0 { return nil } - newTags := []string{} + newTags := types.StringList{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { newTags = append(newTags, tag) } } - if err := hsdb.db.Model(node).Updates(types.Node{ - ForcedTags: newTags, - }).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil { return fmt.Errorf("failed to update tags for node in the database: %w", err) } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.SetTags", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - return nil } // RenameNode takes a Node struct and a new GivenName for the nodes // and renames it. -func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func RenameNode(tx *gorm.DB, + nodeID uint64, newName string, +) error { err := util.CheckForFQDNRules( newName, ) @@ -295,129 +241,74 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { log.Error(). Caller(). Str("func", "RenameNode"). - Str("node", node.Hostname). + Uint64("nodeID", nodeID). Str("newName", newName). Err(err). Msg("failed to rename node") return err } - node.GivenName = newName - if err := hsdb.db.Model(node).Updates(types.Node{ - GivenName: newName, - }).Error; err != nil { + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { return fmt.Errorf("failed to rename node in the database: %w", err) } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.RenameNode", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - return nil } +func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error { + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetExpiry(tx, nodeID, expiry) + }) +} + // NodeSetExpiry takes a Node struct and a new expiry time. -func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.nodeSetExpiry(node, expiry) +func NodeSetExpiry(tx *gorm.DB, + nodeID uint64, expiry time.Time, +) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error } -func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error { - if err := hsdb.db.Model(node).Updates(types.Node{ - Expiry: &expiry, - }).Error; err != nil { - return fmt.Errorf( - "failed to refresh node (update expiration) in the database: %w", - err, - ) - } - - node.Expiry = &expiry - - stateSelfUpdate := types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: types.Nodes{node}, - } - if stateSelfUpdate.Valid() { - hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey) - } - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &expiry, - }, - }, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - - return nil +func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DeleteNode(tx, node, isConnected) + }) } // DeleteNode deletes a Node from the database. -func (hsdb *HSDatabase) DeleteNode(node *types.Node) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.deleteNode(node) -} - -func (hsdb *HSDatabase) deleteNode(node *types.Node) error { - err := hsdb.deleteNodeRoutes(node) +// Caller is responsible for notifying all of change. +func DeleteNode(tx *gorm.DB, + node *types.Node, + isConnected map[key.MachinePublic]bool, +) error { + err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{}) if err != nil { return err } // Unscoped causes the node to be fully removed from the database. - if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil { + if err := tx.Unscoped().Delete(&node).Error; err != nil { return err } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - return nil } // UpdateLastSeen sets a node's last seen field indicating that we // have recently communicating with this node. -// This is mostly used to indicate if a node is online and is not -// extremely important to make sure is fully correct and to avoid -// holding up the hot path, does not contain any locks and isnt -// concurrency safe. But that should be ok. -func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error { - return hsdb.db.Model(node).Updates(types.Node{ - LastSeen: node.LastSeen, - }).Error +func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error } -func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( +func RegisterNodeFromAuthCallback( + tx *gorm.DB, cache *cache.Cache, mkey key.MachinePublic, userName string, nodeExpiry *time.Time, registrationMethod string, + ipPrefixes []netip.Prefix, ) (*types.Node, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - log.Debug(). Str("machine_key", mkey.ShortString()). Str("userName", userName). @@ -427,7 +318,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( if nodeInterface, ok := cache.Get(mkey.String()); ok { if registrationNode, ok := nodeInterface.(types.Node); ok { - user, err := hsdb.getUser(userName) + user, err := GetUser(tx, userName) if err != nil { return nil, fmt.Errorf( "failed to find user in register node from auth callback, %w", @@ -442,14 +333,17 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( } registrationNode.UserID = user.ID + registrationNode.User = *user registrationNode.RegisterMethod = registrationMethod if nodeExpiry != nil { registrationNode.Expiry = nodeExpiry } - node, err := hsdb.registerNode( + node, err := RegisterNode( + tx, registrationNode, + ipPrefixes, ) if err == nil { @@ -465,15 +359,14 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( return nil, ErrNodeNotFoundRegistrationCache } -// RegisterNode is executed from the CLI to register a new Node using its MachineKey. func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.registerNode(node) + return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { + return RegisterNode(tx, node, hsdb.ipPrefixes) + }) } -func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { +// RegisterNode is executed from the CLI to register a new Node using its MachineKey. +func RegisterNode(tx *gorm.DB, node types.Node, ipPrefixes []netip.Prefix) (*types.Node, error) { log.Debug(). Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). @@ -485,7 +378,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { // so we store the node.Expire and node.Nodekey that has been set when // adding it to the registrationCache if len(node.IPAddresses) > 0 { - if err := hsdb.db.Save(&node).Error; err != nil { + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register existing node in the database: %w", err) } @@ -500,10 +393,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { return &node, nil } - hsdb.ipAllocationMutex.Lock() - defer hsdb.ipAllocationMutex.Unlock() - - ips, err := hsdb.getAvailableIPs() + ips, err := getAvailableIPs(tx, ipPrefixes) if err != nil { log.Error(). Caller(). @@ -516,7 +406,7 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { node.IPAddresses = ips - if err := hsdb.db.Save(&node).Error; err != nil { + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register(save) node in the database: %w", err) } @@ -530,61 +420,50 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { } // NodeSetNodeKey sets the node key of a node and saves it to the database. -func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Model(node).Updates(types.Node{ +func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error { + return tx.Model(node).Updates(types.Node{ NodeKey: nodeKey, - }).Error; err != nil { - return err - } - - return nil + }).Error } -// NodeSetMachineKey sets the node key of a node and saves it to the database. func (hsdb *HSDatabase) NodeSetMachineKey( node *types.Node, machineKey key.MachinePublic, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetMachineKey(tx, node, machineKey) + }) +} - if err := hsdb.db.Model(node).Updates(types.Node{ +// NodeSetMachineKey sets the node key of a node and saves it to the database. +func NodeSetMachineKey( + tx *gorm.DB, + node *types.Node, + machineKey key.MachinePublic, +) error { + return tx.Model(node).Updates(types.Node{ MachineKey: machineKey, - }).Error; err != nil { - return err - } - - return nil + }).Error } // NodeSave saves a node object to the database, prefer to use a specific save method rather // than this. It is intended to be used when we are changing or. -func (hsdb *HSDatabase) NodeSave(node *types.Node) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +// TODO(kradalby): Remove this func, just use Save. +func NodeSave(tx *gorm.DB, node *types.Node) error { + return tx.Save(node).Error +} - if err := hsdb.db.Save(node).Error; err != nil { - return err - } - - return nil +func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { + return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { + return GetAdvertisedRoutes(rx, node) + }) } // GetAdvertisedRoutes returns the routes that are be advertised by the given node. -func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getAdvertisedRoutes(node) -} - -func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { +func GetAdvertisedRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { routes := types.Routes{} - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -605,18 +484,17 @@ func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, e return prefixes, nil } -// GetEnabledRoutes returns the routes that are enabled for the node. func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getEnabledRoutes(node) + return Read(hsdb.DB, func(rx *gorm.DB) ([]netip.Prefix, error) { + return GetEnabledRoutes(rx, node) + }) } -func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { +// GetEnabledRoutes returns the routes that are enabled for the node. +func GetEnabledRoutes(tx *gorm.DB, node *types.Node) ([]netip.Prefix, error) { routes := types.Routes{} - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true). Find(&routes).Error @@ -638,16 +516,13 @@ func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, erro return prefixes, nil } -func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func IsRoutesEnabled(tx *gorm.DB, node *types.Node, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes, err := hsdb.getEnabledRoutes(node) + enabledRoutes, err := GetEnabledRoutes(tx, node) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") @@ -663,26 +538,37 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool return false } +func (hsdb *HSDatabase) enableRoutes( + node *types.Node, + routeStrs ...string, +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return enableRoutes(tx, node, routeStrs...) + }) +} + // enableRoutes enables new routes based on a list of new routes. -func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error { +func enableRoutes(tx *gorm.DB, + node *types.Node, routeStrs ...string, +) (*types.StateUpdate, error) { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netip.ParsePrefix(routeStr) if err != nil { - return err + return nil, err } newRoutes[index] = route } - advertisedRoutes, err := hsdb.getAdvertisedRoutes(node) + advertisedRoutes, err := GetAdvertisedRoutes(tx, node) if err != nil { - return err + return nil, err } for _, newRoute := range newRoutes { if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { - return fmt.Errorf( + return nil, fmt.Errorf( "route (%s) is not available on node %s: %w", node.Hostname, newRoute, ErrNodeRouteIsNotAvailable, @@ -693,7 +579,7 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro // Separate loop so we don't leave things in a half-updated state for _, prefix := range newRoutes { route := types.Route{} - err := hsdb.db.Preload("Node"). + err := tx.Preload("Node"). Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). First(&route).Error if err == nil { @@ -702,23 +588,23 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro // Mark already as primary if there is only this node offering this subnet // (and is not an exit route) if !route.IsExitRoute() { - route.IsPrimary = hsdb.isUniquePrefix(route) + route.IsPrimary = isUniquePrefix(tx, route) } - err = hsdb.db.Save(&route).Error + err = tx.Save(&route).Error if err != nil { - return fmt.Errorf("failed to enable route: %w", err) + return nil, fmt.Errorf("failed to enable route: %w", err) } } else { - return fmt.Errorf("failed to find route: %w", err) + return nil, fmt.Errorf("failed to find route: %w", err) } } // Ensure the node has the latest routes when notifying the other // nodes - nRoutes, err := hsdb.getNodeRoutes(node) + nRoutes, err := GetNodeRoutes(tx, node) if err != nil { - return fmt.Errorf("failed to read back routes: %w", err) + return nil, fmt.Errorf("failed to read back routes: %w", err) } node.Routes = nRoutes @@ -729,30 +615,11 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro Strs("routes", routeStrs). Msg("enabling routes") - stateUpdate := types.StateUpdate{ + return &types.StateUpdate{ Type: types.StatePeerChanged, ChangeNodes: types.Nodes{node}, - Message: "called from db.enableRoutes", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore( - stateUpdate, node.MachineKey.String()) - } - - // Send an update to the node itself with to ensure it - // has an updated packetfilter allowing the new route - // if it is defined in the ACL. - selfUpdate := types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: types.Nodes{node}, - } - if selfUpdate.Valid() { - hsdb.notifier.NotifyByMachineKey( - selfUpdate, - node.MachineKey) - } - - return nil + Message: "created in db.enableRoutes", + }, nil } func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { @@ -785,16 +652,23 @@ func (hsdb *HSDatabase) GenerateGivenName( mkey key.MachinePublic, suppliedName string, ) (string, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() + return Read(hsdb.DB, func(rx *gorm.DB) (string, error) { + return GenerateGivenName(rx, mkey, suppliedName) + }) +} +func GenerateGivenName( + tx *gorm.DB, + mkey key.MachinePublic, + suppliedName string, +) (string, error) { givenName, err := generateGivenName(suppliedName, false) if err != nil { return "", err } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - nodes, err := hsdb.listNodesByGivenName(givenName) + nodes, err := listNodesByGivenName(tx, givenName) if err != nil { return "", err } @@ -818,29 +692,28 @@ func (hsdb *HSDatabase) GenerateGivenName( return givenName, nil } -func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - users, err := hsdb.listUsers() +func ExpireEphemeralNodes(tx *gorm.DB, + inactivityThreshhold time.Duration, +) (types.StateUpdate, bool) { + users, err := ListUsers(tx) if err != nil { log.Error().Err(err).Msg("Error listing users") - return + return types.StateUpdate{}, false } + expired := make([]tailcfg.NodeID, 0) for _, user := range users { - nodes, err := hsdb.listNodesByUser(user.Name) + nodes, err := ListNodesByUser(tx, user.Name) if err != nil { log.Error(). Err(err). Str("user", user.Name). Msg("Error listing nodes in user") - return + return types.StateUpdate{}, false } - expired := make([]tailcfg.NodeID, 0) for idx, node := range nodes { if node.IsEphemeral() && node.LastSeen != nil && time.Now(). @@ -851,7 +724,8 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) Str("node", node.Hostname). Msg("Ephemeral client removed from database") - err = hsdb.deleteNode(nodes[idx]) + // empty isConnected map as ephemeral nodes are not routes + err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{}) if err != nil { log.Error(). Err(err). @@ -861,33 +735,35 @@ func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) } } - if len(expired) > 0 { - hsdb.notifier.NotifyAll(types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: expired, - }) - } + // TODO(kradalby): needs to be moved out of transaction } + if len(expired) > 0 { + return types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: expired, + }, true + } + + return types.StateUpdate{}, false } -func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func ExpireExpiredNodes(tx *gorm.DB, + lastCheck time.Time, +) (time.Time, types.StateUpdate, bool) { // use the time of the start of the function to ensure we // dont miss some nodes by returning it _after_ we have // checked everything. started := time.Now() - expiredNodes := make([]*types.Node, 0) + expired := make([]*tailcfg.PeerChange, 0) - nodes, err := hsdb.listNodes() + nodes, err := ListNodes(tx) if err != nil { log.Error(). Err(err). Msg("Error listing nodes to find expired nodes") - return time.Unix(0, 0) + return time.Unix(0, 0), types.StateUpdate{}, false } for index, node := range nodes { if node.IsExpired() && @@ -895,13 +771,17 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { // It will notify about all nodes that has been expired. // It should only notify about expired nodes since _last check_. node.Expiry.After(lastCheck) { - expiredNodes = append(expiredNodes, &nodes[index]) + expired = append(expired, &tailcfg.PeerChange{ + NodeID: tailcfg.NodeID(node.ID), + KeyExpiry: node.Expiry, + }) + now := time.Now() // Do not use setNodeExpiry as that has a notifier hook, which // can cause a deadlock, we are updating all changed nodes later // and there is no point in notifiying twice. - if err := hsdb.db.Model(&nodes[index]).Updates(types.Node{ - Expiry: &started, + if err := tx.Model(&nodes[index]).Updates(types.Node{ + Expiry: &now, }).Error; err != nil { log.Error(). Err(err). @@ -917,33 +797,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { } } - expired := make([]*tailcfg.PeerChange, len(expiredNodes)) - for idx, node := range expiredNodes { - expired[idx] = &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &started, - } + if len(expired) > 0 { + return started, types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: expired, + }, true } - // Inform the peers of a node with a lightweight update. - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: expired, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - - // Inform the node itself that it has expired. - for _, node := range expiredNodes { - stateSelfUpdate := types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: types.Nodes{node}, - } - if stateSelfUpdate.Valid() { - hsdb.notifier.NotifyByMachineKey(stateSelfUpdate, node.MachineKey) - } - } - - return started + return started, types.StateUpdate{}, false } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 140c264b..5e8eb294 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -23,7 +23,7 @@ func (s *Suite) TestGetNode(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -38,9 +38,9 @@ func (s *Suite) TestGetNode(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) } @@ -66,40 +66,12 @@ func (s *Suite) TestGetNodeByID(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.GetNodeByID(0) c.Assert(err, check.IsNil) } -func (s *Suite) TestGetNodeByNodeKey(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - _, err = db.GetNodeByNodeKey(nodeKey.Public()) - c.Assert(err, check.IsNil) -} - func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -124,7 +96,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) c.Assert(err, check.IsNil) @@ -146,12 +118,12 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(1), } - db.db.Save(&node) + db.DB.Save(&node) - err = db.DeleteNode(&node) + err = db.DeleteNode(&node, map[key.MachinePublic]bool{}) c.Assert(err, check.IsNil) - _, err = db.GetNode(user.Name, "testnode3") + _, err = db.getNode(user.Name, "testnode3") c.Assert(err, check.NotNil) } @@ -178,7 +150,7 @@ func (s *Suite) TestListPeers(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) } node0ByID, err := db.GetNodeByID(0) @@ -228,7 +200,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(stor[index%2].key.ID), } - db.db.Save(&node) + db.DB.Save(&node) } aclPolicy := &policy.ACLPolicy{ @@ -295,7 +267,7 @@ func (s *Suite) TestExpireNode(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -311,16 +283,19 @@ func (s *Suite) TestExpireNode(c *check.C) { AuthKeyID: uint(pak.ID), Expiry: &time.Time{}, } - db.db.Save(node) + db.DB.Save(node) - nodeFromDB, err := db.GetNode("test", "testnode") + nodeFromDB, err := db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB, check.NotNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, false) now := time.Now() - err = db.NodeSetExpiry(nodeFromDB, now) + err = db.NodeSetExpiry(nodeFromDB.ID, now) + c.Assert(err, check.IsNil) + + nodeFromDB, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, true) @@ -354,7 +329,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("user-1", "testnode") + _, err = db.getNode("user-1", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -372,7 +347,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") @@ -397,7 +372,7 @@ func (s *Suite) TestSetTags(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "testnode") + _, err = db.getNode("test", "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -412,21 +387,21 @@ func (s *Suite) TestSetTags(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(node) + db.DB.Save(node) // assign simple tags sTags := []string{"tag:test", "tag:foo"} - err = db.SetTags(node, sTags) + err = db.SetTags(node.ID, sTags) c.Assert(err, check.IsNil) - node, err = db.GetNode("test", "testnode") + node, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) // assign duplicat tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = db.SetTags(node, eTags) + err = db.SetTags(node.ID, eTags) c.Assert(err, check.IsNil) - node, err = db.GetNode("test", "testnode") + node, err = db.getNode("test", "testnode") c.Assert(err, check.IsNil) c.Assert( node.ForcedTags, @@ -601,7 +576,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, } - db.db.Save(&node) + db.DB.Save(&node) sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -610,7 +585,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { node0ByID, err := db.GetNodeByID(0) c.Assert(err, check.IsNil) - err = db.EnableAutoApprovedRoutes(pol, node0ByID) + // TODO(kradalby): Check state update + _, err = db.EnableAutoApprovedRoutes(pol, node0ByID) c.Assert(err, check.IsNil) enabledRoutes, err := db.GetEnabledRoutes(node0ByID) diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index e743988f..0fdb8221 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -20,7 +20,6 @@ var ( ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") ) -// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func (hsdb *HSDatabase) CreatePreAuthKey( userName string, reusable bool, @@ -28,11 +27,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey( expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() + return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { + return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags) + }) +} - user, err := hsdb.GetUser(userName) +// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. +func CreatePreAuthKey( + tx *gorm.DB, + userName string, + reusable bool, + ephemeral bool, + expiration *time.Time, + aclTags []string, +) (*types.PreAuthKey, error) { + user, err := GetUser(tx, userName) if err != nil { return nil, err } @@ -48,7 +57,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey( } now := time.Now().UTC() - kstr, err := hsdb.generateKey() + kstr, err := generateKey() if err != nil { return nil, err } @@ -63,29 +72,25 @@ func (hsdb *HSDatabase) CreatePreAuthKey( Expiration: expiration, } - err = hsdb.db.Transaction(func(db *gorm.DB) error { - if err := db.Save(&key).Error; err != nil { - return fmt.Errorf("failed to create key in the database: %w", err) - } + if err := tx.Save(&key).Error; err != nil { + return nil, fmt.Errorf("failed to create key in the database: %w", err) + } - if len(aclTags) > 0 { - seenTags := map[string]bool{} + if len(aclTags) > 0 { + seenTags := map[string]bool{} - for _, tag := range aclTags { - if !seenTags[tag] { - if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { - return fmt.Errorf( - "failed to ceate key tag in the database: %w", - err, - ) - } - seenTags[tag] = true + for _, tag := range aclTags { + if !seenTags[tag] { + if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { + return nil, fmt.Errorf( + "failed to ceate key tag in the database: %w", + err, + ) } + seenTags[tag] = true } } - - return nil - }) + } if err != nil { return nil, err @@ -94,22 +99,21 @@ func (hsdb *HSDatabase) CreatePreAuthKey( return &key, nil } -// ListPreAuthKeys returns the list of PreAuthKeys for a user. func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listPreAuthKeys(userName) + return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { + return ListPreAuthKeys(rx, userName) + }) } -func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - user, err := hsdb.getUser(userName) +// ListPreAuthKeys returns the list of PreAuthKeys for a user. +func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) { + user, err := GetUser(tx, userName) if err != nil { return nil, err } keys := []types.PreAuthKey{} - if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -117,11 +121,8 @@ func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, er } // GetPreAuthKey returns a PreAuthKey for a given key. -func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - pak, err := hsdb.ValidatePreAuthKey(key) +func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) { + pak, err := ValidatePreAuthKey(tx, key) if err != nil { return nil, err } @@ -135,15 +136,8 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. -func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.destroyPreAuthKey(pak) -} - -func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { - return hsdb.db.Transaction(func(db *gorm.DB) error { +func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error { + return tx.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error } @@ -156,12 +150,15 @@ func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { }) } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return ExpirePreAuthKey(tx, k) + }) +} - if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { +// MarkExpirePreAuthKey marks a PreAuthKey as expired. +func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { + if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -169,26 +166,26 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { } // UsePreAuthKey marks a PreAuthKey as used. -func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { k.Used = true - if err := hsdb.db.Save(k).Error; err != nil { + if err := tx.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } return nil } +func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { + return ValidatePreAuthKey(rx, k) + }) +} + // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. -func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { pak := types.PreAuthKey{} - if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( + if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -204,7 +201,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) } nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Where(&types.Node{AuthKeyID: uint(pak.ID)}). Find(&nodes).Error; err != nil { @@ -218,7 +215,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) return &pak, nil } -func (hsdb *HSDatabase) generateKey() (string, error) { +func generateKey() (string, error) { size := 24 bytes := make([]byte, size) if _, err := rand.Read(bytes); err != nil { diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index df9c2a10..003a396f 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -6,6 +6,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" + "gorm.io/gorm" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -41,7 +42,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { user, err := db.CreateUser("test2") c.Assert(err, check.IsNil) - now := time.Now() + now := time.Now().Add(-5 * time.Second) pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) c.Assert(err, check.IsNil) @@ -82,7 +83,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) @@ -103,7 +104,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) key, err := db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.IsNil) @@ -138,19 +139,22 @@ func (*Suite) TestEphemeralKey(c *check.C) { LastSeen: &now, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) _, err = db.ValidatePreAuthKey(pak.Key) // Ephemeral keys are by definition reusable c.Assert(err, check.IsNil) - _, err = db.GetNode("test7", "testest") + _, err = db.getNode("test7", "testest") c.Assert(err, check.IsNil) - db.ExpireEphemeralNodes(time.Second * 20) + db.DB.Transaction(func(tx *gorm.DB) error { + ExpireEphemeralNodes(tx, time.Second*20) + return nil + }) // The machine record should have been deleted - _, err = db.GetNode("test7", "testest") + _, err = db.getNode("test7", "testest") c.Assert(err, check.NotNil) } @@ -178,7 +182,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true - db.db.Save(&pak) + db.DB.Save(&pak) _, err = db.ValidatePreAuthKey(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 8ee91d6b..1ee144a7 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -7,23 +7,15 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" - "github.com/samber/lo" "gorm.io/gorm" "tailscale.com/types/key" ) var ErrRouteIsNotAvailable = errors.New("route is not available") -func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getRoutes() -} - -func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { +func GetRoutes(tx *gorm.DB) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Find(&routes).Error @@ -34,9 +26,9 @@ func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { +func getAdvertisedAndEnabledRoutes(tx *gorm.DB) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("advertised = ? AND enabled = ?", true, true). @@ -48,9 +40,9 @@ func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) { +func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("prefix = ?", types.IPPrefix(pref)). @@ -62,16 +54,9 @@ func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, erro return routes, nil } -func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeAdvertisedRoutes(node) -} - -func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { +func GetNodeAdvertisedRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("node_id = ? AND advertised = true", node.ID). @@ -84,15 +69,14 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, } func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeRoutes(node) + return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) { + return GetNodeRoutes(rx, node) + }) } -func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { +func GetNodeRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). Where("node_id = ?", node.ID). @@ -104,16 +88,9 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { return routes, nil } -func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getRoute(id) -} - -func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { +func GetRoute(tx *gorm.DB, id uint64) (*types.Route, error) { var route types.Route - err := hsdb.db. + err := tx. Preload("Node"). Preload("Node.User"). First(&route, id).Error @@ -124,40 +101,34 @@ func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { return &route, nil } -func (hsdb *HSDatabase) EnableRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.enableRoute(id) -} - -func (hsdb *HSDatabase) enableRoute(id uint64) error { - route, err := hsdb.getRoute(id) +func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if route.IsExitRoute() { - return hsdb.enableRoutes( + return enableRoutes( + tx, &route.Node, types.ExitRouteV4.String(), types.ExitRouteV6.String(), ) } - return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String()) + return enableRoutes(tx, &route.Node, netip.Prefix(route.Prefix).String()) } -func (hsdb *HSDatabase) DisableRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - route, err := hsdb.getRoute(id) +func DisableRoute(tx *gorm.DB, + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } var routes types.Routes @@ -166,64 +137,79 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 + var update *types.StateUpdate if !route.IsExitRoute() { - err = hsdb.failoverRouteWithNotify(route) + update, err = failoverRouteReturnUpdate(tx, isConnected, route) if err != nil { - return err + return nil, err } route.Enabled = false route.IsPrimary = false - err = hsdb.db.Save(route).Error + err = tx.Save(route).Error if err != nil { - return err + return nil, err } } else { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } for i := range routes { if routes[i].IsExitRoute() { routes[i].Enabled = false routes[i].IsPrimary = false - err = hsdb.db.Save(&routes[i]).Error + err = tx.Save(&routes[i]).Error if err != nil { - return err + return nil, err } } } } if routes == nil { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } } node.Routes = routes - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{&node}, - Message: "called from db.DisableRoute", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) + // If update is empty, it means that one was not created + // by failover (as a failover was not necessary), create + // one and return to the caller. + if update == nil { + update = &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{ + &node, + }, + Message: "called from db.DisableRoute", + } } - return nil + return update, nil } -func (hsdb *HSDatabase) DeleteRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +func (hsdb *HSDatabase) DeleteRoute( + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return DeleteRoute(tx, id, isConnected) + }) +} - route, err := hsdb.getRoute(id) +func DeleteRoute( + tx *gorm.DB, + id uint64, + isConnected map[key.MachinePublic]bool, +) (*types.StateUpdate, error) { + route, err := GetRoute(tx, id) if err != nil { - return err + return nil, err } var routes types.Routes @@ -232,19 +218,20 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 + var update *types.StateUpdate if !route.IsExitRoute() { - err := hsdb.failoverRouteWithNotify(route) + update, err = failoverRouteReturnUpdate(tx, isConnected, route) if err != nil { - return nil + return nil, nil } - if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { - return err + if err := tx.Unscoped().Delete(&route).Error; err != nil { + return nil, err } } else { - routes, err := hsdb.getNodeRoutes(&node) + routes, err := GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } routesToDelete := types.Routes{} @@ -254,56 +241,59 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { } } - if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { - return err + if err := tx.Unscoped().Delete(&routesToDelete).Error; err != nil { + return nil, err } } + // If update is empty, it means that one was not created + // by failover (as a failover was not necessary), create + // one and return to the caller. if routes == nil { - routes, err = hsdb.getNodeRoutes(&node) + routes, err = GetNodeRoutes(tx, &node) if err != nil { - return err + return nil, err } } node.Routes = routes - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{&node}, - Message: "called from db.DeleteRoute", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) + if update == nil { + update = &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{ + &node, + }, + Message: "called from db.DeleteRoute", + } } - return nil + return update, nil } -func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error { - routes, err := hsdb.getNodeRoutes(node) +func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error { + routes, err := GetNodeRoutes(tx, node) if err != nil { return err } for i := range routes { - if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { + if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil { return err } // TODO(kradalby): This is a bit too aggressive, we could probably // figure out which routes needs to be failed over rather than all. - hsdb.failoverRouteWithNotify(&routes[i]) + failoverRouteReturnUpdate(tx, isConnected, &routes[i]) } return nil } // isUniquePrefix returns if there is another node providing the same route already. -func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { +func isUniquePrefix(tx *gorm.DB, route types.Route) bool { var count int64 - hsdb.db. - Model(&types.Route{}). + tx.Model(&types.Route{}). Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?", route.Prefix, route.NodeID, @@ -312,9 +302,9 @@ func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { return count == 0 } -func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { +func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) { var route types.Route - err := hsdb.db. + err := tx. Preload("Node"). Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). First(&route).Error @@ -329,14 +319,17 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro return &route, nil } +func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Routes, error) { + return GetNodePrimaryRoutes(rx, node) + }) +} + // getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. -func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func GetNodePrimaryRoutes(tx *gorm.DB, node *types.Node) (types.Routes, error) { var routes types.Routes - err := hsdb.db. + err := tx. Preload("Node"). Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true). Find(&routes).Error @@ -347,22 +340,21 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er return routes, nil } +func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (bool, error) { + return SaveNodeRoutes(tx, node) + }) +} + // SaveNodeRoutes takes a node and updates the database with // the new routes. // It returns a bool whether an update should be sent as the // saved route impacts nodes. -func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.saveNodeRoutes(node) -} - -func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { +func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) { sendUpdate := false currentRoutes := types.Routes{} - err := hsdb.db.Where("node_id = ?", node.ID).Find(¤tRoutes).Error + err := tx.Where("node_id = ?", node.ID).Find(¤tRoutes).Error if err != nil { return sendUpdate, err } @@ -382,7 +374,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if !route.Advertised { currentRoutes[pos].Advertised = true - err := hsdb.db.Save(¤tRoutes[pos]).Error + err := tx.Save(¤tRoutes[pos]).Error if err != nil { return sendUpdate, err } @@ -398,7 +390,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { } else if route.Advertised { currentRoutes[pos].Advertised = false currentRoutes[pos].Enabled = false - err := hsdb.db.Save(¤tRoutes[pos]).Error + err := tx.Save(¤tRoutes[pos]).Error if err != nil { return sendUpdate, err } @@ -413,7 +405,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { Advertised: true, Enabled: false, } - err := hsdb.db.Create(&route).Error + err := tx.Create(&route).Error if err != nil { return sendUpdate, err } @@ -425,127 +417,89 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { // EnsureFailoverRouteIsAvailable takes a node and checks if the node's route // currently have a functioning host that exposes the network. -func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error { - nodeRoutes, err := hsdb.getNodeRoutes(node) +func EnsureFailoverRouteIsAvailable( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + node *types.Node, +) (*types.StateUpdate, error) { + nodeRoutes, err := GetNodeRoutes(tx, node) if err != nil { - return nil + return nil, nil } + var changedNodes types.Nodes for _, nodeRoute := range nodeRoutes { - routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix)) + routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix)) if err != nil { - return err + return nil, err } for _, route := range routes { if route.IsPrimary { // if we have a primary route, and the node is connected // nothing needs to be done. - if hsdb.notifier.IsConnected(route.Node.MachineKey) { + if isConnected[route.Node.MachineKey] { continue } // if not, we need to failover the route - err := hsdb.failoverRouteWithNotify(&route) + update, err := failoverRouteReturnUpdate(tx, isConnected, &route) if err != nil { - return err + return nil, err + } + + if update != nil { + changedNodes = append(changedNodes, update.ChangeNodes...) } } } } - return nil -} - -func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error { - routes, err := hsdb.getNodeRoutes(node) - if err != nil { - return nil - } - - var changedKeys []key.MachinePublic - - for _, route := range routes { - changed, err := hsdb.failoverRoute(&route) - if err != nil { - return err - } - - changedKeys = append(changedKeys, changed...) - } - - changedKeys = lo.Uniq(changedKeys) - - var nodes types.Nodes - - for _, key := range changedKeys { - node, err := hsdb.GetNodeByMachineKey(key) - if err != nil { - return err - } - - nodes = append(nodes, node) - } - - if nodes != nil { - stateUpdate := types.StateUpdate{ + if len(changedNodes) != 0 { + return &types.StateUpdate{ Type: types.StatePeerChanged, - ChangeNodes: nodes, - Message: "called from db.FailoverNodeRoutesWithNotify", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } + ChangeNodes: changedNodes, + Message: "called from db.EnsureFailoverRouteIsAvailable", + }, nil } - return nil + return nil, nil } -func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { - changedKeys, err := hsdb.failoverRoute(r) +func failoverRouteReturnUpdate( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + r *types.Route, +) (*types.StateUpdate, error) { + changedKeys, err := failoverRoute(tx, isConnected, r) if err != nil { - return err + return nil, err } + log.Trace(). + Interface("isConnected", isConnected). + Interface("changedKeys", changedKeys). + Msg("building route failover") + if len(changedKeys) == 0 { - return nil + return nil, nil } var nodes types.Nodes - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("loading machines with new primary routes from db") - for _, key := range changedKeys { - node, err := hsdb.getNodeByMachineKey(key) + node, err := GetNodeByMachineKey(tx, key) if err != nil { - return err + return nil, err } nodes = append(nodes, node) } - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("notifying peers about primary route change") - - if nodes != nil { - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: nodes, - Message: "called from db.failoverRouteWithNotify", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - } - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("notified peers about primary route change") - - return nil + return &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: nodes, + Message: "called from db.failoverRouteReturnUpdate", + }, nil } // failoverRoute takes a route that is no longer available, @@ -556,12 +510,16 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { // // and tries to find a new route to take over its place. // If the given route was not primary, it returns early. -func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) { +func failoverRoute( + tx *gorm.DB, + isConnected map[key.MachinePublic]bool, + r *types.Route, +) ([]key.MachinePublic, error) { if r == nil { return nil, nil } - // This route is not a primary route, and it isnt + // This route is not a primary route, and it is not // being served to nodes. if !r.IsPrimary { return nil, nil @@ -572,7 +530,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro return nil, nil } - routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix)) + routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix)) if err != nil { return nil, err } @@ -589,14 +547,14 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro continue } - if hsdb.notifier.IsConnected(route.Node.MachineKey) { + if isConnected[route.Node.MachineKey] { newPrimary = &routes[idx] break } } // If a new route was not found/available, - // return with an error. + // return without an error. // We do not want to update the database as // the one currently marked as primary is the // best we got. @@ -610,7 +568,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro // Remove primary from the old route r.IsPrimary = false - err = hsdb.db.Save(&r).Error + err = tx.Save(&r).Error if err != nil { log.Error().Err(err).Msg("error disabling new primary route") @@ -623,7 +581,7 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro // Set primary for the new primary newPrimary.IsPrimary = true - err = hsdb.db.Save(&newPrimary).Error + err = tx.Save(&newPrimary).Error if err != nil { log.Error().Err(err).Msg("error enabling new primary route") @@ -638,25 +596,26 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil } -// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. func (hsdb *HSDatabase) EnableAutoApprovedRoutes( aclPolicy *policy.ACLPolicy, node *types.Node, -) error { - if len(aclPolicy.AutoApprovers.ExitNode) == 0 && len(aclPolicy.AutoApprovers.Routes) == 0 { - // No autoapprovers configured - return nil - } +) (*types.StateUpdate, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return EnableAutoApprovedRoutes(tx, aclPolicy, node) + }) +} +// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. +func EnableAutoApprovedRoutes( + tx *gorm.DB, + aclPolicy *policy.ACLPolicy, + node *types.Node, +) (*types.StateUpdate, error) { if len(node.IPAddresses) == 0 { - // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs - return nil + return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs } - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - routes, err := hsdb.getNodeAdvertisedRoutes(node) + routes, err := GetNodeAdvertisedRoutes(tx, node) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error(). Caller(). @@ -664,7 +623,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Str("node", node.Hostname). Msg("Could not get advertised routes for node") - return err + return nil, err } log.Trace().Interface("routes", routes).Msg("routes for autoapproving") @@ -685,7 +644,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Uint64("nodeId", node.ID). Msg("Failed to resolve autoApprovers for advertised route") - return err + return nil, err } log.Trace(). @@ -706,7 +665,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( Str("alias", approvedAlias). Msg("Failed to expand alias when processing autoApprovers policy") - return err + return nil, err } // approvedIPs should contain all of node's IPs if it matches the rule, so check for first @@ -717,17 +676,25 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( } } + update := &types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{}, + Message: "created in db.EnableAutoApprovedRoutes", + } + for _, approvedRoute := range approvedRoutes { - err := hsdb.enableRoute(uint64(approvedRoute.ID)) + perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID)) if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). Uint64("nodeId", node.ID). Msg("Failed to enable approved route") - return err + return nil, err } + + update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...) } - return nil + return update, nil } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 15456071..3b544aa7 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" @@ -24,7 +23,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_get_route_node") + _, err = db.getNode("test", "test_get_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -42,7 +41,7 @@ func (s *Suite) TestGetRoutes(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo, } - db.db.Save(&node) + db.DB.Save(&node) su, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -52,10 +51,11 @@ func (s *Suite) TestGetRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(advertisedRoutes), check.Equals, 1) - err = db.enableRoutes(&node, "192.168.0.0/24") + // TODO(kradalby): check state update + _, err = db.enableRoutes(&node, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) } @@ -66,7 +66,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -91,7 +91,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo, } - db.db.Save(&node) + db.DB.Save(&node) sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) @@ -106,10 +106,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = db.enableRoutes(&node, "192.168.0.0/24") + _, err = db.enableRoutes(&node, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) enabledRoutes, err := db.GetEnabledRoutes(&node) @@ -117,14 +117,14 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = db.enableRoutes(&node, "10.0.0.0/24") + _, err = db.enableRoutes(&node, "10.0.0.0/24") c.Assert(err, check.IsNil) enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node) c.Assert(err, check.IsNil) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = db.enableRoutes(&node, "150.0.10.0/25") + _, err = db.enableRoutes(&node, "150.0.10.0/25") c.Assert(err, check.IsNil) enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node) @@ -139,7 +139,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -163,16 +163,16 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo1, } - db.db.Save(&node1) + db.DB.Save(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node1, route.String()) + _, err = db.enableRoutes(&node1, route.String()) c.Assert(err, check.IsNil) - err = db.enableRoutes(&node1, route2.String()) + _, err = db.enableRoutes(&node1, route2.String()) c.Assert(err, check.IsNil) hostInfo2 := tailcfg.Hostinfo{ @@ -186,13 +186,13 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { AuthKeyID: uint(pak.ID), Hostinfo: &hostInfo2, } - db.db.Save(&node2) + db.DB.Save(&node2) sendUpdate, err = db.SaveNodeRoutes(&node2) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node2, route2.String()) + _, err = db.enableRoutes(&node2, route2.String()) c.Assert(err, check.IsNil) enabledRoutes1, err := db.GetEnabledRoutes(&node1) @@ -219,7 +219,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.GetNode("test", "test_enable_route_node") + _, err = db.getNode("test", "test_enable_route_node") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( @@ -246,22 +246,23 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { Hostinfo: &hostInfo1, LastSeen: &now, } - db.db.Save(&node1) + db.DB.Save(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) c.Assert(sendUpdate, check.Equals, false) - err = db.enableRoutes(&node1, prefix.String()) + _, err = db.enableRoutes(&node1, prefix.String()) c.Assert(err, check.IsNil) - err = db.enableRoutes(&node1, prefix2.String()) + _, err = db.enableRoutes(&node1, prefix2.String()) c.Assert(err, check.IsNil) routes, err := db.GetNodeRoutes(&node1) c.Assert(err, check.IsNil) - err = db.DeleteRoute(uint64(routes[0].ID)) + // TODO(kradalby): check stateupdate + _, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{}) c.Assert(err, check.IsNil) enabledRoutes1, err := db.GetEnabledRoutes(&node1) @@ -269,17 +270,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { c.Assert(len(enabledRoutes1), check.Equals, 1) } +var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } + func TestFailoverRoute(t *testing.T) { - ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } - - // TODO(kradalby): Count/verify updates - var sink chan types.StateUpdate - - go func() { - for range sink { - } - }() - machineKeys := []key.MachinePublic{ key.NewMachine().Public(), key.NewMachine().Public(), @@ -291,6 +284,7 @@ func TestFailoverRoute(t *testing.T) { name string failingRoute types.Route routes types.Routes + isConnected map[key.MachinePublic]bool want []key.MachinePublic wantErr bool }{ @@ -397,6 +391,10 @@ func TestFailoverRoute(t *testing.T) { Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: false, + machineKeys[1]: true, + }, want: []key.MachinePublic{ machineKeys[0], machineKeys[1], @@ -491,6 +489,11 @@ func TestFailoverRoute(t *testing.T) { Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: true, + machineKeys[1]: true, + machineKeys[2]: true, + }, want: []key.MachinePublic{ machineKeys[1], machineKeys[0], @@ -535,6 +538,10 @@ func TestFailoverRoute(t *testing.T) { Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: true, + machineKeys[3]: false, + }, want: nil, wantErr: false, }, @@ -587,6 +594,11 @@ func TestFailoverRoute(t *testing.T) { Enabled: true, }, }, + isConnected: map[key.MachinePublic]bool{ + machineKeys[0]: false, + machineKeys[1]: true, + machineKeys[3]: false, + }, want: []key.MachinePublic{ machineKeys[0], machineKeys[1], @@ -641,13 +653,10 @@ func TestFailoverRoute(t *testing.T) { tmpDir, err := os.MkdirTemp("", "failover-db-test") assert.NoError(t, err) - notif := notifier.NewNotifier() - db, err = NewHeadscaleDatabase( "sqlite3", tmpDir+"/headscale_test.db", false, - notif, []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, @@ -655,23 +664,15 @@ func TestFailoverRoute(t *testing.T) { ) assert.NoError(t, err) - // Pretend that all the nodes are connected to control - for idx, key := range machineKeys { - // Pretend one node is offline - if idx == 3 { - continue - } - - notif.AddNode(key, sink) - } - for _, route := range tt.routes { - if err := db.db.Save(&route).Error; err != nil { + if err := db.DB.Save(&route).Error; err != nil { t.Fatalf("failed to create route: %s", err) } } - got, err := db.failoverRoute(&tt.failingRoute) + got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) { + return failoverRoute(tx, tt.isConnected, &tt.failingRoute) + }) if (err != nil) != tt.wantErr { t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) @@ -685,3 +686,231 @@ func TestFailoverRoute(t *testing.T) { }) } } + +// func TestDisableRouteFailover(t *testing.T) { +// machineKeys := []key.MachinePublic{ +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// key.NewMachine().Public(), +// } + +// tests := []struct { +// name string +// nodes types.Nodes + +// routeID uint64 +// isConnected map[key.MachinePublic]bool + +// wantMachineKey key.MachinePublic +// wantErr string +// }{ +// { +// name: "single-route", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// Node: types.Node{ +// MachineKey: machineKeys[0], +// }, +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[0], +// }, +// { +// name: "failover-simple", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// { +// name: "no-failover-offline", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// isConnected: map[key.MachinePublic]bool{ +// machineKeys[0]: true, +// machineKeys[1]: false, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// { +// name: "failover-to-online", +// nodes: types.Nodes{ +// &types.Node{ +// ID: 0, +// MachineKey: machineKeys[0], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 1, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: true, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// &types.Node{ +// ID: 1, +// MachineKey: machineKeys[1], +// Routes: []types.Route{ +// { +// Model: gorm.Model{ +// ID: 2, +// }, +// Prefix: ipp("10.0.0.0/24"), +// IsPrimary: false, +// }, +// }, +// Hostinfo: &tailcfg.Hostinfo{ +// RoutableIPs: []netip.Prefix{ +// netip.MustParsePrefix("10.0.0.0/24"), +// }, +// }, +// }, +// }, +// isConnected: map[key.MachinePublic]bool{ +// machineKeys[0]: true, +// machineKeys[1]: true, +// }, +// routeID: 1, +// wantMachineKey: machineKeys[1], +// }, +// } + +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "") +// assert.NoError(t, err) + +// // bootstrap db +// datab.DB.Transaction(func(tx *gorm.DB) error { +// for _, node := range tt.nodes { +// err := tx.Save(node).Error +// if err != nil { +// return err +// } + +// _, err = SaveNodeRoutes(tx, node) +// if err != nil { +// return err +// } +// } + +// return nil +// }) + +// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { +// return DisableRoute(tx, tt.routeID, tt.isConnected) +// }) + +// // if (err.Error() != "") != tt.wantErr { +// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) + +// // return +// // } + +// if len(got.ChangeNodes) != 1 { +// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes)) +// } + +// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" { +// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff) +// } +// }) +// } +// } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 1c384918..d4b11b14 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -6,7 +6,6 @@ import ( "os" "testing" - "github.com/juanfont/headscale/hscontrol/notifier" "gopkg.in/check.v1" ) @@ -48,7 +47,6 @@ func (s *Suite) ResetDB(c *check.C) { "sqlite3", tmpDir+"/headscale_test.db", false, - notifier.NewNotifier(), []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), }, diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 27a1406b..99e93393 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -15,22 +15,25 @@ var ( ErrUserStillHasNodes = errors.New("user not empty: node(s) found") ) +func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) { + return CreateUser(tx, name) + }) +} + // CreateUser creates a new User. Returns error if could not be created // or another user already exists. -func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func CreateUser(tx *gorm.DB, name string) (*types.User, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } user := types.User{} - if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { + if err := tx.Where("name = ?", name).First(&user).Error; err == nil { return nil, ErrUserExists } user.Name = name - if err := hsdb.db.Create(&user).Error; err != nil { + if err := tx.Create(&user).Error; err != nil { log.Error(). Str("func", "CreateUser"). Err(err). @@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { return &user, nil } +func (hsdb *HSDatabase) DestroyUser(name string) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DestroyUser(tx, name) + }) +} + // DestroyUser destroys a User. Returns error if the User does // not exist or if there are nodes associated with it. -func (hsdb *HSDatabase) DestroyUser(name string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - user, err := hsdb.getUser(name) +func DestroyUser(tx *gorm.DB, name string) error { + user, err := GetUser(tx, name) if err != nil { return ErrUserNotFound } - nodes, err := hsdb.listNodesByUser(name) + nodes, err := ListNodesByUser(tx, name) if err != nil { return err } @@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := hsdb.listPreAuthKeys(name) + keys, err := ListPreAuthKeys(tx, name) if err != nil { return err } for _, key := range keys { - err = hsdb.destroyPreAuthKey(key) + err = DestroyPreAuthKey(tx, key) if err != nil { return err } } - if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil { + if result := tx.Unscoped().Delete(&user); result.Error != nil { return result.Error } return nil } +func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { + return hsdb.Write(func(tx *gorm.DB) error { + return RenameUser(tx, oldName, newName) + }) +} + // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func RenameUser(tx *gorm.DB, oldName, newName string) error { var err error - oldUser, err := hsdb.getUser(oldName) + oldUser, err := GetUser(tx, oldName) if err != nil { return err } @@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { if err != nil { return err } - _, err = hsdb.getUser(newName) + _, err = GetUser(tx, newName) if err == nil { return ErrUserExists } @@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { oldUser.Name = newName - if result := hsdb.db.Save(&oldUser); result.Error != nil { + if result := tx.Save(&oldUser); result.Error != nil { return result.Error } return nil } -// GetUser fetches a user by name. func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getUser(name) + return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUser(rx, name) + }) } -func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { +func GetUser(tx *gorm.DB, name string) (*types.User, error) { user := types.User{} - if result := hsdb.db.First(&user, "name = ?", name); errors.Is( + if result := tx.First(&user, "name = ?", name); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { return &user, nil } -// ListUsers gets all the existing users. func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listUsers() + return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { + return ListUsers(rx) + }) } -func (hsdb *HSDatabase) listUsers() ([]types.User, error) { +// ListUsers gets all the existing users. +func ListUsers(tx *gorm.DB) ([]types.User, error) { users := []types.User{} - if err := hsdb.db.Find(&users).Error; err != nil { + if err := tx.Find(&users).Error; err != nil { return nil, err } @@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) { } // ListNodesByUser gets all the nodes in a given user. -func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodesByUser(name) -} - -func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) { +func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } - user, err := hsdb.getUser(name) + user, err := GetUser(tx, name) if err != nil { return nil, err } nodes := types.Nodes{} - if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { + if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } -// AssignNodeToUser assigns a Node to a user. func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return AssignNodeToUser(tx, node, username) + }) +} +// AssignNodeToUser assigns a Node to a user. +func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error { err := util.CheckForFQDNRules(username) if err != nil { return err } - user, err := hsdb.getUser(username) + user, err := GetUser(tx, username) if err != nil { return err } node.User = *user - if result := hsdb.db.Save(&node); result.Error != nil { + if result := tx.Save(&node); result.Error != nil { return result.Error } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 1ca3b49f..b36e8613 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -36,7 +36,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { err = db.DestroyUser("test") c.Assert(err, check.IsNil) - result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) + result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) // destroying a user also deletes all associated preauthkeys c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) @@ -53,7 +53,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) err = db.DestroyUser("test") c.Assert(err, check.Equals, ErrUserStillHasNodes) @@ -105,7 +105,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { RegisterMethod: util.RegisterMethodAuthKey, AuthKeyID: uint(pak.ID), } - db.db.Save(&node) + db.DB.Save(&node) c.Assert(node.UserID, check.Equals, oldUser.ID) err = db.AssignNodeToUser(&node, newUser.Name) diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index ad325c7a..52a63e9f 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -211,7 +211,7 @@ func DERPProbeHandler( // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // An example implementation is found here https://derp.tailscale.com/bootstrap-dns -// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL +// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL. func DERPBootstrapDNSHandler( derpMap *tailcfg.DERPMap, ) func(http.ResponseWriter, *http.Request) { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index ffd3a576..c12ba73d 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -8,11 +8,13 @@ import ( "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -136,12 +138,14 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( ctx context.Context, request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { - preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key) - if err != nil { - return nil, err - } + err := api.h.db.DB.Transaction(func(tx *gorm.DB) error { + preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key) + if err != nil { + return err + } - err = api.h.db.ExpirePreAuthKey(preAuthKey) + return db.ExpirePreAuthKey(tx, preAuthKey) + }) if err != nil { return nil, err } @@ -181,17 +185,31 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } - node, err := api.h.db.RegisterNodeFromAuthCallback( - api.h.registrationCache, - mkey, - request.GetUser(), - nil, - util.RegisterMethodCLI, - ) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + return db.RegisterNodeFromAuthCallback( + tx, + api.h.registrationCache, + mkey, + request.GetUser(), + nil, + util.RegisterMethodCLI, + api.h.cfg.IPPrefixes, + ) + }) if err != nil { return nil, err } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.RegisterNode", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -217,25 +235,35 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) - if err != nil { - return nil, err - } - for _, tag := range request.GetTags() { err := validateTag(tag) if err != nil { - return &v1.SetTagsResponse{ - Node: nil, - }, status.Error(codes.InvalidArgument, err.Error()) + return nil, err } } - err = api.h.db.SetTags(node, request.GetTags()) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := db.SetTags(tx, request.GetNodeId(), request.GetTags()) + if err != nil { + return nil, err + } + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return &v1.SetTagsResponse{ Node: nil, - }, status.Error(codes.Internal, err.Error()) + }, status.Error(codes.InvalidArgument, err.Error()) + } + + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.SetTags", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } log.Trace(). @@ -270,11 +298,21 @@ func (api headscaleV1APIServer) DeleteNode( err = api.h.db.DeleteNode( node, + api.h.nodeNotifier.ConnectedMap(), ) if err != nil { return nil, err } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname) + api.h.nodeNotifier.NotifyAll(ctx, stateUpdate) + } + return &v1.DeleteNodeResponse{}, nil } @@ -282,17 +320,38 @@ func (api headscaleV1APIServer) ExpireNode( ctx context.Context, request *v1.ExpireNodeRequest, ) (*v1.ExpireNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + now := time.Now() + + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + db.NodeSetExpiry( + tx, + request.GetNodeId(), + now, + ) + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return nil, err } - now := time.Now() + selfUpdate := types.StateUpdate{ + Type: types.StateSelfUpdate, + ChangeNodes: types.Nodes{node}, + } + if selfUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) + api.h.nodeNotifier.NotifyByMachineKey( + ctx, + selfUpdate, + node.MachineKey) + } - api.h.db.NodeSetExpiry( - node, - now, - ) + stateUpdate := types.StateUpdateExpire(node.ID, now) + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } log.Trace(). Str("node", node.Hostname). @@ -306,17 +365,30 @@ func (api headscaleV1APIServer) RenameNode( ctx context.Context, request *v1.RenameNodeRequest, ) (*v1.RenameNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + err := db.RenameNode( + tx, + request.GetNodeId(), + request.GetNewName(), + ) + if err != nil { + return nil, err + } + + return db.GetNodeByID(tx, request.GetNodeId()) + }) if err != nil { return nil, err } - err = api.h.db.RenameNode( - node, - request.GetNewName(), - ) - if err != nil { - return nil, err + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from api.RenameNode", + } + if stateUpdate.Valid() { + ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname) + api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) } log.Trace(). @@ -331,8 +403,11 @@ func (api headscaleV1APIServer) ListNodes( ctx context.Context, request *v1.ListNodesRequest, ) (*v1.ListNodesResponse, error) { + isConnected := api.h.nodeNotifier.ConnectedMap() if request.GetUser() != "" { - nodes, err := api.h.db.ListNodesByUser(request.GetUser()) + nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { + return db.ListNodesByUser(rx, request.GetUser()) + }) if err != nil { return nil, err } @@ -343,7 +418,7 @@ func (api headscaleV1APIServer) ListNodes( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + resp.Online = isConnected[node.MachineKey] response[index] = resp } @@ -362,10 +437,10 @@ func (api headscaleV1APIServer) ListNodes( // Populate the online field based on // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + resp.Online = isConnected[node.MachineKey] validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - &node, + node, ) resp.InvalidTags = invalidTags resp.ValidTags = validTags @@ -396,7 +471,9 @@ func (api headscaleV1APIServer) GetRoutes( ctx context.Context, request *v1.GetRoutesRequest, ) (*v1.GetRoutesResponse, error) { - routes, err := api.h.db.GetRoutes() + routes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Routes, error) { + return db.GetRoutes(rx) + }) if err != nil { return nil, err } @@ -410,11 +487,19 @@ func (api headscaleV1APIServer) EnableRoute( ctx context.Context, request *v1.EnableRouteRequest, ) (*v1.EnableRouteResponse, error) { - err := api.h.db.EnableRoute(request.GetRouteId()) + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.EnableRoute(tx, request.GetRouteId()) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown") + api.h.nodeNotifier.NotifyAll( + ctx, *update) + } + return &v1.EnableRouteResponse{}, nil } @@ -422,11 +507,19 @@ func (api headscaleV1APIServer) DisableRoute( ctx context.Context, request *v1.DisableRouteRequest, ) (*v1.DisableRouteResponse, error) { - err := api.h.db.DisableRoute(request.GetRouteId()) + isConnected := api.h.nodeNotifier.ConnectedMap() + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.DisableRoute(tx, request.GetRouteId(), isConnected) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown") + api.h.nodeNotifier.NotifyAll(ctx, *update) + } + return &v1.DisableRouteResponse{}, nil } @@ -453,11 +546,19 @@ func (api headscaleV1APIServer) DeleteRoute( ctx context.Context, request *v1.DeleteRouteRequest, ) (*v1.DeleteRouteResponse, error) { - err := api.h.db.DeleteRoute(request.GetRouteId()) + isConnected := api.h.nodeNotifier.ConnectedMap() + update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.DeleteRoute(tx, request.GetRouteId(), isConnected) + }) if err != nil { return nil, err } + if update != nil && update.Valid() { + ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown") + api.h.nodeNotifier.NotifyWithIgnore(ctx, *update) + } + return &v1.DeleteRouteResponse{}, nil } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 9998f128..df0f4d9c 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -272,6 +272,7 @@ func (m *Mapper) LiteMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, pol *policy.ACLPolicy, + messages ...string, ) ([]byte, error) { resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) if err != nil { @@ -290,7 +291,7 @@ func (m *Mapper) LiteMapResponse( resp.PacketFilter = policy.ReduceFilterRules(node, rules) resp.SSHPolicy = sshPolicy - return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress) + return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...) } func (m *Mapper) KeepAliveResponse( @@ -392,9 +393,7 @@ func (m *Mapper) PeerChangedPatchResponse( } if patches, ok := m.patches[uint64(change.NodeID)]; ok { - patches := append(patches, p) - - m.patches[uint64(change.NodeID)] = patches + m.patches[uint64(change.NodeID)] = append(patches, p) } else { m.patches[uint64(change.NodeID)] = []patch{p} } @@ -470,6 +469,8 @@ func (m *Mapper) marshalMapResponse( switch { case resp.Peers != nil && len(resp.Peers) > 0: responseType = "full" + case isSelfUpdate(messages...): + responseType = "self" case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil: responseType = "lite" case resp.PeersChanged != nil && len(resp.PeersChanged) > 0: @@ -668,3 +669,13 @@ func appendPeerChanges( return nil } + +func isSelfUpdate(messages ...string) bool { + for _, message := range messages { + if strings.Contains(message, types.SelfUpdateIdentifier) { + return true + } + } + + return false +} diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index e213a951..c10da4de 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -72,7 +72,7 @@ func tailNode( } var derp string - if node.Hostinfo.NetInfo != nil { + if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil { derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) } else { derp = "127.3.3.40:0" // Zero means disconnected or unknown. diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 77e8b199..2384a40f 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -1,6 +1,7 @@ package notifier import ( + "context" "fmt" "strings" "sync" @@ -12,26 +13,30 @@ import ( ) type Notifier struct { - l sync.RWMutex - nodes map[string]chan<- types.StateUpdate + l sync.RWMutex + nodes map[string]chan<- types.StateUpdate + connected map[key.MachinePublic]bool } func NewNotifier() *Notifier { - return &Notifier{} + return &Notifier{ + nodes: make(map[string]chan<- types.StateUpdate), + connected: make(map[key.MachinePublic]bool), + } } func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) { log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node") - defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node") + defer log.Trace(). + Caller(). + Str("key", machineKey.ShortString()). + Msg("releasing lock to add node") n.l.Lock() defer n.l.Unlock() - if n.nodes == nil { - n.nodes = make(map[string]chan<- types.StateUpdate) - } - n.nodes[machineKey.String()] = c + n.connected[machineKey] = true log.Trace(). Str("machine_key", machineKey.ShortString()). @@ -41,16 +46,20 @@ func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpd func (n *Notifier) RemoveNode(machineKey key.MachinePublic) { log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node") - defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node") + defer log.Trace(). + Caller(). + Str("key", machineKey.ShortString()). + Msg("releasing lock to remove node") n.l.Lock() defer n.l.Unlock() - if n.nodes == nil { + if len(n.nodes) == 0 { return } delete(n.nodes, machineKey.String()) + n.connected[machineKey] = false log.Trace(). Str("machine_key", machineKey.ShortString()). @@ -64,23 +73,28 @@ func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool { n.l.RLock() defer n.l.RUnlock() - if _, ok := n.nodes[machineKey.String()]; ok { - return true - } - - return false + return n.connected[machineKey] } -func (n *Notifier) NotifyAll(update types.StateUpdate) { - n.NotifyWithIgnore(update) +// TODO(kradalby): This returns a pointer and can be dangerous. +func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool { + return n.connected } -func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { +func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) { + n.NotifyWithIgnore(ctx, update) +} + +func (n *Notifier) NotifyWithIgnore( + ctx context.Context, + update types.StateUpdate, + ignore ...string, +) { log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Interface("type", update.Type). - Msg("releasing lock, finished notifing") + Msg("releasing lock, finished notifying") n.l.RLock() defer n.l.RUnlock() @@ -90,23 +104,58 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) continue } - log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update") - c <- update + select { + case <-ctx.Done(): + log.Error(). + Err(ctx.Err()). + Str("mkey", key). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update not sent, context cancelled") + + return + case c <- update: + log.Trace(). + Str("mkey", key). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update successfully sent on chan") + } } } -func (n *Notifier) NotifyByMachineKey(update types.StateUpdate, mKey key.MachinePublic) { +func (n *Notifier) NotifyByMachineKey( + ctx context.Context, + update types.StateUpdate, + mKey key.MachinePublic, +) { log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") defer log.Trace(). Caller(). Interface("type", update.Type). - Msg("releasing lock, finished notifing") + Msg("releasing lock, finished notifying") n.l.RLock() defer n.l.RUnlock() if c, ok := n.nodes[mKey.String()]; ok { - c <- update + select { + case <-ctx.Done(): + log.Error(). + Err(ctx.Err()). + Str("mkey", mKey.String()). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update not sent, context cancelled") + + return + case c <- update: + log.Trace(). + Str("mkey", mKey.String()). + Any("origin", ctx.Value("origin")). + Any("hostname", ctx.Value("hostname")). + Msgf("update successfully sent on chan") + } } } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 568519fd..a0fc9310 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -20,6 +20,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" + "gorm.io/gorm" "tailscale.com/types/key" ) @@ -492,7 +493,7 @@ func (h *Headscale) validateNodeForOIDCCallback( Str("node", node.Hostname). Msg("node already registered, reauthenticating") - err := h.db.NodeSetExpiry(node, expiry) + err := h.db.NodeSetExpiry(node.ID, expiry) if err != nil { util.LogErr(err, "Failed to refresh node") http.Error( @@ -536,6 +537,12 @@ func (h *Headscale) validateNodeForOIDCCallback( util.LogErr(err, "Failed to write response") } + stateUpdate := types.StateUpdateExpire(node.ID, expiry) + if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String()) + } + return nil, true, nil } @@ -613,14 +620,22 @@ func (h *Headscale) registerNodeForOIDCCallback( machineKey *key.MachinePublic, expiry time.Time, ) error { - if _, err := h.db.RegisterNodeFromAuthCallback( - // TODO(kradalby): find a better way to use the cache across modules - h.registrationCache, - *machineKey, - user.Name, - &expiry, - util.RegisterMethodOIDC, - ); err != nil { + if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + if _, err := db.RegisterNodeFromAuthCallback( + // TODO(kradalby): find a better way to use the cache across modules + tx, + h.registrationCache, + *machineKey, + user.Name, + &expiry, + util.RegisterMethodOIDC, + h.cfg.IPPrefixes, + ); err != nil { + return err + } + + return nil + }); err != nil { util.LogErr(err, "could not register node") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 1dd664c8..2ccc56b4 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -905,32 +905,39 @@ func (pol *ACLPolicy) TagsOfNode( validTags := make([]string, 0) invalidTags := make([]string, 0) + // TODO(kradalby): Why is this sometimes nil? coming from tailNode? + if node == nil { + return validTags, invalidTags + } + validTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool) - for _, tag := range node.Hostinfo.RequestTags { - owners, err := expandOwnersFromTag(pol, tag) - if errors.Is(err, ErrInvalidTag) { - invalidTagMap[tag] = true + if node.Hostinfo != nil { + for _, tag := range node.Hostinfo.RequestTags { + owners, err := expandOwnersFromTag(pol, tag) + if errors.Is(err, ErrInvalidTag) { + invalidTagMap[tag] = true - continue - } - var found bool - for _, owner := range owners { - if node.User.Name == owner { - found = true + continue + } + var found bool + for _, owner := range owners { + if node.User.Name == owner { + found = true + } + } + if found { + validTagMap[tag] = true + } else { + invalidTagMap[tag] = true } } - if found { - validTagMap[tag] = true - } else { - invalidTagMap[tag] = true + for tag := range invalidTagMap { + invalidTags = append(invalidTags, tag) + } + for tag := range validTagMap { + validTags = append(validTags, tag) } - } - for tag := range invalidTagMap { - invalidTags = append(invalidTags, tag) - } - for tag := range validTagMap { - validTags = append(validTags, tag) } return validTags, invalidTags diff --git a/hscontrol/poll.go b/hscontrol/poll.go index c867f261..f00152d4 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -4,12 +4,15 @@ import ( "context" "fmt" "net/http" + "strings" "time" + "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" xslices "golang.org/x/exp/slices" + "gorm.io/gorm" "tailscale.com/tailcfg" ) @@ -128,10 +131,14 @@ func (h *Headscale) handlePoll( if h.ACLPolicy != nil { // update routes with peer information - err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) + update, err := h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) if err != nil { logErr(err, "Error running auto approved routes") } + + if update != nil { + sendUpdate = true + } } } @@ -146,7 +153,7 @@ func (h *Headscale) handlePoll( } if sendUpdate { - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) @@ -161,7 +168,9 @@ func (h *Headscale) handlePoll( Message: "called from handlePoll -> update -> new hostinfo", } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-hostinfochange", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } @@ -174,7 +183,9 @@ func (h *Headscale) handlePoll( ChangeNodes: types.Nodes{node}, } if selfUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", node.Hostname) h.nodeNotifier.NotifyByMachineKey( + ctx, selfUpdate, node.MachineKey) } @@ -183,7 +194,7 @@ func (h *Headscale) handlePoll( } } - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) @@ -195,7 +206,9 @@ func (h *Headscale) handlePoll( ChangePatches: []*tailcfg.PeerChange{&change}, } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } @@ -251,7 +264,7 @@ func (h *Headscale) handlePoll( } } - if err := h.db.NodeSave(node); err != nil { + if err := h.db.DB.Save(node).Error; err != nil { logErr(err, "Failed to persist/update node in the database") http.Error(writer, "", http.StatusInternalServerError) @@ -288,7 +301,10 @@ func (h *Headscale) handlePoll( // update ACLRules with peer informations (to update server tags if necessary) if h.ACLPolicy != nil { // update routes with peer information - err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) + // This state update is ignored as it will be sent + // as part of the whole node + // TODO(kradalby): figure out if that is actually correct + _, err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) if err != nil { logErr(err, "Error running auto approved routes") } @@ -324,11 +340,17 @@ func (h *Headscale) handlePoll( Message: "called from handlePoll -> new node added", } if stateUpdate.Valid() { + ctx := types.NotifyCtx(context.Background(), "poll-newnode-peers", node.Hostname) h.nodeNotifier.NotifyWithIgnore( + ctx, stateUpdate, node.MachineKey.String()) } + if len(node.Routes) > 0 { + go h.pollFailoverRoutes(logErr, "new node", node) + } + // Set up the client stream h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() @@ -346,15 +368,9 @@ func (h *Headscale) handlePoll( keepAliveTicker := time.NewTicker(keepAliveInterval) - ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname) - - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname)) defer cancel() - if len(node.Routes) > 0 { - go h.db.EnsureFailoverRouteIsAvailable(node) - } - for { logInfo("Waiting for update on stream channel") select { @@ -403,6 +419,7 @@ func (h *Headscale) handlePoll( return } + startMapResp := time.Now() switch update.Type { case types.StateFullUpdate: logInfo("Sending Full MapResponse") @@ -411,6 +428,7 @@ func (h *Headscale) handlePoll( case types.StatePeerChanged: logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message)) + isConnectedMap := h.nodeNotifier.ConnectedMap() for _, node := range update.ChangeNodes { // If a node is not reported to be online, it might be // because the value is outdated, check with the notifier. @@ -418,7 +436,7 @@ func (h *Headscale) handlePoll( // this might be because it has announced itself, but not // reached the stage to actually create the notifier channel. if node.IsOnline != nil && !*node.IsOnline { - isOnline := h.nodeNotifier.IsConnected(node.MachineKey) + isOnline := isConnectedMap[node.MachineKey] node.IsOnline = &isOnline } } @@ -434,7 +452,7 @@ func (h *Headscale) handlePoll( if len(update.ChangeNodes) == 1 { logInfo("Sending SelfUpdate MapResponse") node = update.ChangeNodes[0] - data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy) + data, err = mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy, types.SelfUpdateIdentifier) } else { logInfo("SelfUpdate contained too many nodes, this is likely a bug in the code, please report.") } @@ -449,8 +467,11 @@ func (h *Headscale) handlePoll( return } + log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startMapResp).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished making map response") + // Only send update if there is change if data != nil { + startWrite := time.Now() _, err = writer.Write(data) if err != nil { logErr(err, "Could not write the map response") @@ -468,6 +489,7 @@ func (h *Headscale) handlePoll( return } + log.Trace().Str("node", node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", node.MachineKey.String()).Int("type", int(update.Type)).Msg("finished writing mapresp to node") log.Info(). Caller(). @@ -487,7 +509,7 @@ func (h *Headscale) handlePoll( go h.updateNodeOnlineStatus(false, node) // Failover the node's routes if any. - go h.db.FailoverNodeRoutesWithNotify(node) + go h.pollFailoverRoutes(logErr, "node closing connection", node) // The connection has been closed, so we can stop polling. return @@ -500,6 +522,22 @@ func (h *Headscale) handlePoll( } } +func (h *Headscale) pollFailoverRoutes(logErr func(error, string), where string, node *types.Node) { + update, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { + return db.EnsureFailoverRouteIsAvailable(tx, h.nodeNotifier.ConnectedMap(), node) + }) + if err != nil { + logErr(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) + + return + } + + if update != nil && !update.Empty() && update.Valid() { + ctx := types.NotifyCtx(context.Background(), fmt.Sprintf("poll-%s-routes-ensurefailover", strings.ReplaceAll(where, " ", "-")), node.Hostname) + h.nodeNotifier.NotifyWithIgnore(ctx, *update, node.MachineKey.String()) + } +} + // updateNodeOnlineStatus records the last seen status of a node and notifies peers // about change in their online/offline status. // It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. @@ -519,10 +557,13 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { }, } if statusUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String()) + ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-onlinestatus", node.Hostname) + h.nodeNotifier.NotifyWithIgnore(ctx, statusUpdate, node.MachineKey.String()) } - err := h.db.UpdateLastSeen(node) + err := h.db.DB.Transaction(func(tx *gorm.DB) error { + return db.UpdateLastSeen(tx, node.ID, *node.LastSeen) + }) if err != nil { log.Error().Err(err).Msg("Cannot update node LastSeen") diff --git a/hscontrol/poll_noise.go b/hscontrol/poll_noise.go index 675836a6..53b1d47e 100644 --- a/hscontrol/poll_noise.go +++ b/hscontrol/poll_noise.go @@ -13,7 +13,7 @@ import ( ) const ( - MinimumCapVersion tailcfg.CapabilityVersion = 56 + MinimumCapVersion tailcfg.CapabilityVersion = 58 ) // NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index e38d8e39..d45f9d4c 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -1,15 +1,19 @@ package types import ( + "context" "database/sql/driver" "encoding/json" "errors" "fmt" "net/netip" + "time" "tailscale.com/tailcfg" ) +const SelfUpdateIdentifier = "self-update" + var ErrCannotParsePrefix = errors.New("cannot parse prefix") type IPPrefix netip.Prefix @@ -160,3 +164,37 @@ func (su *StateUpdate) Valid() bool { return true } + +// Empty reports if there are any updates in the StateUpdate. +func (su *StateUpdate) Empty() bool { + switch su.Type { + case StatePeerChanged: + return len(su.ChangeNodes) == 0 + case StatePeerChangedPatch: + return len(su.ChangePatches) == 0 + case StatePeerRemoved: + return len(su.Removed) == 0 + } + + return false +} + +func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate { + return StateUpdate{ + Type: StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: tailcfg.NodeID(nodeID), + KeyExpiry: &expiry, + }, + }, + } +} + +func NotifyCtx(ctx context.Context, origin, hostname string) context.Context { + ctx2, _ := context.WithTimeout( + context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin), + 3*time.Second, + ) + return ctx2 +} diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 7f6b40ed..0b8324f2 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,7 +2,6 @@ package types import ( "strconv" - "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -22,12 +21,13 @@ type User struct { func (n *User) TailscaleUser() *tailcfg.User { user := tailcfg.User{ - ID: tailcfg.UserID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, + ID: tailcfg.UserID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + // TODO(kradalby): See if we can fill in Gravatar here ProfilePicURL: "", Logins: []tailcfg.LoginID{}, - Created: time.Time{}, + Created: n.CreatedAt, } return &user @@ -35,9 +35,10 @@ func (n *User) TailscaleUser() *tailcfg.User { func (n *User) TailscaleLogin() *tailcfg.Login { login := tailcfg.Login{ - ID: tailcfg.LoginID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, + ID: tailcfg.LoginID(n.ID), + LoginName: n.Name, + DisplayName: n.Name, + // TODO(kradalby): See if we can fill in Gravatar here ProfilePicURL: "", } diff --git a/integration/cli_test.go b/integration/cli_test.go index d2d741e0..e6190fb7 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1240,7 +1240,7 @@ func TestNodeRenameCommand(t *testing.T) { assert.Contains(t, listAll[4].GetGivenName(), "node-5") for idx := 0; idx < 3; idx++ { - _, err := headscale.Execute( + res, err := headscale.Execute( []string{ "headscale", "nodes", @@ -1251,6 +1251,8 @@ func TestNodeRenameCommand(t *testing.T) { }, ) assert.Nil(t, err) + + assert.Contains(t, res, "Node renamed") } var listAllAfterRename []v1.Node diff --git a/integration/general_test.go b/integration/general_test.go index 15c3a72c..5c98cd26 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -26,6 +26,8 @@ func TestPingAllByIP(t *testing.T) { assertNoErr(t, err) defer scenario.Shutdown() + // TODO(kradalby): it does not look like the user thing works, only second + // get created? maybe only when many? spec := map[string]int{ "user1": len(MustTestVersions), "user2": len(MustTestVersions), @@ -321,7 +323,12 @@ func TestTaildrop(t *testing.T) { t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err) } } - curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"} + curlCommand := []string{ + "curl", + "--unix-socket", + "/var/run/tailscale/tailscaled.sock", + "http://local-tailscaled.sock/localapi/v0/file-targets", + } err = retry(10, 1*time.Second, func() error { result, _, err := client.Execute(curlCommand) if err != nil { @@ -338,13 +345,23 @@ func TestTaildrop(t *testing.T) { for _, ft := range fts { ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) } - return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr) + return fmt.Errorf( + "client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", + client.Hostname(), + len(fts), + len(allClients)-1, + ftStr, + ) } return err }) if err != nil { - t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err) + t.Errorf( + "failed to query localapi for filetarget on %s, err: %s", + client.Hostname(), + err, + ) } } @@ -434,72 +451,6 @@ func TestTaildrop(t *testing.T) { } } -func TestResolveMagicDNS(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "magicdns1": len(MustTestVersions), - "magicdns2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) - assertNoErrHeadscaleEnv(t, err) - - allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - - // Poor mans cache - _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) - - _, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) - - for _, client := range allClients { - for _, peer := range allClients { - // It is safe to ignore this error as we handled it when caching it - peerFQDN, _ := peer.FQDN() - - command := []string{ - "tailscale", - "ip", peerFQDN, - } - result, _, err := client.Execute(command) - if err != nil { - t.Fatalf( - "failed to execute resolve/ip command %s from %s: %s", - peerFQDN, - client.Hostname(), - err, - ) - } - - ips, err := peer.IPs() - if err != nil { - t.Fatalf( - "failed to get ips for %s: %s", - peer.Hostname(), - err, - ) - } - - for _, ip := range ips { - if !strings.Contains(result, ip.String()) { - t.Fatalf("ip %s is not found in \n%s\n", ip.String(), result) - } - } - } - } -} - func TestExpireNode(t *testing.T) { IntegrationSkip(t) t.Parallel() @@ -545,7 +496,7 @@ func TestExpireNode(t *testing.T) { // TODO(kradalby): This is Headscale specific and would not play nicely // with other implementations of the ControlServer interface result, err := headscale.Execute([]string{ - "headscale", "nodes", "expire", "--identifier", "0", "--output", "json", + "headscale", "nodes", "expire", "--identifier", "1", "--output", "json", }) assertNoErr(t, err) @@ -576,16 +527,38 @@ func TestExpireNode(t *testing.T) { assertNotNil(t, peerStatus.Expired) assert.NotNil(t, peerStatus.KeyExpiry) - t.Logf("node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + t.Logf( + "node %q should have a key expire before %s, was %s", + peerStatus.HostName, + now.String(), + peerStatus.KeyExpiry, + ) if peerStatus.KeyExpiry != nil { - assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %q should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + assert.Truef( + t, + peerStatus.KeyExpiry.Before(now), + "node %q should have a key expire before %s, was %s", + peerStatus.HostName, + now.String(), + peerStatus.KeyExpiry, + ) } - assert.Truef(t, peerStatus.Expired, "node %q should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired) + assert.Truef( + t, + peerStatus.Expired, + "node %q should be expired, expired is %v", + peerStatus.HostName, + peerStatus.Expired, + ) _, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()}) if !strings.Contains(stderr, "node key has expired") { - t.Errorf("expected to be unable to ping expired host %q from %q", node.GetName(), client.Hostname()) + t.Errorf( + "expected to be unable to ping expired host %q from %q", + node.GetName(), + client.Hostname(), + ) } } else { t.Errorf("failed to find node %q with nodekey (%s) in mapresponse, should be present even if it is expired", node.GetName(), expiredNodeKey) @@ -597,7 +570,7 @@ func TestExpireNode(t *testing.T) { // NeedsLogin means that the node has understood that it is no longer // valid. - assert.Equal(t, "NeedsLogin", status.BackendState) + assert.Equalf(t, "NeedsLogin", status.BackendState, "checking node %q", status.Self.HostName) } } } @@ -690,7 +663,8 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) { assert.Truef( t, lastSeen.After(lastSeenThreshold), - "lastSeen (%v) was not %s after the threshold (%v)", + "node (%s) lastSeen (%v) was not %s after the threshold (%v)", + node.GetName(), lastSeen, keepAliveInterval, lastSeenThreshold, diff --git a/integration/route_test.go b/integration/route_test.go index 741ba24e..75296fd5 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -88,9 +88,9 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, routes, 3) for _, route := range routes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) } // Verify that no routes has been sent to the client, @@ -135,9 +135,9 @@ func TestEnablingRoutes(t *testing.T) { assert.Len(t, enablingRoutes, 3) for _, route := range enablingRoutes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, true, route.GetEnabled()) + assert.Equal(t, true, route.GetIsPrimary()) } time.Sleep(5 * time.Second) @@ -191,6 +191,8 @@ func TestEnablingRoutes(t *testing.T) { }) assertNoErr(t, err) + time.Sleep(5 * time.Second) + var disablingRoutes []*v1.Route err = executeAndUnmarshal( headscale, @@ -209,16 +211,14 @@ func TestEnablingRoutes(t *testing.T) { assert.Equal(t, true, route.GetAdvertised()) if route.GetId() == routeToBeDisabled.GetId() { - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) } else { - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) + assert.Equal(t, true, route.GetEnabled()) + assert.Equal(t, true, route.GetIsPrimary()) } } - time.Sleep(5 * time.Second) - // Verify that the clients can see the new routes for _, client := range allClients { status, err := client.Status() @@ -294,7 +294,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // advertise HA route on node 1 and 2 // ID 1 will be primary // ID 2 will be secondary - for _, client := range allClients { + for _, client := range allClients[:2] { status, err := client.Status() assertNoErr(t, err) @@ -306,6 +306,8 @@ func TestHASubnetRouterFailover(t *testing.T) { } _, _, err = client.Execute(command) assertNoErrf(t, "failed to advertise route: %s", err) + } else { + t.Fatalf("failed to find route for Node %s (id: %s)", status.Self.HostName, status.Self.ID) } } @@ -328,6 +330,8 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) assert.Len(t, routes, 2) + t.Logf("initial routes %#v", routes) + for _, route := range routes { assert.Equal(t, true, route.GetAdvertised()) assert.Equal(t, false, route.GetEnabled()) @@ -644,6 +648,8 @@ func TestHASubnetRouterFailover(t *testing.T) { assertNoErr(t, err) assert.Len(t, routesAfterDisabling1, 2) + t.Logf("routes after disabling1 %#v", routesAfterDisabling1) + // Node 1 is not primary assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled()) diff --git a/integration/scenario.go b/integration/scenario.go index c11af723..16ec6f47 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -56,8 +56,8 @@ var ( "1.44": true, // CapVer: 63 "1.42": true, // CapVer: 61 "1.40": true, // CapVer: 61 - "1.38": true, // CapVer: 58 - "1.36": true, // Oldest supported version, CapVer: 56 + "1.38": true, // Oldest supported version, CapVer: 58 + "1.36": false, // CapVer: 56 "1.34": false, // CapVer: 51 "1.32": false, // CapVer: 46 "1.30": false,