Rework map session

This commit restructures the map session in to a struct
holding the state of what is needed during its lifetime.

For streaming sessions, the event loop is structured a
bit differently not hammering the clients with updates
but rather batching them over a short, configurable time
which should significantly improve cpu usage, and potentially
flakyness.

The use of Patch updates has been dialed back a little as
it does not look like its a 100% ready for prime time. Nodes
are now updated with full changes, except for a few things
like online status.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-02-23 10:59:24 +01:00 committed by Juan Font
parent dd693c444c
commit 58c94d2bd3
35 changed files with 1803 additions and 1716 deletions

View file

@ -43,7 +43,8 @@ jobs:
- TestTaildrop - TestTaildrop
- TestResolveMagicDNS - TestResolveMagicDNS
- TestExpireNode - TestExpireNode
- TestNodeOnlineLastSeenStatus - TestNodeOnlineStatus
- TestPingAllByIPManyUpDown
- TestEnablingRoutes - TestEnablingRoutes
- TestHASubnetRouterFailover - TestHASubnetRouterFailover
- TestEnableDisableAutoApprovedRoute - TestEnableDisableAutoApprovedRoute

2
go.mod
View file

@ -150,6 +150,7 @@ require (
github.com/opencontainers/image-spec v1.1.0-rc6 // indirect github.com/opencontainers/image-spec v1.1.0-rc6 // indirect
github.com/opencontainers/runc v1.1.12 // indirect github.com/opencontainers/runc v1.1.12 // indirect
github.com/pelletier/go-toml/v2 v2.1.1 // indirect github.com/pelletier/go-toml/v2 v2.1.1 // indirect
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
@ -161,6 +162,7 @@ require (
github.com/safchain/ethtool v0.3.0 // indirect github.com/safchain/ethtool v0.3.0 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sasha-s/go-deadlock v0.3.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect github.com/spf13/afero v1.11.0 // indirect

4
go.sum
View file

@ -336,6 +336,8 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI=
github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ=
github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o=
github.com/philip-bui/grpc-zerolog v1.0.1 h1:EMacvLRUd2O1K0eWod27ZP5CY1iTNkhBDLSN+Q4JEvA= github.com/philip-bui/grpc-zerolog v1.0.1 h1:EMacvLRUd2O1K0eWod27ZP5CY1iTNkhBDLSN+Q4JEvA=
github.com/philip-bui/grpc-zerolog v1.0.1/go.mod h1:qXbiq/2X4ZUMMshsqlWyTHOcw7ns+GZmlqZZN05ZHcQ= github.com/philip-bui/grpc-zerolog v1.0.1/go.mod h1:qXbiq/2X4ZUMMshsqlWyTHOcw7ns+GZmlqZZN05ZHcQ=
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
@ -392,6 +394,8 @@ github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6g
github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0=
github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM=
github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=

View file

@ -28,6 +28,7 @@ import (
"github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/derp" "github.com/juanfont/headscale/hscontrol/derp"
derpServer "github.com/juanfont/headscale/hscontrol/derp/server" derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@ -38,6 +39,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
zl "github.com/rs/zerolog" zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -77,6 +79,11 @@ const (
registerCacheCleanup = time.Minute * 20 registerCacheCleanup = time.Minute * 20
) )
func init() {
deadlock.Opts.DeadlockTimeout = 15 * time.Second
deadlock.Opts.PrintAllCurrentGoroutines = true
}
// Headscale represents the base app of the service. // Headscale represents the base app of the service.
type Headscale struct { type Headscale struct {
cfg *types.Config cfg *types.Config
@ -89,6 +96,7 @@ type Headscale struct {
ACLPolicy *policy.ACLPolicy ACLPolicy *policy.ACLPolicy
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier nodeNotifier *notifier.Notifier
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
@ -96,8 +104,10 @@ type Headscale struct {
registrationCache *cache.Cache registrationCache *cache.Cache
shutdownChan chan struct{}
pollNetMapStreamWG sync.WaitGroup pollNetMapStreamWG sync.WaitGroup
mapSessions map[types.NodeID]*mapSession
mapSessionMu deadlock.Mutex
} }
var ( var (
@ -129,6 +139,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
registrationCache: registrationCache, registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{}, pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(), nodeNotifier: notifier.NewNotifier(),
mapSessions: make(map[types.NodeID]*mapSession),
} }
app.db, err = db.NewHeadscaleDatabase( app.db, err = db.NewHeadscaleDatabase(
@ -199,16 +210,16 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, target, http.StatusFound) http.Redirect(w, req, target, http.StatusFound)
} }
// expireEphemeralNodes deletes ephemeral node records that have not been // deleteExpireEphemeralNodes deletes ephemeral node records that have not been
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout. // seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
var update types.StateUpdate
var changed bool
for range ticker.C { for range ticker.C {
var removed []types.NodeID
var changed []types.NodeID
if err := h.db.DB.Transaction(func(tx *gorm.DB) error { if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
update, changed = db.ExpireEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
return nil return nil
}); err != nil { }); err != nil {
@ -216,9 +227,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
continue continue
} }
if changed && update.Valid() { if removed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na") ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, update) h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: removed,
})
}
if changed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changed,
})
} }
} }
} }
@ -243,8 +265,9 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
continue continue
} }
log.Trace().Str("nodes", update.ChangeNodes.String()).Msgf("expiring nodes") if changed {
if changed && update.Valid() { log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na") ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update) h.nodeNotifier.NotifyAll(ctx, update)
} }
@ -272,14 +295,11 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
h.DERPMap.Regions[region.RegionID] = &region h.DERPMap.Regions[region.RegionID] = &region
} }
stateUpdate := types.StateUpdate{ ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateDERPUpdated, Type: types.StateDERPUpdated,
DERPMap: h.DERPMap, DERPMap: h.DERPMap,
} })
if stateUpdate.Valid() {
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, stateUpdate)
}
} }
} }
} }
@ -502,6 +522,7 @@ func (h *Headscale) Serve() error {
// Fetch an initial DERP Map before we start serving // Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP) h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap())
if h.cfg.DERP.ServerEnabled { if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server // When embedded DERP is enabled we always need a STUN server
@ -533,7 +554,7 @@ func (h *Headscale) Serve() error {
// TODO(kradalby): These should have cancel channels and be cleaned // TODO(kradalby): These should have cancel channels and be cleaned
// up on shutdown. // up on shutdown.
go h.expireEphemeralNodes(updateInterval) go h.deleteExpireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval) go h.expireExpiredMachines(updateInterval)
if zl.GlobalLevel() == zl.TraceLevel { if zl.GlobalLevel() == zl.TraceLevel {
@ -686,6 +707,9 @@ func (h *Headscale) Serve() error {
// no good way to handle streaming timeouts, therefore we need to // no good way to handle streaming timeouts, therefore we need to
// keep this at unlimited and be careful to clean up connections // keep this at unlimited and be careful to clean up connections
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming // https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming
// TODO(kradalby): this timeout can now be set per handler with http.ResponseController:
// https://www.alexedwards.net/blog/how-to-use-the-http-responsecontroller-type
// replace this so only the longpoller has no timeout.
WriteTimeout: 0, WriteTimeout: 0,
} }
@ -742,7 +766,6 @@ func (h *Headscale) Serve() error {
} }
// Handle common process-killing signals so we can gracefully shut down: // Handle common process-killing signals so we can gracefully shut down:
h.shutdownChan = make(chan struct{})
sigc := make(chan os.Signal, 1) sigc := make(chan os.Signal, 1)
signal.Notify(sigc, signal.Notify(sigc,
syscall.SIGHUP, syscall.SIGHUP,
@ -785,8 +808,6 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully") Msg("Received signal to stop, shutting down gracefully")
close(h.shutdownChan)
h.pollNetMapStreamWG.Wait() h.pollNetMapStreamWG.Wait()
// Gracefully shut down servers // Gracefully shut down servers

View file

@ -352,13 +352,8 @@ func (h *Headscale) handleAuthKey(
} }
} }
mkey := node.MachineKey ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
update := types.StateUpdateExpire(node.ID, registerRequest.Expiry) h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, registerRequest.Expiry), node.ID)
if update.Valid() {
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
}
} else { } else {
now := time.Now().UTC() now := time.Now().UTC()
@ -538,11 +533,8 @@ func (h *Headscale) handleNodeLogOut(
return return
} }
stateUpdate := types.StateUpdateExpire(node.ID, now) ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
if stateUpdate.Valid() { h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = false resp.MachineAuthorized = false
@ -572,7 +564,7 @@ func (h *Headscale) handleNodeLogOut(
} }
if node.IsEphemeral() { if node.IsEphemeral() {
err = h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap()) changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -580,13 +572,16 @@ func (h *Headscale) handleNodeLogOut(
Msg("Cannot delete ephemeral node from the database") Msg("Cannot delete ephemeral node from the database")
} }
stateUpdate := types.StateUpdate{ ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved, Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, Removed: []types.NodeID{node.ID},
} })
if stateUpdate.Valid() { if changedNodes != nil {
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
h.nodeNotifier.NotifyAll(ctx, stateUpdate) Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
} }
return return

View file

@ -34,27 +34,22 @@ var (
) )
) )
func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) { func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListPeers(rx, node) return ListPeers(rx, nodeID)
}) })
} }
// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. // ListPeers returns all peers of node, regardless of any Policy or if the node is expired.
func ListPeers(tx *gorm.DB, node *types.Node) (types.Nodes, error) { func ListPeers(tx *gorm.DB, nodeID types.NodeID) (types.Nodes, error) {
log.Trace().
Caller().
Str("node", node.Hostname).
Msg("Finding direct peers")
nodes := types.Nodes{} nodes := types.Nodes{}
if err := tx. if err := tx.
Preload("AuthKey"). Preload("AuthKey").
Preload("AuthKey.User"). Preload("AuthKey.User").
Preload("User"). Preload("User").
Preload("Routes"). Preload("Routes").
Where("node_key <> ?", Where("id <> ?",
node.NodeKey.String()).Find(&nodes).Error; err != nil { nodeID).Find(&nodes).Error; err != nil {
return types.Nodes{}, err return types.Nodes{}, err
} }
@ -119,14 +114,14 @@ func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
return nil, ErrNodeNotFound return nil, ErrNodeNotFound
} }
func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByID(rx, id) return GetNodeByID(rx, id)
}) })
} }
// GetNodeByID finds a Node by ID and returns the Node struct. // GetNodeByID finds a Node by ID and returns the Node struct.
func GetNodeByID(tx *gorm.DB, id uint64) (*types.Node, error) { func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
mach := types.Node{} mach := types.Node{}
if result := tx. if result := tx.
Preload("AuthKey"). Preload("AuthKey").
@ -197,7 +192,7 @@ func GetNodeByAnyKey(
} }
func (hsdb *HSDatabase) SetTags( func (hsdb *HSDatabase) SetTags(
nodeID uint64, nodeID types.NodeID,
tags []string, tags []string,
) error { ) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
@ -208,7 +203,7 @@ func (hsdb *HSDatabase) SetTags(
// SetTags takes a Node struct pointer and update the forced tags. // SetTags takes a Node struct pointer and update the forced tags.
func SetTags( func SetTags(
tx *gorm.DB, tx *gorm.DB,
nodeID uint64, nodeID types.NodeID,
tags []string, tags []string,
) error { ) error {
if len(tags) == 0 { if len(tags) == 0 {
@ -256,7 +251,7 @@ func RenameNode(tx *gorm.DB,
return nil return nil
} }
func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error { func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
return NodeSetExpiry(tx, nodeID, expiry) return NodeSetExpiry(tx, nodeID, expiry)
}) })
@ -264,13 +259,13 @@ func (hsdb *HSDatabase) NodeSetExpiry(nodeID uint64, expiry time.Time) error {
// NodeSetExpiry takes a Node struct and a new expiry time. // NodeSetExpiry takes a Node struct and a new expiry time.
func NodeSetExpiry(tx *gorm.DB, func NodeSetExpiry(tx *gorm.DB,
nodeID uint64, expiry time.Time, nodeID types.NodeID, expiry time.Time,
) error { ) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
} }
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.MachinePublic]bool) error { func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
return hsdb.Write(func(tx *gorm.DB) error { return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return DeleteNode(tx, node, isConnected) return DeleteNode(tx, node, isConnected)
}) })
} }
@ -279,24 +274,24 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected map[key.Machine
// Caller is responsible for notifying all of change. // Caller is responsible for notifying all of change.
func DeleteNode(tx *gorm.DB, func DeleteNode(tx *gorm.DB,
node *types.Node, node *types.Node,
isConnected map[key.MachinePublic]bool, isConnected types.NodeConnectedMap,
) error { ) ([]types.NodeID, error) {
err := deleteNodeRoutes(tx, node, map[key.MachinePublic]bool{}) changed, err := deleteNodeRoutes(tx, node, isConnected)
if err != nil { if err != nil {
return err return changed, err
} }
// Unscoped causes the node to be fully removed from the database. // Unscoped causes the node to be fully removed from the database.
if err := tx.Unscoped().Delete(&node).Error; err != nil { if err := tx.Unscoped().Delete(&node).Error; err != nil {
return err return changed, err
} }
return nil return changed, nil
} }
// UpdateLastSeen sets a node's last seen field indicating that we // SetLastSeen sets a node's last seen field indicating that we
// have recently communicating with this node. // have recently communicating with this node.
func UpdateLastSeen(tx *gorm.DB, nodeID uint64, lastSeen time.Time) error { func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error
} }
@ -606,7 +601,7 @@ func enableRoutes(tx *gorm.DB,
return &types.StateUpdate{ return &types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node}, ChangeNodes: []types.NodeID{node.ID},
Message: "created in db.enableRoutes", Message: "created in db.enableRoutes",
}, nil }, nil
} }
@ -681,17 +676,18 @@ func GenerateGivenName(
return givenName, nil return givenName, nil
} }
func ExpireEphemeralNodes(tx *gorm.DB, func DeleteExpiredEphemeralNodes(tx *gorm.DB,
inactivityThreshhold time.Duration, inactivityThreshhold time.Duration,
) (types.StateUpdate, bool) { ) ([]types.NodeID, []types.NodeID) {
users, err := ListUsers(tx) users, err := ListUsers(tx)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error listing users") log.Error().Err(err).Msg("Error listing users")
return types.StateUpdate{}, false return nil, nil
} }
expired := make([]tailcfg.NodeID, 0) var expired []types.NodeID
var changedNodes []types.NodeID
for _, user := range users { for _, user := range users {
nodes, err := ListNodesByUser(tx, user.Name) nodes, err := ListNodesByUser(tx, user.Name)
if err != nil { if err != nil {
@ -700,40 +696,36 @@ func ExpireEphemeralNodes(tx *gorm.DB,
Str("user", user.Name). Str("user", user.Name).
Msg("Error listing nodes in user") Msg("Error listing nodes in user")
return types.StateUpdate{}, false return nil, nil
} }
for idx, node := range nodes { for idx, node := range nodes {
if node.IsEphemeral() && node.LastSeen != nil && if node.IsEphemeral() && node.LastSeen != nil &&
time.Now(). time.Now().
After(node.LastSeen.Add(inactivityThreshhold)) { After(node.LastSeen.Add(inactivityThreshhold)) {
expired = append(expired, tailcfg.NodeID(node.ID)) expired = append(expired, node.ID)
log.Info(). log.Info().
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("Ephemeral client removed from database") Msg("Ephemeral client removed from database")
// empty isConnected map as ephemeral nodes are not routes // empty isConnected map as ephemeral nodes are not routes
err = DeleteNode(tx, nodes[idx], map[key.MachinePublic]bool{}) changed, err := DeleteNode(tx, nodes[idx], nil)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("🤮 Cannot delete ephemeral node from the database") Msg("🤮 Cannot delete ephemeral node from the database")
} }
changedNodes = append(changedNodes, changed...)
} }
} }
// TODO(kradalby): needs to be moved out of transaction // TODO(kradalby): needs to be moved out of transaction
} }
if len(expired) > 0 {
return types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: expired,
}, true
}
return types.StateUpdate{}, false return expired, changedNodes
} }
func ExpireExpiredNodes(tx *gorm.DB, func ExpireExpiredNodes(tx *gorm.DB,
@ -754,35 +746,12 @@ func ExpireExpiredNodes(tx *gorm.DB,
return time.Unix(0, 0), types.StateUpdate{}, false return time.Unix(0, 0), types.StateUpdate{}, false
} }
for index, node := range nodes { for _, node := range nodes {
if node.IsExpired() && if node.IsExpired() && node.Expiry.After(lastCheck) {
// TODO(kradalby): Replace this, it is very spammy
// It will notify about all nodes that has been expired.
// It should only notify about expired nodes since _last check_.
node.Expiry.After(lastCheck) {
expired = append(expired, &tailcfg.PeerChange{ expired = append(expired, &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID), NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry, KeyExpiry: node.Expiry,
}) })
now := time.Now()
// Do not use setNodeExpiry as that has a notifier hook, which
// can cause a deadlock, we are updating all changed nodes later
// and there is no point in notifiying twice.
if err := tx.Model(&nodes[index]).Updates(types.Node{
Expiry: &now,
}).Error; err != nil {
log.Error().
Err(err).
Str("node", node.Hostname).
Str("name", node.GivenName).
Msg("🤮 Cannot expire node")
} else {
log.Info().
Str("node", node.Hostname).
Str("name", node.GivenName).
Msg("Node successfully expired")
}
} }
} }

View file

@ -120,7 +120,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
} }
db.DB.Save(&node) db.DB.Save(&node)
err = db.DeleteNode(&node, map[key.MachinePublic]bool{}) _, err = db.DeleteNode(&node, types.NodeConnectedMap{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(user.Name, "testnode3") _, err = db.getNode(user.Name, "testnode3")
@ -142,7 +142,7 @@ func (s *Suite) TestListPeers(c *check.C) {
machineKey := key.NewMachine() machineKey := key.NewMachine()
node := types.Node{ node := types.Node{
ID: uint64(index), ID: types.NodeID(index),
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
Hostname: "testnode" + strconv.Itoa(index), Hostname: "testnode" + strconv.Itoa(index),
@ -156,7 +156,7 @@ func (s *Suite) TestListPeers(c *check.C) {
node0ByID, err := db.GetNodeByID(0) node0ByID, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peersOfNode0, err := db.ListPeers(node0ByID) peersOfNode0, err := db.ListPeers(node0ByID.ID)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(peersOfNode0), check.Equals, 9) c.Assert(len(peersOfNode0), check.Equals, 9)
@ -189,7 +189,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
machineKey := key.NewMachine() machineKey := key.NewMachine()
node := types.Node{ node := types.Node{
ID: uint64(index), ID: types.NodeID(index),
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
IPAddresses: types.NodeAddresses{ IPAddresses: types.NodeAddresses{
@ -232,16 +232,16 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User) c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
adminPeers, err := db.ListPeers(adminNode) adminPeers, err := db.ListPeers(adminNode.ID)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
testPeers, err := db.ListPeers(testNode) testPeers, err := db.ListPeers(testNode.ID)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
adminRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, adminNode, adminPeers) adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, testNode, testPeers) testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
@ -586,7 +586,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// TODO(kradalby): Check state update // TODO(kradalby): Check state update
_, err = db.EnableAutoApprovedRoutes(pol, node0ByID) err = db.EnableAutoApprovedRoutes(pol, node0ByID)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes, err := db.GetEnabledRoutes(node0ByID) enabledRoutes, err := db.GetEnabledRoutes(node0ByID)

View file

@ -92,10 +92,6 @@ func CreatePreAuthKey(
} }
} }
if err != nil {
return nil, err
}
return &key, nil return &key, nil
} }

View file

@ -148,7 +148,7 @@ func (*Suite) TestEphemeralKeyReusable(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
db.DB.Transaction(func(tx *gorm.DB) error { db.DB.Transaction(func(tx *gorm.DB) error {
ExpireEphemeralNodes(tx, time.Second*20) DeleteExpiredEphemeralNodes(tx, time.Second*20)
return nil return nil
}) })
@ -182,7 +182,7 @@ func (*Suite) TestEphemeralKeyNotReusable(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
db.DB.Transaction(func(tx *gorm.DB) error { db.DB.Transaction(func(tx *gorm.DB) error {
ExpireEphemeralNodes(tx, time.Second*20) DeleteExpiredEphemeralNodes(tx, time.Second*20)
return nil return nil
}) })

View file

@ -8,7 +8,6 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/types/key"
) )
var ErrRouteIsNotAvailable = errors.New("route is not available") var ErrRouteIsNotAvailable = errors.New("route is not available")
@ -124,8 +123,8 @@ func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
func DisableRoute(tx *gorm.DB, func DisableRoute(tx *gorm.DB,
id uint64, id uint64,
isConnected map[key.MachinePublic]bool, isConnected types.NodeConnectedMap,
) (*types.StateUpdate, error) { ) ([]types.NodeID, error) {
route, err := GetRoute(tx, id) route, err := GetRoute(tx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@ -137,16 +136,15 @@ func DisableRoute(tx *gorm.DB,
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update *types.StateUpdate var update []types.NodeID
if !route.IsExitRoute() { if !route.IsExitRoute() {
update, err = failoverRouteReturnUpdate(tx, isConnected, route) route.Enabled = false
err = tx.Save(route).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
route.Enabled = false update, err = failoverRouteTx(tx, isConnected, route)
route.IsPrimary = false
err = tx.Save(route).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -160,6 +158,7 @@ func DisableRoute(tx *gorm.DB,
if routes[i].IsExitRoute() { if routes[i].IsExitRoute() {
routes[i].Enabled = false routes[i].Enabled = false
routes[i].IsPrimary = false routes[i].IsPrimary = false
err = tx.Save(&routes[i]).Error err = tx.Save(&routes[i]).Error
if err != nil { if err != nil {
return nil, err return nil, err
@ -168,26 +167,11 @@ func DisableRoute(tx *gorm.DB,
} }
} }
if routes == nil {
routes, err = GetNodeRoutes(tx, &node)
if err != nil {
return nil, err
}
}
node.Routes = routes
// If update is empty, it means that one was not created // If update is empty, it means that one was not created
// by failover (as a failover was not necessary), create // by failover (as a failover was not necessary), create
// one and return to the caller. // one and return to the caller.
if update == nil { if update == nil {
update = &types.StateUpdate{ update = []types.NodeID{node.ID}
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{
&node,
},
Message: "called from db.DisableRoute",
}
} }
return update, nil return update, nil
@ -195,9 +179,9 @@ func DisableRoute(tx *gorm.DB,
func (hsdb *HSDatabase) DeleteRoute( func (hsdb *HSDatabase) DeleteRoute(
id uint64, id uint64,
isConnected map[key.MachinePublic]bool, isConnected types.NodeConnectedMap,
) (*types.StateUpdate, error) { ) ([]types.NodeID, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return DeleteRoute(tx, id, isConnected) return DeleteRoute(tx, id, isConnected)
}) })
} }
@ -205,8 +189,8 @@ func (hsdb *HSDatabase) DeleteRoute(
func DeleteRoute( func DeleteRoute(
tx *gorm.DB, tx *gorm.DB,
id uint64, id uint64,
isConnected map[key.MachinePublic]bool, isConnected types.NodeConnectedMap,
) (*types.StateUpdate, error) { ) ([]types.NodeID, error) {
route, err := GetRoute(tx, id) route, err := GetRoute(tx, id)
if err != nil { if err != nil {
return nil, err return nil, err
@ -218,9 +202,9 @@ func DeleteRoute(
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update *types.StateUpdate var update []types.NodeID
if !route.IsExitRoute() { if !route.IsExitRoute() {
update, err = failoverRouteReturnUpdate(tx, isConnected, route) update, err = failoverRouteTx(tx, isConnected, route)
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
@ -229,7 +213,7 @@ func DeleteRoute(
return nil, err return nil, err
} }
} else { } else {
routes, err := GetNodeRoutes(tx, &node) routes, err = GetNodeRoutes(tx, &node)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -259,35 +243,37 @@ func DeleteRoute(
node.Routes = routes node.Routes = routes
if update == nil { if update == nil {
update = &types.StateUpdate{ update = []types.NodeID{node.ID}
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{
&node,
},
Message: "called from db.DeleteRoute",
}
} }
return update, nil return update, nil
} }
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected map[key.MachinePublic]bool) error { func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
routes, err := GetNodeRoutes(tx, node) routes, err := GetNodeRoutes(tx, node)
if err != nil { if err != nil {
return err return nil, err
} }
var changed []types.NodeID
for i := range routes { for i := range routes {
if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil { if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil {
return err return nil, err
} }
// TODO(kradalby): This is a bit too aggressive, we could probably // TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all. // figure out which routes needs to be failed over rather than all.
failoverRouteReturnUpdate(tx, isConnected, &routes[i]) chn, err := failoverRouteTx(tx, isConnected, &routes[i])
if err != nil {
return changed, err
}
if chn != nil {
changed = append(changed, chn...)
}
} }
return nil return changed, nil
} }
// isUniquePrefix returns if there is another node providing the same route already. // isUniquePrefix returns if there is another node providing the same route already.
@ -400,7 +386,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
for prefix, exists := range advertisedRoutes { for prefix, exists := range advertisedRoutes {
if !exists { if !exists {
route := types.Route{ route := types.Route{
NodeID: node.ID, NodeID: node.ID.Uint64(),
Prefix: types.IPPrefix(prefix), Prefix: types.IPPrefix(prefix),
Advertised: true, Advertised: true,
Enabled: false, Enabled: false,
@ -415,19 +401,23 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
return sendUpdate, nil return sendUpdate, nil
} }
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route // FailoverRouteIfAvailable takes a node and checks if the node's route
// currently have a functioning host that exposes the network. // currently have a functioning host that exposes the network.
func EnsureFailoverRouteIsAvailable( // If it does not, it is failed over to another suitable route if there
// is one.
func FailoverRouteIfAvailable(
tx *gorm.DB, tx *gorm.DB,
isConnected map[key.MachinePublic]bool, isConnected types.NodeConnectedMap,
node *types.Node, node *types.Node,
) (*types.StateUpdate, error) { ) (*types.StateUpdate, error) {
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Msgf("ROUTE DEBUG ENTERED FAILOVER")
nodeRoutes, err := GetNodeRoutes(tx, node) nodeRoutes, err := GetNodeRoutes(tx, node)
if err != nil { if err != nil {
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("nodeRoutes", nodeRoutes).Msgf("ROUTE DEBUG NO ROUTES")
return nil, nil return nil, nil
} }
var changedNodes types.Nodes var changedNodes []types.NodeID
for _, nodeRoute := range nodeRoutes { for _, nodeRoute := range nodeRoutes {
routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix)) routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix))
if err != nil { if err != nil {
@ -438,71 +428,39 @@ func EnsureFailoverRouteIsAvailable(
if route.IsPrimary { if route.IsPrimary {
// if we have a primary route, and the node is connected // if we have a primary route, and the node is connected
// nothing needs to be done. // nothing needs to be done.
if isConnected[route.Node.MachineKey] { log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG CHECKING IF ONLINE")
continue if isConnected[route.Node.ID] {
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG IS ONLINE")
return nil, nil
} }
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG NOT ONLINE, FAILING OVER")
// if not, we need to failover the route // if not, we need to failover the route
update, err := failoverRouteReturnUpdate(tx, isConnected, &route) changedIDs, err := failoverRouteTx(tx, isConnected, &route)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if update != nil { if changedIDs != nil {
changedNodes = append(changedNodes, update.ChangeNodes...) changedNodes = append(changedNodes, changedIDs...)
} }
} }
} }
} }
log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("changedNodes", changedNodes).Msgf("ROUTE DEBUG")
if len(changedNodes) != 0 { if len(changedNodes) != 0 {
return &types.StateUpdate{ return &types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: changedNodes, ChangeNodes: changedNodes,
Message: "called from db.EnsureFailoverRouteIsAvailable", Message: "called from db.FailoverRouteIfAvailable",
}, nil }, nil
} }
return nil, nil return nil, nil
} }
func failoverRouteReturnUpdate( // failoverRouteTx takes a route that is no longer available,
tx *gorm.DB,
isConnected map[key.MachinePublic]bool,
r *types.Route,
) (*types.StateUpdate, error) {
changedKeys, err := failoverRoute(tx, isConnected, r)
if err != nil {
return nil, err
}
log.Trace().
Interface("isConnected", isConnected).
Interface("changedKeys", changedKeys).
Msg("building route failover")
if len(changedKeys) == 0 {
return nil, nil
}
var nodes types.Nodes
for _, key := range changedKeys {
node, err := GetNodeByMachineKey(tx, key)
if err != nil {
return nil, err
}
nodes = append(nodes, node)
}
return &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.failoverRouteReturnUpdate",
}, nil
}
// failoverRoute takes a route that is no longer available,
// this can be either from: // this can be either from:
// - being disabled // - being disabled
// - being deleted // - being deleted
@ -510,11 +468,11 @@ func failoverRouteReturnUpdate(
// //
// and tries to find a new route to take over its place. // and tries to find a new route to take over its place.
// If the given route was not primary, it returns early. // If the given route was not primary, it returns early.
func failoverRoute( func failoverRouteTx(
tx *gorm.DB, tx *gorm.DB,
isConnected map[key.MachinePublic]bool, isConnected types.NodeConnectedMap,
r *types.Route, r *types.Route,
) ([]key.MachinePublic, error) { ) ([]types.NodeID, error) {
if r == nil { if r == nil {
return nil, nil return nil, nil
} }
@ -535,11 +493,64 @@ func failoverRoute(
return nil, err return nil, err
} }
fo := failoverRoute(isConnected, r, routes)
if fo == nil {
return nil, nil
}
err = tx.Save(fo.old).Error
if err != nil {
log.Error().Err(err).Msg("disabling old primary route")
return nil, err
}
err = tx.Save(fo.new).Error
if err != nil {
log.Error().Err(err).Msg("saving new primary route")
return nil, err
}
log.Trace().
Str("hostname", fo.new.Node.Hostname).
Msgf("set primary to new route, was: id(%d), host(%s), now: id(%d), host(%s)", fo.old.ID, fo.old.Node.Hostname, fo.new.ID, fo.new.Node.Hostname)
// Return a list of the machinekeys of the changed nodes.
return []types.NodeID{fo.old.Node.ID, fo.new.Node.ID}, nil
}
type failover struct {
old *types.Route
new *types.Route
}
func failoverRoute(
isConnected types.NodeConnectedMap,
routeToReplace *types.Route,
altRoutes types.Routes,
) *failover {
if routeToReplace == nil {
return nil
}
// This route is not a primary route, and it is not
// being served to nodes.
if !routeToReplace.IsPrimary {
return nil
}
// We do not have to failover exit nodes
if routeToReplace.IsExitRoute() {
return nil
}
var newPrimary *types.Route var newPrimary *types.Route
// Find a new suitable route // Find a new suitable route
for idx, route := range routes { for idx, route := range altRoutes {
if r.ID == route.ID { if routeToReplace.ID == route.ID {
continue continue
} }
@ -547,8 +558,8 @@ func failoverRoute(
continue continue
} }
if isConnected[route.Node.MachineKey] { if isConnected != nil && isConnected[route.Node.ID] {
newPrimary = &routes[idx] newPrimary = &altRoutes[idx]
break break
} }
} }
@ -559,48 +570,23 @@ func failoverRoute(
// the one currently marked as primary is the // the one currently marked as primary is the
// best we got. // best we got.
if newPrimary == nil { if newPrimary == nil {
return nil, nil return nil
} }
log.Trace(). routeToReplace.IsPrimary = false
Str("hostname", newPrimary.Node.Hostname).
Msg("found new primary, updating db")
// Remove primary from the old route
r.IsPrimary = false
err = tx.Save(&r).Error
if err != nil {
log.Error().Err(err).Msg("error disabling new primary route")
return nil, err
}
log.Trace().
Str("hostname", newPrimary.Node.Hostname).
Msg("removed primary from old route")
// Set primary for the new primary
newPrimary.IsPrimary = true newPrimary.IsPrimary = true
err = tx.Save(&newPrimary).Error
if err != nil {
log.Error().Err(err).Msg("error enabling new primary route")
return nil, err return &failover{
old: routeToReplace,
new: newPrimary,
} }
log.Trace().
Str("hostname", newPrimary.Node.Hostname).
Msg("set primary to new route")
// Return a list of the machinekeys of the changed nodes.
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
} }
func (hsdb *HSDatabase) EnableAutoApprovedRoutes( func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy, aclPolicy *policy.ACLPolicy,
node *types.Node, node *types.Node,
) (*types.StateUpdate, error) { ) error {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { return hsdb.Write(func(tx *gorm.DB) error {
return EnableAutoApprovedRoutes(tx, aclPolicy, node) return EnableAutoApprovedRoutes(tx, aclPolicy, node)
}) })
} }
@ -610,9 +596,9 @@ func EnableAutoApprovedRoutes(
tx *gorm.DB, tx *gorm.DB,
aclPolicy *policy.ACLPolicy, aclPolicy *policy.ACLPolicy,
node *types.Node, node *types.Node,
) (*types.StateUpdate, error) { ) error {
if len(node.IPAddresses) == 0 { if len(node.IPAddresses) == 0 {
return nil, nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs
} }
routes, err := GetNodeAdvertisedRoutes(tx, node) routes, err := GetNodeAdvertisedRoutes(tx, node)
@ -623,7 +609,7 @@ func EnableAutoApprovedRoutes(
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("Could not get advertised routes for node") Msg("Could not get advertised routes for node")
return nil, err return err
} }
log.Trace().Interface("routes", routes).Msg("routes for autoapproving") log.Trace().Interface("routes", routes).Msg("routes for autoapproving")
@ -641,10 +627,10 @@ func EnableAutoApprovedRoutes(
if err != nil { if err != nil {
log.Err(err). log.Err(err).
Str("advertisedRoute", advertisedRoute.String()). Str("advertisedRoute", advertisedRoute.String()).
Uint64("nodeId", node.ID). Uint64("nodeId", node.ID.Uint64()).
Msg("Failed to resolve autoApprovers for advertised route") Msg("Failed to resolve autoApprovers for advertised route")
return nil, err return err
} }
log.Trace(). log.Trace().
@ -665,7 +651,7 @@ func EnableAutoApprovedRoutes(
Str("alias", approvedAlias). Str("alias", approvedAlias).
Msg("Failed to expand alias when processing autoApprovers policy") Msg("Failed to expand alias when processing autoApprovers policy")
return nil, err return err
} }
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first // approvedIPs should contain all of node's IPs if it matches the rule, so check for first
@ -676,25 +662,17 @@ func EnableAutoApprovedRoutes(
} }
} }
update := &types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{},
Message: "created in db.EnableAutoApprovedRoutes",
}
for _, approvedRoute := range approvedRoutes { for _, approvedRoute := range approvedRoutes {
perHostUpdate, err := EnableRoute(tx, uint64(approvedRoute.ID)) _, err := EnableRoute(tx, uint64(approvedRoute.ID))
if err != nil { if err != nil {
log.Err(err). log.Err(err).
Str("approvedRoute", approvedRoute.String()). Str("approvedRoute", approvedRoute.String()).
Uint64("nodeId", node.ID). Uint64("nodeId", node.ID.Uint64()).
Msg("Failed to enable approved route") Msg("Failed to enable approved route")
return nil, err return err
} }
update.ChangeNodes = append(update.ChangeNodes, perHostUpdate.ChangeNodes...)
} }
return update, nil return nil
} }

View file

@ -13,7 +13,6 @@ import (
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
func (s *Suite) TestGetRoutes(c *check.C) { func (s *Suite) TestGetRoutes(c *check.C) {
@ -262,7 +261,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// TODO(kradalby): check stateupdate // TODO(kradalby): check stateupdate
_, err = db.DeleteRoute(uint64(routes[0].ID), map[key.MachinePublic]bool{}) _, err = db.DeleteRoute(uint64(routes[0].ID), nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1) enabledRoutes1, err := db.GetEnabledRoutes(&node1)
@ -272,20 +271,13 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } var ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
func TestFailoverRoute(t *testing.T) { func TestFailoverRouteTx(t *testing.T) {
machineKeys := []key.MachinePublic{
key.NewMachine().Public(),
key.NewMachine().Public(),
key.NewMachine().Public(),
key.NewMachine().Public(),
}
tests := []struct { tests := []struct {
name string name string
failingRoute types.Route failingRoute types.Route
routes types.Routes routes types.Routes
isConnected map[key.MachinePublic]bool isConnected types.NodeConnectedMap
want []key.MachinePublic want []types.NodeID
wantErr bool wantErr bool
}{ }{
{ {
@ -301,10 +293,8 @@ func TestFailoverRoute(t *testing.T) {
Model: gorm.Model{ Model: gorm.Model{
ID: 1, ID: 1,
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{},
MachineKey: machineKeys[0],
},
IsPrimary: false, IsPrimary: false,
}, },
routes: types.Routes{}, routes: types.Routes{},
@ -317,10 +307,8 @@ func TestFailoverRoute(t *testing.T) {
Model: gorm.Model{ Model: gorm.Model{
ID: 1, ID: 1,
}, },
Prefix: ipp("0.0.0.0/0"), Prefix: ipp("0.0.0.0/0"),
Node: types.Node{ Node: types.Node{},
MachineKey: machineKeys[0],
},
IsPrimary: true, IsPrimary: true,
}, },
routes: types.Routes{}, routes: types.Routes{},
@ -335,7 +323,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
}, },
@ -346,7 +334,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
}, },
@ -362,7 +350,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -374,7 +362,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -385,19 +373,19 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[1], ID: 2,
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true, Enabled: true,
}, },
}, },
isConnected: map[key.MachinePublic]bool{ isConnected: types.NodeConnectedMap{
machineKeys[0]: false, 1: false,
machineKeys[1]: true, 2: true,
}, },
want: []key.MachinePublic{ want: []types.NodeID{
machineKeys[0], 1,
machineKeys[1], 2,
}, },
wantErr: false, wantErr: false,
}, },
@ -409,7 +397,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true, Enabled: true,
@ -421,7 +409,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -432,7 +420,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[1], ID: 2,
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true, Enabled: true,
@ -449,7 +437,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[1], ID: 2,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -461,7 +449,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true, Enabled: true,
@ -472,7 +460,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[1], ID: 2,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -483,20 +471,19 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[2], ID: 3,
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true, Enabled: true,
}, },
}, },
isConnected: map[key.MachinePublic]bool{ isConnected: types.NodeConnectedMap{
machineKeys[0]: true, 1: true,
machineKeys[1]: true, 2: true,
machineKeys[2]: true, 3: true,
}, },
want: []key.MachinePublic{ want: []types.NodeID{
machineKeys[1], 2, 1,
machineKeys[0],
}, },
wantErr: false, wantErr: false,
}, },
@ -508,7 +495,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -520,7 +507,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -532,15 +519,15 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[3], ID: 4,
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true, Enabled: true,
}, },
}, },
isConnected: map[key.MachinePublic]bool{ isConnected: types.NodeConnectedMap{
machineKeys[0]: true, 1: true,
machineKeys[3]: false, 4: false,
}, },
want: nil, want: nil,
wantErr: false, wantErr: false,
@ -553,7 +540,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -565,7 +552,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -577,7 +564,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[3], ID: 4,
}, },
IsPrimary: false, IsPrimary: false,
Enabled: true, Enabled: true,
@ -588,20 +575,20 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[1], ID: 2,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
}, },
}, },
isConnected: map[key.MachinePublic]bool{ isConnected: types.NodeConnectedMap{
machineKeys[0]: false, 1: false,
machineKeys[1]: true, 2: true,
machineKeys[3]: false, 4: false,
}, },
want: []key.MachinePublic{ want: []types.NodeID{
machineKeys[0], 1,
machineKeys[1], 2,
}, },
wantErr: false, wantErr: false,
}, },
@ -613,7 +600,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -625,7 +612,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[0], ID: 1,
}, },
IsPrimary: true, IsPrimary: true,
Enabled: true, Enabled: true,
@ -637,7 +624,7 @@ func TestFailoverRoute(t *testing.T) {
}, },
Prefix: ipp("10.0.0.0/24"), Prefix: ipp("10.0.0.0/24"),
Node: types.Node{ Node: types.Node{
MachineKey: machineKeys[1], ID: 2,
}, },
IsPrimary: false, IsPrimary: false,
Enabled: false, Enabled: false,
@ -670,8 +657,8 @@ func TestFailoverRoute(t *testing.T) {
} }
} }
got, err := Write(db.DB, func(tx *gorm.DB) ([]key.MachinePublic, error) { got, err := Write(db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return failoverRoute(tx, tt.isConnected, &tt.failingRoute) return failoverRouteTx(tx, tt.isConnected, &tt.failingRoute)
}) })
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -687,230 +674,177 @@ func TestFailoverRoute(t *testing.T) {
} }
} }
// func TestDisableRouteFailover(t *testing.T) { func TestFailoverRoute(t *testing.T) {
// machineKeys := []key.MachinePublic{ r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
// key.NewMachine().Public(), return types.Route{
// key.NewMachine().Public(), Model: gorm.Model{
// key.NewMachine().Public(), ID: id,
// key.NewMachine().Public(), },
// } Node: types.Node{
ID: nid,
},
Prefix: prefix,
Enabled: enabled,
IsPrimary: primary,
}
}
rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
ro := r(id, nid, prefix, enabled, primary)
return &ro
}
tests := []struct {
name string
failingRoute types.Route
routes types.Routes
isConnected types.NodeConnectedMap
want *failover
}{
{
name: "no-route",
failingRoute: types.Route{},
routes: types.Routes{},
want: nil,
},
{
name: "no-prime",
failingRoute: r(1, 1, ipp("10.0.0.0/24"), false, false),
// tests := []struct { routes: types.Routes{},
// name string want: nil,
// nodes types.Nodes },
{
name: "exit-node",
failingRoute: r(1, 1, ipp("0.0.0.0/0"), false, true),
routes: types.Routes{},
want: nil,
},
{
name: "no-failover-single-route",
failingRoute: r(1, 1, ipp("10.0.0.0/24"), false, true),
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), false, true),
},
want: nil,
},
{
name: "failover-primary",
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
},
isConnected: types.NodeConnectedMap{
1: false,
2: true,
},
want: &failover{
old: rp(1, 1, ipp("10.0.0.0/24"), true, false),
new: rp(2, 2, ipp("10.0.0.0/24"), true, true),
},
},
{
name: "failover-none-primary",
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, false),
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false),
},
want: nil,
},
{
name: "failover-primary-multi-route",
failingRoute: r(2, 2, ipp("10.0.0.0/24"), true, true),
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, false),
r(2, 2, ipp("10.0.0.0/24"), true, true),
r(3, 3, ipp("10.0.0.0/24"), true, false),
},
isConnected: types.NodeConnectedMap{
1: true,
2: true,
3: true,
},
want: &failover{
old: rp(2, 2, ipp("10.0.0.0/24"), true, false),
new: rp(1, 1, ipp("10.0.0.0/24"), true, true),
},
},
{
name: "failover-primary-no-online",
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 4, ipp("10.0.0.0/24"), true, false),
},
isConnected: types.NodeConnectedMap{
1: true,
4: false,
},
want: nil,
},
{
name: "failover-primary-one-not-online",
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 4, ipp("10.0.0.0/24"), true, false),
r(3, 2, ipp("10.0.0.0/24"), true, false),
},
isConnected: types.NodeConnectedMap{
1: false,
2: true,
4: false,
},
want: &failover{
old: rp(1, 1, ipp("10.0.0.0/24"), true, false),
new: rp(3, 2, ipp("10.0.0.0/24"), true, true),
},
},
{
name: "failover-primary-none-enabled",
failingRoute: r(1, 1, ipp("10.0.0.0/24"), true, true),
routes: types.Routes{
r(1, 1, ipp("10.0.0.0/24"), true, false),
r(2, 2, ipp("10.0.0.0/24"), false, true),
},
want: nil,
},
}
// routeID uint64 cmps := append(
// isConnected map[key.MachinePublic]bool util.Comparers,
cmp.Comparer(func(x, y types.IPPrefix) bool {
return netip.Prefix(x) == netip.Prefix(y)
}),
)
// wantMachineKey key.MachinePublic for _, tt := range tests {
// wantErr string t.Run(tt.name, func(t *testing.T) {
// }{ gotf := failoverRoute(tt.isConnected, &tt.failingRoute, tt.routes)
// {
// name: "single-route",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// Node: types.Node{
// MachineKey: machineKeys[0],
// },
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[0],
// },
// {
// name: "failover-simple",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "no-failover-offline",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: false,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// {
// name: "failover-to-online",
// nodes: types.Nodes{
// &types.Node{
// ID: 0,
// MachineKey: machineKeys[0],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 1,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: true,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// &types.Node{
// ID: 1,
// MachineKey: machineKeys[1],
// Routes: []types.Route{
// {
// Model: gorm.Model{
// ID: 2,
// },
// Prefix: ipp("10.0.0.0/24"),
// IsPrimary: false,
// },
// },
// Hostinfo: &tailcfg.Hostinfo{
// RoutableIPs: []netip.Prefix{
// netip.MustParsePrefix("10.0.0.0/24"),
// },
// },
// },
// },
// isConnected: map[key.MachinePublic]bool{
// machineKeys[0]: true,
// machineKeys[1]: true,
// },
// routeID: 1,
// wantMachineKey: machineKeys[1],
// },
// }
// for _, tt := range tests { if tt.want == nil && gotf != nil {
// t.Run(tt.name, func(t *testing.T) { t.Fatalf("expected nil, got %+v", gotf)
// datab, err := NewHeadscaleDatabase("sqlite3", ":memory:", false, []netip.Prefix{}, "") }
// assert.NoError(t, err)
// // bootstrap db if gotf == nil && tt.want != nil {
// datab.DB.Transaction(func(tx *gorm.DB) error { t.Fatalf("expected %+v, got nil", tt.want)
// for _, node := range tt.nodes { }
// err := tx.Save(node).Error
// if err != nil {
// return err
// }
// _, err = SaveNodeRoutes(tx, node) if tt.want != nil && gotf != nil {
// if err != nil { want := map[string]*types.Route{
// return err "new": tt.want.new,
// } "old": tt.want.old,
// } }
// return nil got := map[string]*types.Route{
// }) "new": gotf.new,
"old": gotf.old,
}
// got, err := Write(datab.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { if diff := cmp.Diff(want, got, cmps...); diff != "" {
// return DisableRoute(tx, tt.routeID, tt.isConnected) t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
// }) }
}
// // if (err.Error() != "") != tt.wantErr { })
// // t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) }
}
// // return
// // }
// if len(got.ChangeNodes) != 1 {
// t.Errorf("expected update with one machine, got %d", len(got.ChangeNodes))
// }
// if diff := cmp.Diff(tt.wantMachineKey, got.ChangeNodes[0].MachineKey, util.Comparers...); diff != "" {
// t.Errorf("DisableRoute() unexpected result (-want +got):\n%s", diff)
// }
// })
// }
// }

View file

@ -222,7 +222,7 @@ func (api headscaleV1APIServer) GetNode(
ctx context.Context, ctx context.Context,
request *v1.GetNodeRequest, request *v1.GetNodeRequest,
) (*v1.GetNodeResponse, error) { ) (*v1.GetNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId()) node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -231,7 +231,7 @@ func (api headscaleV1APIServer) GetNode(
// Populate the online field based on // Populate the online field based on
// currently connected nodes. // currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
return &v1.GetNodeResponse{Node: resp}, nil return &v1.GetNodeResponse{Node: resp}, nil
} }
@ -248,12 +248,12 @@ func (api headscaleV1APIServer) SetTags(
} }
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.SetTags(tx, request.GetNodeId(), request.GetTags()) err := db.SetTags(tx, types.NodeID(request.GetNodeId()), request.GetTags())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return db.GetNodeByID(tx, request.GetNodeId()) return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
}) })
if err != nil { if err != nil {
return &v1.SetTagsResponse{ return &v1.SetTagsResponse{
@ -261,15 +261,12 @@ func (api headscaleV1APIServer) SetTags(
}, status.Error(codes.InvalidArgument, err.Error()) }, status.Error(codes.InvalidArgument, err.Error())
} }
stateUpdate := types.StateUpdate{ ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node}, ChangeNodes: []types.NodeID{node.ID},
Message: "called from api.SetTags", Message: "called from api.SetTags",
} }, node.ID)
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -296,12 +293,12 @@ func (api headscaleV1APIServer) DeleteNode(
ctx context.Context, ctx context.Context,
request *v1.DeleteNodeRequest, request *v1.DeleteNodeRequest,
) (*v1.DeleteNodeResponse, error) { ) (*v1.DeleteNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId()) node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = api.h.db.DeleteNode( changedNodes, err := api.h.db.DeleteNode(
node, node,
api.h.nodeNotifier.ConnectedMap(), api.h.nodeNotifier.ConnectedMap(),
) )
@ -309,13 +306,17 @@ func (api headscaleV1APIServer) DeleteNode(
return nil, err return nil, err
} }
stateUpdate := types.StateUpdate{ ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved, Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, Removed: []types.NodeID{node.ID},
} })
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-deletenode", node.Hostname) if changedNodes != nil {
api.h.nodeNotifier.NotifyAll(ctx, stateUpdate) api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
})
} }
return &v1.DeleteNodeResponse{}, nil return &v1.DeleteNodeResponse{}, nil
@ -330,33 +331,27 @@ func (api headscaleV1APIServer) ExpireNode(
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
db.NodeSetExpiry( db.NodeSetExpiry(
tx, tx,
request.GetNodeId(), types.NodeID(request.GetNodeId()),
now, now,
) )
return db.GetNodeByID(tx, request.GetNodeId()) return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
selfUpdate := types.StateUpdate{ ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
Type: types.StateSelfUpdate, api.h.nodeNotifier.NotifyByMachineKey(
ChangeNodes: types.Nodes{node}, ctx,
} types.StateUpdate{
if selfUpdate.Valid() { Type: types.StateSelfUpdate,
ctx := types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) ChangeNodes: []types.NodeID{node.ID},
api.h.nodeNotifier.NotifyByMachineKey( },
ctx, node.ID)
selfUpdate,
node.MachineKey)
}
stateUpdate := types.StateUpdateExpire(node.ID, now) ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
if stateUpdate.Valid() { api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
ctx := types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -380,21 +375,18 @@ func (api headscaleV1APIServer) RenameNode(
return nil, err return nil, err
} }
return db.GetNodeByID(tx, request.GetNodeId()) return db.GetNodeByID(tx, types.NodeID(request.GetNodeId()))
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
stateUpdate := types.StateUpdate{ ctx = types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node}, ChangeNodes: []types.NodeID{node.ID},
Message: "called from api.RenameNode", Message: "called from api.RenameNode",
} }, node.ID)
if stateUpdate.Valid() {
ctx := types.NotifyCtx(ctx, "cli-renamenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -423,7 +415,7 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on // Populate the online field based on
// currently connected nodes. // currently connected nodes.
resp.Online = isConnected[node.MachineKey] resp.Online = isConnected[node.ID]
response[index] = resp response[index] = resp
} }
@ -446,7 +438,7 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on // Populate the online field based on
// currently connected nodes. // currently connected nodes.
resp.Online = isConnected[node.MachineKey] resp.Online = isConnected[node.ID]
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
node, node,
@ -463,7 +455,7 @@ func (api headscaleV1APIServer) MoveNode(
ctx context.Context, ctx context.Context,
request *v1.MoveNodeRequest, request *v1.MoveNodeRequest,
) (*v1.MoveNodeResponse, error) { ) (*v1.MoveNodeResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId()) node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -503,7 +495,7 @@ func (api headscaleV1APIServer) EnableRoute(
return nil, err return nil, err
} }
if update != nil && update.Valid() { if update != nil {
ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown") ctx := types.NotifyCtx(ctx, "cli-enableroute", "unknown")
api.h.nodeNotifier.NotifyAll( api.h.nodeNotifier.NotifyAll(
ctx, *update) ctx, *update)
@ -516,17 +508,19 @@ func (api headscaleV1APIServer) DisableRoute(
ctx context.Context, ctx context.Context,
request *v1.DisableRouteRequest, request *v1.DisableRouteRequest,
) (*v1.DisableRouteResponse, error) { ) (*v1.DisableRouteResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap() update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.ConnectedMap())
return db.DisableRoute(tx, request.GetRouteId(), isConnected)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
if update != nil && update.Valid() { if update != nil {
ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown") ctx := types.NotifyCtx(ctx, "cli-disableroute", "unknown")
api.h.nodeNotifier.NotifyAll(ctx, *update) api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: update,
})
} }
return &v1.DisableRouteResponse{}, nil return &v1.DisableRouteResponse{}, nil
@ -536,7 +530,7 @@ func (api headscaleV1APIServer) GetNodeRoutes(
ctx context.Context, ctx context.Context,
request *v1.GetNodeRoutesRequest, request *v1.GetNodeRoutesRequest,
) (*v1.GetNodeRoutesResponse, error) { ) (*v1.GetNodeRoutesResponse, error) {
node, err := api.h.db.GetNodeByID(request.GetNodeId()) node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -556,16 +550,19 @@ func (api headscaleV1APIServer) DeleteRoute(
request *v1.DeleteRouteRequest, request *v1.DeleteRouteRequest,
) (*v1.DeleteRouteResponse, error) { ) (*v1.DeleteRouteResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap() isConnected := api.h.nodeNotifier.ConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return db.DeleteRoute(tx, request.GetRouteId(), isConnected) return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
if update != nil && update.Valid() { if update != nil {
ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown") ctx := types.NotifyCtx(ctx, "cli-deleteroute", "unknown")
api.h.nodeNotifier.NotifyWithIgnore(ctx, *update) api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: update,
})
} }
return &v1.DeleteRouteResponse{}, nil return &v1.DeleteRouteResponse{}, nil

View file

@ -68,12 +68,6 @@ func (h *Headscale) KeyHandler(
Msg("could not get capability version") Msg("could not get capability version")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return return
} }
@ -82,19 +76,6 @@ func (h *Headscale) KeyHandler(
Str("handler", "/key"). Str("handler", "/key").
Int("cap_ver", int(capVer)). Int("cap_ver", int(capVer)).
Msg("New noise client") Msg("New noise client")
if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
// TS2021 (Tailscale v2 protocol) requires to have a different key // TS2021 (Tailscale v2 protocol) requires to have a different key
if capVer >= NoiseCapabilityVersion { if capVer >= NoiseCapabilityVersion {

View file

@ -16,12 +16,12 @@ import (
"time" "time"
mapset "github.com/deckarep/golang-set/v2" mapset "github.com/deckarep/golang-set/v2"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/exp/maps"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/smallzstd" "tailscale.com/smallzstd"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -51,21 +51,14 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
type Mapper struct { type Mapper struct {
// Configuration // Configuration
// TODO(kradalby): figure out if this is the format we want this in // TODO(kradalby): figure out if this is the format we want this in
derpMap *tailcfg.DERPMap db *db.HSDatabase
baseDomain string cfg *types.Config
dnsCfg *tailcfg.DNSConfig derpMap *tailcfg.DERPMap
logtail bool isLikelyConnected types.NodeConnectedMap
randomClientPort bool
uid string uid string
created time.Time created time.Time
seq uint64 seq uint64
// Map isnt concurrency safe, so we need to ensure
// only one func is accessing it over time.
mu sync.Mutex
peers map[uint64]*types.Node
patches map[uint64][]patch
} }
type patch struct { type patch struct {
@ -74,35 +67,22 @@ type patch struct {
} }
func NewMapper( func NewMapper(
node *types.Node, db *db.HSDatabase,
peers types.Nodes, cfg *types.Config,
derpMap *tailcfg.DERPMap, derpMap *tailcfg.DERPMap,
baseDomain string, isLikelyConnected types.NodeConnectedMap,
dnsCfg *tailcfg.DNSConfig,
logtail bool,
randomClientPort bool,
) *Mapper { ) *Mapper {
log.Debug().
Caller().
Str("node", node.Hostname).
Msg("creating new mapper")
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &Mapper{ return &Mapper{
derpMap: derpMap, db: db,
baseDomain: baseDomain, cfg: cfg,
dnsCfg: dnsCfg, derpMap: derpMap,
logtail: logtail, isLikelyConnected: isLikelyConnected,
randomClientPort: randomClientPort,
uid: uid, uid: uid,
created: time.Now(), created: time.Now(),
seq: 0, seq: 0,
// TODO: populate
peers: peers.IDMap(),
patches: make(map[uint64][]patch),
} }
} }
@ -207,11 +187,10 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
// It is a separate function to make testing easier. // It is a separate function to make testing easier.
func (m *Mapper) fullMapResponse( func (m *Mapper) fullMapResponse(
node *types.Node, node *types.Node,
peers types.Nodes,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
peers := nodeMapToList(m.peers)
resp, err := m.baseWithConfigMapResponse(node, pol, capVer) resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
if err != nil { if err != nil {
return nil, err return nil, err
@ -219,14 +198,13 @@ func (m *Mapper) fullMapResponse(
err = appendPeerChanges( err = appendPeerChanges(
resp, resp,
true, // full change
pol, pol,
node, node,
capVer, capVer,
peers, peers,
peers, peers,
m.baseDomain, m.cfg,
m.dnsCfg,
m.randomClientPort,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -240,35 +218,25 @@ func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) { ) ([]byte, error) {
m.mu.Lock() peers, err := m.ListPeers(node.ID)
defer m.mu.Unlock()
peers := maps.Keys(m.peers)
peersWithPatches := maps.Keys(m.patches)
slices.Sort(peers)
slices.Sort(peersWithPatches)
if len(peersWithPatches) > 0 {
log.Debug().
Str("node", node.Hostname).
Uints64("peers", peers).
Uints64("pending_patches", peersWithPatches).
Msgf("node requested full map response, but has pending patches")
}
resp, err := m.fullMapResponse(node, pol, mapRequest.Version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress) resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
} }
// LiteMapResponse returns a MapResponse for the given node. // ReadOnlyResponse returns a MapResponse for the given node.
// Lite means that the peers has been omitted, this is intended // Lite means that the peers has been omitted, this is intended
// to be used to answer MapRequests with OmitPeers set to true. // to be used to answer MapRequests with OmitPeers set to true.
func (m *Mapper) LiteMapResponse( func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
@ -279,18 +247,6 @@ func (m *Mapper) LiteMapResponse(
return nil, err return nil, err
} }
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules(
pol,
node,
nodeMapToList(m.peers),
)
if err != nil {
return nil, err
}
resp.PacketFilter = policy.ReduceFilterRules(node, rules)
resp.SSHPolicy = sshPolicy
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...) return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
} }
@ -320,50 +276,74 @@ func (m *Mapper) DERPMapResponse(
func (m *Mapper) PeerChangedResponse( func (m *Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
changed types.Nodes, changed map[types.NodeID]bool,
patches []*tailcfg.PeerChange,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Update our internal map.
for _, node := range changed {
if patches, ok := m.patches[node.ID]; ok {
// preserve online status in case the patch has an outdated one
online := node.IsOnline
for _, p := range patches {
// TODO(kradalby): Figure if this needs to be sorted by timestamp
node.ApplyPeerChange(p.change)
}
// Ensure the patches are not applied again later
delete(m.patches, node.ID)
node.IsOnline = online
}
m.peers[node.ID] = node
}
resp := m.baseMapResponse() resp := m.baseMapResponse()
err := appendPeerChanges( peers, err := m.ListPeers(node.ID)
if err != nil {
return nil, err
}
var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed {
if nodeChanged {
changedIDs = append(changedIDs, nodeID)
} else {
removedIDs = append(removedIDs, nodeID.NodeID())
}
}
changedNodes := make(types.Nodes, 0, len(changedIDs))
for _, peer := range peers {
if slices.Contains(changedIDs, peer.ID) {
changedNodes = append(changedNodes, peer)
}
}
err = appendPeerChanges(
&resp, &resp,
false, // partial change
pol, pol,
node, node,
mapRequest.Version, mapRequest.Version,
nodeMapToList(m.peers), peers,
changed, changedNodes,
m.baseDomain, m.cfg,
m.dnsCfg,
m.randomClientPort,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
resp.PeersRemoved = removedIDs
// Sending patches as a part of a PeersChanged response
// is technically not suppose to be done, but they are
// applied after the PeersChanged. The patch list
// should _only_ contain Nodes that are not in the
// PeersChanged or PeersRemoved list and the caller
// should filter them out.
//
// From tailcfg docs:
// These are applied after Peers* above, but in practice the
// control server should only send these on their own, without
// the Peers* fields also set.
if patches != nil {
resp.PeersChangedPatch = patches
}
// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
if err != nil {
return nil, err
}
resp.Node = tailnode
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...) return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
} }
@ -375,71 +355,12 @@ func (m *Mapper) PeerChangedPatchResponse(
changed []*tailcfg.PeerChange, changed []*tailcfg.PeerChange,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
) ([]byte, error) { ) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
sendUpdate := false
// patch the internal map
for _, change := range changed {
if peer, ok := m.peers[uint64(change.NodeID)]; ok {
peer.ApplyPeerChange(change)
sendUpdate = true
} else {
log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname)
p := patch{
timestamp: time.Now(),
change: change,
}
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
m.patches[uint64(change.NodeID)] = append(patches, p)
} else {
m.patches[uint64(change.NodeID)] = []patch{p}
}
}
}
if !sendUpdate {
return nil, nil
}
resp := m.baseMapResponse() resp := m.baseMapResponse()
resp.PeersChangedPatch = changed resp.PeersChangedPatch = changed
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
} }
// TODO(kradalby): We need some integration tests for this.
func (m *Mapper) PeerRemovedResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
removed []tailcfg.NodeID,
) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Some nodes might have been removed already
// so we dont want to ask downstream to remove
// twice, than can cause a panic in tailscaled.
notYetRemoved := []tailcfg.NodeID{}
// remove from our internal map
for _, id := range removed {
if _, ok := m.peers[uint64(id)]; ok {
notYetRemoved = append(notYetRemoved, id)
}
delete(m.peers, uint64(id))
delete(m.patches, uint64(id))
}
resp := m.baseMapResponse()
resp.PeersRemoved = notYetRemoved
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) marshalMapResponse( func (m *Mapper) marshalMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
resp *tailcfg.MapResponse, resp *tailcfg.MapResponse,
@ -469,10 +390,8 @@ func (m *Mapper) marshalMapResponse(
switch { switch {
case resp.Peers != nil && len(resp.Peers) > 0: case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full" responseType = "full"
case isSelfUpdate(messages...): case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
responseType = "self" responseType = "self"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil:
responseType = "lite"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0: case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
responseType = "changed" responseType = "changed"
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0: case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
@ -496,11 +415,11 @@ func (m *Mapper) marshalMapResponse(
panic(err) panic(err)
} }
now := time.Now().UnixNano() now := time.Now().Format("2006-01-02T15-04-05.999999999")
mapResponsePath := path.Join( mapResponsePath := path.Join(
mPath, mPath,
fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType), fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
) )
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@ -574,7 +493,7 @@ func (m *Mapper) baseWithConfigMapResponse(
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
tailnode, err := tailNode(node, capVer, pol, m.dnsCfg, m.baseDomain, m.randomClientPort) tailnode, err := tailNode(node, capVer, pol, m.cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -582,7 +501,7 @@ func (m *Mapper) baseWithConfigMapResponse(
resp.DERPMap = m.derpMap resp.DERPMap = m.derpMap
resp.Domain = m.baseDomain resp.Domain = m.cfg.BaseDomain
// Do not instruct clients to collect services we do not // Do not instruct clients to collect services we do not
// support or do anything with them // support or do anything with them
@ -591,12 +510,26 @@ func (m *Mapper) baseWithConfigMapResponse(
resp.KeepAlive = false resp.KeepAlive = false
resp.Debug = &tailcfg.Debug{ resp.Debug = &tailcfg.Debug{
DisableLogTail: !m.logtail, DisableLogTail: !m.cfg.LogTail.Enabled,
} }
return &resp, nil return &resp, nil
} }
func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
peers, err := m.db.ListPeers(nodeID)
if err != nil {
return nil, err
}
for _, peer := range peers {
online := m.isLikelyConnected[peer.ID]
peer.IsOnline = &online
}
return peers, nil
}
func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
ret := make(types.Nodes, 0) ret := make(types.Nodes, 0)
@ -612,42 +545,41 @@ func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes {
func appendPeerChanges( func appendPeerChanges(
resp *tailcfg.MapResponse, resp *tailcfg.MapResponse,
fullChange bool,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
node *types.Node, node *types.Node,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
peers types.Nodes, peers types.Nodes,
changed types.Nodes, changed types.Nodes,
baseDomain string, cfg *types.Config,
dnsCfg *tailcfg.DNSConfig,
randomClientPort bool,
) error { ) error {
fullChange := len(peers) == len(changed)
rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( packetFilter, err := pol.CompileFilterRules(append(peers, node))
pol, if err != nil {
node, return err
peers, }
)
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
if err != nil { if err != nil {
return err return err
} }
// If there are filter rules present, see if there are any nodes that cannot // If there are filter rules present, see if there are any nodes that cannot
// access eachother at all and remove them from the peers. // access eachother at all and remove them from the peers.
if len(rules) > 0 { if len(packetFilter) > 0 {
changed = policy.FilterNodesByACL(node, changed, rules) changed = policy.FilterNodesByACL(node, changed, packetFilter)
} }
profiles := generateUserProfiles(node, changed, baseDomain) profiles := generateUserProfiles(node, changed, cfg.BaseDomain)
dnsConfig := generateDNSConfig( dnsConfig := generateDNSConfig(
dnsCfg, cfg.DNSConfig,
baseDomain, cfg.BaseDomain,
node, node,
peers, peers,
) )
tailPeers, err := tailNodes(changed, capVer, pol, dnsCfg, baseDomain, randomClientPort) tailPeers, err := tailNodes(changed, capVer, pol, cfg)
if err != nil { if err != nil {
return err return err
} }
@ -663,19 +595,9 @@ func appendPeerChanges(
resp.PeersChanged = tailPeers resp.PeersChanged = tailPeers
} }
resp.DNSConfig = dnsConfig resp.DNSConfig = dnsConfig
resp.PacketFilter = policy.ReduceFilterRules(node, rules) resp.PacketFilter = policy.ReduceFilterRules(node, packetFilter)
resp.UserProfiles = profiles resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy resp.SSHPolicy = sshPolicy
return nil return nil
} }
func isSelfUpdate(messages ...string) bool {
for _, message := range messages {
if strings.Contains(message, types.SelfUpdateIdentifier) {
return true
}
}
return false
}

View file

@ -331,13 +331,10 @@ func Test_fullMapResponse(t *testing.T) {
node *types.Node node *types.Node
peers types.Nodes peers types.Nodes
baseDomain string derpMap *tailcfg.DERPMap
dnsConfig *tailcfg.DNSConfig cfg *types.Config
derpMap *tailcfg.DERPMap want *tailcfg.MapResponse
logtail bool wantErr bool
randomClientPort bool
want *tailcfg.MapResponse
wantErr bool
}{ }{
// { // {
// name: "empty-node", // name: "empty-node",
@ -349,15 +346,17 @@ func Test_fullMapResponse(t *testing.T) {
// wantErr: true, // wantErr: true,
// }, // },
{ {
name: "no-pol-no-peers-map-response", name: "no-pol-no-peers-map-response",
pol: &policy.ACLPolicy{}, pol: &policy.ACLPolicy{},
node: mini, node: mini,
peers: types.Nodes{}, peers: types.Nodes{},
baseDomain: "", derpMap: &tailcfg.DERPMap{},
dnsConfig: &tailcfg.DNSConfig{}, cfg: &types.Config{
derpMap: &tailcfg.DERPMap{}, BaseDomain: "",
logtail: false, DNSConfig: &tailcfg.DNSConfig{},
randomClientPort: false, LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{ want: &tailcfg.MapResponse{
Node: tailMini, Node: tailMini,
KeepAlive: false, KeepAlive: false,
@ -383,11 +382,13 @@ func Test_fullMapResponse(t *testing.T) {
peers: types.Nodes{ peers: types.Nodes{
peer1, peer1,
}, },
baseDomain: "", derpMap: &tailcfg.DERPMap{},
dnsConfig: &tailcfg.DNSConfig{}, cfg: &types.Config{
derpMap: &tailcfg.DERPMap{}, BaseDomain: "",
logtail: false, DNSConfig: &tailcfg.DNSConfig{},
randomClientPort: false, LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{ want: &tailcfg.MapResponse{
KeepAlive: false, KeepAlive: false,
Node: tailMini, Node: tailMini,
@ -424,11 +425,13 @@ func Test_fullMapResponse(t *testing.T) {
peer1, peer1,
peer2, peer2,
}, },
baseDomain: "", derpMap: &tailcfg.DERPMap{},
dnsConfig: &tailcfg.DNSConfig{}, cfg: &types.Config{
derpMap: &tailcfg.DERPMap{}, BaseDomain: "",
logtail: false, DNSConfig: &tailcfg.DNSConfig{},
randomClientPort: false, LogTail: types.LogTailConfig{Enabled: false},
RandomizeClientPort: false,
},
want: &tailcfg.MapResponse{ want: &tailcfg.MapResponse{
KeepAlive: false, KeepAlive: false,
Node: tailMini, Node: tailMini,
@ -463,17 +466,15 @@ func Test_fullMapResponse(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
mappy := NewMapper( mappy := NewMapper(
tt.node, nil,
tt.peers, tt.cfg,
tt.derpMap, tt.derpMap,
tt.baseDomain, nil,
tt.dnsConfig,
tt.logtail,
tt.randomClientPort,
) )
got, err := mappy.fullMapResponse( got, err := mappy.fullMapResponse(
tt.node, tt.node,
tt.peers,
tt.pol, tt.pol,
0, 0,
) )

View file

@ -3,12 +3,10 @@ package mapper
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"strconv"
"time" "time"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/samber/lo" "github.com/samber/lo"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -17,9 +15,7 @@ func tailNodes(
nodes types.Nodes, nodes types.Nodes,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
dnsConfig *tailcfg.DNSConfig, cfg *types.Config,
baseDomain string,
randomClientPort bool,
) ([]*tailcfg.Node, error) { ) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, len(nodes)) tNodes := make([]*tailcfg.Node, len(nodes))
@ -28,9 +24,7 @@ func tailNodes(
node, node,
capVer, capVer,
pol, pol,
dnsConfig, cfg,
baseDomain,
randomClientPort,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -48,9 +42,7 @@ func tailNode(
node *types.Node, node *types.Node,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
dnsConfig *tailcfg.DNSConfig, cfg *types.Config,
baseDomain string,
randomClientPort bool,
) (*tailcfg.Node, error) { ) (*tailcfg.Node, error) {
addrs := node.IPAddresses.Prefixes() addrs := node.IPAddresses.Prefixes()
@ -85,7 +77,7 @@ func tailNode(
keyExpiry = time.Time{} keyExpiry = time.Time{}
} }
hostname, err := node.GetFQDN(dnsConfig, baseDomain) hostname, err := node.GetFQDN(cfg.DNSConfig, cfg.BaseDomain)
if err != nil { if err != nil {
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
} }
@ -94,12 +86,10 @@ func tailNode(
tags = lo.Uniq(append(tags, node.ForcedTags...)) tags = lo.Uniq(append(tags, node.ForcedTags...))
tNode := tailcfg.Node{ tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID), // this is the actual ID ID: tailcfg.NodeID(node.ID), // this is the actual ID
StableID: tailcfg.StableNodeID( StableID: node.ID.StableID(),
strconv.FormatUint(node.ID, util.Base10), Name: hostname,
), // in headscale, unlike tailcontrol server, IDs are permanent Cap: capVer,
Name: hostname,
Cap: capVer,
User: tailcfg.UserID(node.UserID), User: tailcfg.UserID(node.UserID),
@ -133,7 +123,7 @@ func tailNode(
tailcfg.CapabilitySSH: []tailcfg.RawMessage{}, tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
} }
if randomClientPort { if cfg.RandomizeClientPort {
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
} }
} else { } else {
@ -143,7 +133,7 @@ func tailNode(
tailcfg.CapabilitySSH, tailcfg.CapabilitySSH,
} }
if randomClientPort { if cfg.RandomizeClientPort {
tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort) tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort)
} }
} }

View file

@ -182,13 +182,16 @@ func TestTailNode(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cfg := &types.Config{
BaseDomain: tt.baseDomain,
DNSConfig: tt.dnsConfig,
RandomizeClientPort: false,
}
got, err := tailNode( got, err := tailNode(
tt.node, tt.node,
0, 0,
tt.pol, tt.pol,
tt.dnsConfig, cfg,
tt.baseDomain,
false,
) )
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {

View file

@ -3,6 +3,7 @@ package hscontrol
import ( import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
@ -11,6 +12,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"gorm.io/gorm"
"tailscale.com/control/controlbase" "tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp" "tailscale.com/control/controlhttp"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -163,3 +165,135 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
return nil return nil
} }
const (
MinimumCapVersion tailcfg.CapabilityVersion = 58
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (ns *noiseServer) NoisePollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().
Str("handler", "NoisePollNetMap").
Msg("PollNetMapHandler called")
log.Trace().
Any("headers", req.Header).
Caller().
Msg("Headers")
body, _ := io.ReadAll(req.Body)
mapRequest := tailcfg.MapRequest{}
if err := json.Unmarshal(body, &mapRequest); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse MapRequest")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
// Reject unsupported versions
if mapRequest.Version < MinimumCapVersion {
log.Info().
Caller().
Int("min_version", int(MinimumCapVersion)).
Int("client_version", int(mapRequest.Version)).
Msg("unsupported client connected")
http.Error(writer, "Internal error", http.StatusBadRequest)
return
}
ns.nodeKey = mapRequest.NodeKey
node, err := ns.headscale.db.GetNodeByAnyKey(
ns.conn.Peer(),
mapRequest.NodeKey,
key.NodePublic{},
)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "NoisePollNetMap").
Uint64("node.id", node.ID.Uint64()).
Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusNotFound)
return
}
log.Error().
Str("handler", "NoisePollNetMap").
Uint64("node.id", node.ID.Uint64()).
Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
log.Debug().
Str("handler", "NoisePollNetMap").
Str("node", node.Hostname).
Int("cap_ver", int(mapRequest.Version)).
Uint64("node.id", node.ID.Uint64()).
Msg("A node sending a MapRequest with Noise protocol")
session := ns.headscale.newMapSession(req.Context(), mapRequest, writer, node)
// If a streaming mapSession exists for this node, close it
// and start a new one.
if session.isStreaming() {
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Aquiring lock to check stream")
ns.headscale.mapSessionMu.Lock()
if oldSession, ok := ns.headscale.mapSessions[node.ID]; ok {
log.Info().
Caller().
Uint64("node.id", node.ID.Uint64()).
Msg("Node has an open streaming session, replacing")
oldSession.close()
}
ns.headscale.mapSessions[node.ID] = session
ns.headscale.mapSessionMu.Unlock()
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Releasing lock to check stream")
}
session.serve()
if session.isStreaming() {
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Aquiring lock to remove stream")
ns.headscale.mapSessionMu.Lock()
delete(ns.headscale.mapSessions, node.ID)
ns.headscale.mapSessionMu.Unlock()
log.Debug().
Caller().
Uint64("node.id", node.ID.Uint64()).
Int("cap_ver", int(mapRequest.Version)).
Msg("Releasing lock to remove stream")
}
}

View file

@ -3,52 +3,51 @@ package notifier
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"strings" "strings"
"sync" "sync"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/types/key"
) )
type Notifier struct { type Notifier struct {
l sync.RWMutex l sync.RWMutex
nodes map[string]chan<- types.StateUpdate nodes map[types.NodeID]chan<- types.StateUpdate
connected map[key.MachinePublic]bool connected types.NodeConnectedMap
} }
func NewNotifier() *Notifier { func NewNotifier() *Notifier {
return &Notifier{ return &Notifier{
nodes: make(map[string]chan<- types.StateUpdate), nodes: make(map[types.NodeID]chan<- types.StateUpdate),
connected: make(map[key.MachinePublic]bool), connected: make(types.NodeConnectedMap),
} }
} }
func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) { func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node") log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node")
defer log.Trace(). defer log.Trace().
Caller(). Caller().
Str("key", machineKey.ShortString()). Uint64("node.id", nodeID.Uint64()).
Msg("releasing lock to add node") Msg("releasing lock to add node")
n.l.Lock() n.l.Lock()
defer n.l.Unlock() defer n.l.Unlock()
n.nodes[machineKey.String()] = c n.nodes[nodeID] = c
n.connected[machineKey] = true n.connected[nodeID] = true
log.Trace(). log.Trace().
Str("machine_key", machineKey.ShortString()). Uint64("node.id", nodeID.Uint64()).
Int("open_chans", len(n.nodes)). Int("open_chans", len(n.nodes)).
Msg("Added new channel") Msg("Added new channel")
} }
func (n *Notifier) RemoveNode(machineKey key.MachinePublic) { func (n *Notifier) RemoveNode(nodeID types.NodeID) {
log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node") log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to remove node")
defer log.Trace(). defer log.Trace().
Caller(). Caller().
Str("key", machineKey.ShortString()). Uint64("node.id", nodeID.Uint64()).
Msg("releasing lock to remove node") Msg("releasing lock to remove node")
n.l.Lock() n.l.Lock()
@ -58,26 +57,32 @@ func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
return return
} }
delete(n.nodes, machineKey.String()) delete(n.nodes, nodeID)
n.connected[machineKey] = false n.connected[nodeID] = false
log.Trace(). log.Trace().
Str("machine_key", machineKey.ShortString()). Uint64("node.id", nodeID.Uint64()).
Int("open_chans", len(n.nodes)). Int("open_chans", len(n.nodes)).
Msg("Removed channel") Msg("Removed channel")
} }
// IsConnected reports if a node is connected to headscale and has a // IsConnected reports if a node is connected to headscale and has a
// poll session open. // poll session open.
func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool { func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
return n.connected[machineKey] return n.connected[nodeID]
}
// IsLikelyConnected reports if a node is connected to headscale and has a
// poll session open, but doesnt lock, so might be wrong.
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return n.connected[nodeID]
} }
// TODO(kradalby): This returns a pointer and can be dangerous. // TODO(kradalby): This returns a pointer and can be dangerous.
func (n *Notifier) ConnectedMap() map[key.MachinePublic]bool { func (n *Notifier) ConnectedMap() types.NodeConnectedMap {
return n.connected return n.connected
} }
@ -88,19 +93,23 @@ func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
func (n *Notifier) NotifyWithIgnore( func (n *Notifier) NotifyWithIgnore(
ctx context.Context, ctx context.Context,
update types.StateUpdate, update types.StateUpdate,
ignore ...string, ignoreNodeIDs ...types.NodeID,
) { ) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
defer log.Trace(). defer log.Trace().
Caller(). Caller().
Interface("type", update.Type). Str("type", update.Type.String()).
Msg("releasing lock, finished notifying") Msg("releasing lock, finished notifying")
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
for key, c := range n.nodes { if update.Type == types.StatePeerChangedPatch {
if util.IsStringInSlice(ignore, key) { log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT")
}
for nodeID, c := range n.nodes {
if slices.Contains(ignoreNodeIDs, nodeID) {
continue continue
} }
@ -108,17 +117,17 @@ func (n *Notifier) NotifyWithIgnore(
case <-ctx.Done(): case <-ctx.Done():
log.Error(). log.Error().
Err(ctx.Err()). Err(ctx.Err()).
Str("mkey", key). Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")). Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")). Any("origin-hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled") Msgf("update not sent, context cancelled")
return return
case c <- update: case c <- update:
log.Trace(). log.Trace().
Str("mkey", key). Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")). Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")). Any("origin-hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan") Msgf("update successfully sent on chan")
} }
} }
@ -127,33 +136,33 @@ func (n *Notifier) NotifyWithIgnore(
func (n *Notifier) NotifyByMachineKey( func (n *Notifier) NotifyByMachineKey(
ctx context.Context, ctx context.Context,
update types.StateUpdate, update types.StateUpdate,
mKey key.MachinePublic, nodeID types.NodeID,
) { ) {
log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
defer log.Trace(). defer log.Trace().
Caller(). Caller().
Interface("type", update.Type). Str("type", update.Type.String()).
Msg("releasing lock, finished notifying") Msg("releasing lock, finished notifying")
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
if c, ok := n.nodes[mKey.String()]; ok { if c, ok := n.nodes[nodeID]; ok {
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Error(). log.Error().
Err(ctx.Err()). Err(ctx.Err()).
Str("mkey", mKey.String()). Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")). Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")). Any("origin-hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled") Msgf("update not sent, context cancelled")
return return
case c <- update: case c <- update:
log.Trace(). log.Trace().
Str("mkey", mKey.String()). Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")). Any("origin", ctx.Value("origin")).
Any("hostname", ctx.Value("hostname")). Any("origin-hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan") Msgf("update successfully sent on chan")
} }
} }
@ -166,7 +175,7 @@ func (n *Notifier) String() string {
str := []string{"Notifier, in map:\n"} str := []string{"Notifier, in map:\n"}
for k, v := range n.nodes { for k, v := range n.nodes {
str = append(str, fmt.Sprintf("\t%s: %v\n", k, v)) str = append(str, fmt.Sprintf("\t%d: %v\n", k, v))
} }
return strings.Join(str, "") return strings.Join(str, "")

View file

@ -537,11 +537,8 @@ func (h *Headscale) validateNodeForOIDCCallback(
util.LogErr(err, "Failed to write response") util.LogErr(err, "Failed to write response")
} }
stateUpdate := types.StateUpdateExpire(node.ID, expiry) ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
if stateUpdate.Valid() { h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
ctx := types.NotifyCtx(context.Background(), "oidc-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
}
return nil, true, nil return nil, true, nil
} }

View file

@ -114,7 +114,7 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
return &policy, nil return &policy, nil
} }
func GenerateFilterAndSSHRules( func GenerateFilterAndSSHRulesForTests(
policy *ACLPolicy, policy *ACLPolicy,
node *types.Node, node *types.Node,
peers types.Nodes, peers types.Nodes,
@ -124,40 +124,31 @@ func GenerateFilterAndSSHRules(
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
} }
rules, err := policy.generateFilterRules(node, peers) rules, err := policy.CompileFilterRules(append(peers, node))
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
var sshPolicy *tailcfg.SSHPolicy sshPolicy, err := policy.CompileSSHPolicy(node, peers)
sshRules, err := policy.generateSSHRules(node, peers)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
log.Trace().
Interface("SSH", sshRules).
Str("node", node.GivenName).
Msg("SSH rules")
if sshPolicy == nil {
sshPolicy = &tailcfg.SSHPolicy{}
}
sshPolicy.Rules = sshRules
return rules, sshPolicy, nil return rules, sshPolicy, nil
} }
// generateFilterRules takes a set of nodes and an ACLPolicy and generates a // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) generateFilterRules( func (pol *ACLPolicy) CompileFilterRules(
node *types.Node, nodes types.Nodes,
peers types.Nodes,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
if pol == nil {
return tailcfg.FilterAllowAll, nil
}
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
nodes := append(peers, node)
for index, acl := range pol.ACLs { for index, acl := range pol.ACLs {
if acl.Action != "accept" { if acl.Action != "accept" {
@ -279,10 +270,14 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
return ret return ret
} }
func (pol *ACLPolicy) generateSSHRules( func (pol *ACLPolicy) CompileSSHPolicy(
node *types.Node, node *types.Node,
peers types.Nodes, peers types.Nodes,
) ([]*tailcfg.SSHRule, error) { ) (*tailcfg.SSHPolicy, error) {
if pol == nil {
return nil, nil
}
rules := []*tailcfg.SSHRule{} rules := []*tailcfg.SSHRule{}
acceptAction := tailcfg.SSHAction{ acceptAction := tailcfg.SSHAction{
@ -393,7 +388,9 @@ func (pol *ACLPolicy) generateSSHRules(
}) })
} }
return rules, nil return &tailcfg.SSHPolicy{
Rules: rules,
}, nil
} }
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {

View file

@ -385,11 +385,12 @@ acls:
return return
} }
rules, err := pol.generateFilterRules(&types.Node{ rules, err := pol.CompileFilterRules(types.Nodes{
IPAddresses: types.NodeAddresses{ &types.Node{
netip.MustParseAddr("100.100.100.100"), IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.100.100.100"),
},
}, },
}, types.Nodes{
&types.Node{ &types.Node{
IPAddresses: types.NodeAddresses{ IPAddresses: types.NodeAddresses{
netip.MustParseAddr("200.200.200.200"), netip.MustParseAddr("200.200.200.200"),
@ -546,7 +547,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(&types.Node{}, types.Nodes{}) rules, err := pol.CompileFilterRules(types.Nodes{})
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil) c.Assert(rules, check.IsNil)
} }
@ -562,7 +563,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
} }
@ -581,7 +582,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
} }
@ -597,7 +598,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
}, },
} }
_, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
} }
@ -1724,8 +1725,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
pol ACLPolicy pol ACLPolicy
} }
type args struct { type args struct {
node *types.Node nodes types.Nodes
peers types.Nodes
} }
tests := []struct { tests := []struct {
name string name string
@ -1755,13 +1755,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
}, },
}, },
args: args{ args: args{
node: &types.Node{ nodes: types.Nodes{
IPAddresses: types.NodeAddresses{ &types.Node{
netip.MustParseAddr("100.64.0.1"), IPAddresses: types.NodeAddresses{
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), netip.MustParseAddr("100.64.0.1"),
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
},
}, },
}, },
peers: types.Nodes{},
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
@ -1800,14 +1801,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
}, },
}, },
args: args{ args: args{
node: &types.Node{ nodes: types.Nodes{
IPAddresses: types.NodeAddresses{ &types.Node{
netip.MustParseAddr("100.64.0.1"), IPAddresses: types.NodeAddresses{
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), netip.MustParseAddr("100.64.0.1"),
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
},
User: types.User{Name: "mickael"},
}, },
User: types.User{Name: "mickael"},
},
peers: types.Nodes{
&types.Node{ &types.Node{
IPAddresses: types.NodeAddresses{ IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("100.64.0.2"),
@ -1846,9 +1847,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.generateFilterRules( got, err := tt.field.pol.CompileFilterRules(
tt.args.node, tt.args.nodes,
tt.args.peers,
) )
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
@ -1980,9 +1980,8 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
rules, _ := tt.pol.generateFilterRules( rules, _ := tt.pol.CompileFilterRules(
tt.node, append(tt.peers, tt.node),
tt.peers,
) )
got := ReduceFilterRules(tt.node, rules) got := ReduceFilterRules(tt.node, rules)
@ -2883,7 +2882,7 @@ func TestSSHRules(t *testing.T) {
node types.Node node types.Node
peers types.Nodes peers types.Nodes
pol ACLPolicy pol ACLPolicy
want []*tailcfg.SSHRule want *tailcfg.SSHPolicy
}{ }{
{ {
name: "peers-can-connect", name: "peers-can-connect",
@ -2946,7 +2945,7 @@ func TestSSHRules(t *testing.T) {
}, },
}, },
}, },
want: []*tailcfg.SSHRule{ want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{
{ {
Principals: []*tailcfg.SSHPrincipal{ Principals: []*tailcfg.SSHPrincipal{
{ {
@ -2991,7 +2990,7 @@ func TestSSHRules(t *testing.T) {
}, },
Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true}, Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
}, },
}, }},
}, },
{ {
name: "peers-cannot-connect", name: "peers-cannot-connect",
@ -3042,13 +3041,13 @@ func TestSSHRules(t *testing.T) {
}, },
}, },
}, },
want: []*tailcfg.SSHRule{}, want: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.generateSSHRules(&tt.node, tt.peers) got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
assert.NoError(t, err) assert.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
@ -3155,7 +3154,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3206,7 +3205,7 @@ func TestInvalidTagValidUser(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3265,7 +3264,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3335,7 +3334,7 @@ func TestValidTagInvalidUser(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{nodes2}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{

File diff suppressed because it is too large Load diff

View file

@ -1,96 +0,0 @@
package hscontrol
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
MinimumCapVersion tailcfg.CapabilityVersion = 58
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (ns *noiseServer) NoisePollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().
Str("handler", "NoisePollNetMap").
Msg("PollNetMapHandler called")
log.Trace().
Any("headers", req.Header).
Caller().
Msg("Headers")
body, _ := io.ReadAll(req.Body)
mapRequest := tailcfg.MapRequest{}
if err := json.Unmarshal(body, &mapRequest); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse MapRequest")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
// Reject unsupported versions
if mapRequest.Version < MinimumCapVersion {
log.Info().
Caller().
Int("min_version", int(MinimumCapVersion)).
Int("client_version", int(mapRequest.Version)).
Msg("unsupported client connected")
http.Error(writer, "Internal error", http.StatusBadRequest)
return
}
ns.nodeKey = mapRequest.NodeKey
node, err := ns.headscale.db.GetNodeByAnyKey(
ns.conn.Peer(),
mapRequest.NodeKey,
key.NodePublic{},
)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "NoisePollNetMap").
Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusNotFound)
return
}
log.Error().
Str("handler", "NoisePollNetMap").
Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
log.Debug().
Str("handler", "NoisePollNetMap").
Str("node", node.Hostname).
Int("cap_ver", int(mapRequest.Version)).
Msg("A node sending a MapRequest with Noise protocol")
ns.headscale.handlePoll(writer, req.Context(), node, mapRequest)
}

View file

@ -90,6 +90,25 @@ func (i StringList) Value() (driver.Value, error) {
type StateUpdateType int type StateUpdateType int
func (su StateUpdateType) String() string {
switch su {
case StateFullUpdate:
return "StateFullUpdate"
case StatePeerChanged:
return "StatePeerChanged"
case StatePeerChangedPatch:
return "StatePeerChangedPatch"
case StatePeerRemoved:
return "StatePeerRemoved"
case StateSelfUpdate:
return "StateSelfUpdate"
case StateDERPUpdated:
return "StateDERPUpdated"
}
return "unknown state update type"
}
const ( const (
StateFullUpdate StateUpdateType = iota StateFullUpdate StateUpdateType = iota
// StatePeerChanged is used for updates that needs // StatePeerChanged is used for updates that needs
@ -118,7 +137,7 @@ type StateUpdate struct {
// ChangeNodes must be set when Type is StatePeerAdded // ChangeNodes must be set when Type is StatePeerAdded
// and StatePeerChanged and contains the full node // and StatePeerChanged and contains the full node
// object for added nodes. // object for added nodes.
ChangeNodes Nodes ChangeNodes []NodeID
// ChangePatches must be set when Type is StatePeerChangedPatch // ChangePatches must be set when Type is StatePeerChangedPatch
// and contains a populated PeerChange object. // and contains a populated PeerChange object.
@ -127,7 +146,7 @@ type StateUpdate struct {
// Removed must be set when Type is StatePeerRemoved and // Removed must be set when Type is StatePeerRemoved and
// contain a list of the nodes that has been removed from // contain a list of the nodes that has been removed from
// the network. // the network.
Removed []tailcfg.NodeID Removed []NodeID
// DERPMap must be set when Type is StateDERPUpdated and // DERPMap must be set when Type is StateDERPUpdated and
// contain the new DERP Map. // contain the new DERP Map.
@ -138,39 +157,6 @@ type StateUpdate struct {
Message string Message string
} }
// Valid reports if a StateUpdate is correctly filled and
// panics if the mandatory fields for a type is not
// filled.
// Reports true if valid.
func (su *StateUpdate) Valid() bool {
switch su.Type {
case StatePeerChanged:
if su.ChangeNodes == nil {
panic("Mandatory field ChangeNodes is not set on StatePeerChanged update")
}
case StatePeerChangedPatch:
if su.ChangePatches == nil {
panic("Mandatory field ChangePatches is not set on StatePeerChangedPatch update")
}
case StatePeerRemoved:
if su.Removed == nil {
panic("Mandatory field Removed is not set on StatePeerRemove update")
}
case StateSelfUpdate:
if su.ChangeNodes == nil || len(su.ChangeNodes) != 1 {
panic(
"Mandatory field ChangeNodes is not set for StateSelfUpdate or has more than one node",
)
}
case StateDERPUpdated:
if su.DERPMap == nil {
panic("Mandatory field DERPMap is not set on StateDERPUpdated update")
}
}
return true
}
// Empty reports if there are any updates in the StateUpdate. // Empty reports if there are any updates in the StateUpdate.
func (su *StateUpdate) Empty() bool { func (su *StateUpdate) Empty() bool {
switch su.Type { switch su.Type {
@ -185,12 +171,12 @@ func (su *StateUpdate) Empty() bool {
return false return false
} }
func StateUpdateExpire(nodeID uint64, expiry time.Time) StateUpdate { func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
return StateUpdate{ return StateUpdate{
Type: StatePeerChangedPatch, Type: StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{ ChangePatches: []*tailcfg.PeerChange{
{ {
NodeID: tailcfg.NodeID(nodeID), NodeID: nodeID.NodeID(),
KeyExpiry: &expiry, KeyExpiry: &expiry,
}, },
}, },

View file

@ -69,6 +69,8 @@ type Config struct {
CLI CLIConfig CLI CLIConfig
ACL ACLConfig ACL ACLConfig
Tuning Tuning
} }
type SqliteConfig struct { type SqliteConfig struct {
@ -161,6 +163,11 @@ type LogConfig struct {
Level zerolog.Level Level zerolog.Level
} }
type Tuning struct {
BatchChangeDelay time.Duration
NodeMapSessionBufferedChanSize int
}
func LoadConfig(path string, isFile bool) error { func LoadConfig(path string, isFile bool) error {
if isFile { if isFile {
viper.SetConfigFile(path) viper.SetConfigFile(path)
@ -220,6 +227,9 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("node_update_check_interval", "10s") viper.SetDefault("node_update_check_interval", "10s")
viper.SetDefault("tuning.batch_change_delay", "800ms")
viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30)
if IsCLIConfigured() { if IsCLIConfigured() {
return nil return nil
} }
@ -719,6 +729,12 @@ func GetHeadscaleConfig() (*Config, error) {
}, },
Log: GetLogConfig(), Log: GetLogConfig(),
// TODO(kradalby): Document these settings when more stable
Tuning: Tuning{
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
NodeMapSessionBufferedChanSize: viper.GetInt("tuning.node_mapsession_buffered_chan_size"),
},
}, nil }, nil
} }

View file

@ -7,11 +7,13 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"sort" "sort"
"strconv"
"strings" "strings"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher" "github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"go4.org/netipx" "go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
@ -27,9 +29,24 @@ var (
ErrNodeUserHasNoName = errors.New("node user has no name") ErrNodeUserHasNoName = errors.New("node user has no name")
) )
type NodeID uint64
type NodeConnectedMap map[NodeID]bool
func (id NodeID) StableID() tailcfg.StableNodeID {
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
}
func (id NodeID) NodeID() tailcfg.NodeID {
return tailcfg.NodeID(id)
}
func (id NodeID) Uint64() uint64 {
return uint64(id)
}
// Node is a Headscale client. // Node is a Headscale client.
type Node struct { type Node struct {
ID uint64 `gorm:"primary_key"` ID NodeID `gorm:"primary_key"`
// MachineKeyDatabaseField is the string representation of MachineKey // MachineKeyDatabaseField is the string representation of MachineKey
// it is _only_ used for reading and writing the key to the // it is _only_ used for reading and writing the key to the
@ -198,7 +215,7 @@ func (node Node) IsExpired() bool {
return false return false
} }
return time.Now().UTC().After(*node.Expiry) return time.Since(*node.Expiry) > 0
} }
// IsEphemeral returns if the node is registered as an Ephemeral node. // IsEphemeral returns if the node is registered as an Ephemeral node.
@ -319,7 +336,7 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
func (node *Node) Proto() *v1.Node { func (node *Node) Proto() *v1.Node {
nodeProto := &v1.Node{ nodeProto := &v1.Node{
Id: node.ID, Id: uint64(node.ID),
MachineKey: node.MachineKey.String(), MachineKey: node.MachineKey.String(),
NodeKey: node.NodeKey.String(), NodeKey: node.NodeKey.String(),
@ -486,8 +503,8 @@ func (nodes Nodes) String() string {
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
} }
func (nodes Nodes) IDMap() map[uint64]*Node { func (nodes Nodes) IDMap() map[NodeID]*Node {
ret := map[uint64]*Node{} ret := map[NodeID]*Node{}
for _, node := range nodes { for _, node := range nodes {
ret[node.ID] = node ret[node.ID] = node

View file

@ -83,7 +83,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
@ -142,7 +142,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()

View file

@ -53,7 +53,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
@ -92,7 +92,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()

View file

@ -65,7 +65,7 @@ func TestPingAllByIP(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
@ -103,7 +103,7 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
@ -135,7 +135,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
clientIPs := make(map[TailscaleClient][]netip.Addr) clientIPs := make(map[TailscaleClient][]netip.Addr)
for _, client := range allClients { for _, client := range allClients {
@ -176,7 +176,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allClients, err = scenario.ListTailscaleClients() allClients, err = scenario.ListTailscaleClients()
assertNoErrListClients(t, err) assertNoErrListClients(t, err)
@ -329,7 +329,7 @@ func TestPingAllByHostname(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allHostnames, err := scenario.ListTailscaleClientsFQDNs() allHostnames, err := scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err) assertNoErrListFQDN(t, err)
@ -539,7 +539,7 @@ func TestResolveMagicDNS(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
// Poor mans cache // Poor mans cache
_, err = scenario.ListTailscaleClientsFQDNs() _, err = scenario.ListTailscaleClientsFQDNs()
@ -609,7 +609,7 @@ func TestExpireNode(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
@ -711,7 +711,7 @@ func TestExpireNode(t *testing.T) {
} }
} }
func TestNodeOnlineLastSeenStatus(t *testing.T) { func TestNodeOnlineStatus(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
@ -723,7 +723,7 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
"user1": len(MustTestVersions), "user1": len(MustTestVersions),
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("onlinelastseen")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("online"))
assertNoErrHeadscaleEnv(t, err) assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients() allClients, err := scenario.ListTailscaleClients()
@ -735,7 +735,7 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
err = scenario.WaitForTailscaleSync() err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err) assertNoErrSync(t, err)
assertClientsState(t, allClients) // assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String() return x.String()
@ -755,8 +755,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
headscale, err := scenario.Headscale() headscale, err := scenario.Headscale()
assertNoErr(t, err) assertNoErr(t, err)
keepAliveInterval := 60 * time.Second
// Duration is chosen arbitrarily, 10m is reported in #1561 // Duration is chosen arbitrarily, 10m is reported in #1561
testDuration := 12 * time.Minute testDuration := 12 * time.Minute
start := time.Now() start := time.Now()
@ -780,11 +778,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
err = json.Unmarshal([]byte(result), &nodes) err = json.Unmarshal([]byte(result), &nodes)
assertNoErr(t, err) assertNoErr(t, err)
now := time.Now()
// Threshold with some leeway
lastSeenThreshold := now.Add(-keepAliveInterval - (10 * time.Second))
// Verify that headscale reports the nodes as online // Verify that headscale reports the nodes as online
for _, node := range nodes { for _, node := range nodes {
// All nodes should be online // All nodes should be online
@ -795,18 +788,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
node.GetName(), node.GetName(),
time.Since(start), time.Since(start),
) )
lastSeen := node.GetLastSeen().AsTime()
// All nodes should have been last seen between now and the keepAliveInterval
assert.Truef(
t,
lastSeen.After(lastSeenThreshold),
"node (%s) lastSeen (%v) was not %s after the threshold (%v)",
node.GetName(),
lastSeen,
keepAliveInterval,
lastSeenThreshold,
)
} }
// Verify that all nodes report all nodes to be online // Verify that all nodes report all nodes to be online
@ -834,15 +815,6 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
client.Hostname(), client.Hostname(),
time.Since(start), time.Since(start),
) )
// from docs: last seen to tailcontrol; only present if offline
// assert.Nilf(
// t,
// peerStatus.LastSeen,
// "expected node %s to not have LastSeen set, got %s",
// peerStatus.HostName,
// peerStatus.LastSeen,
// )
} }
} }
@ -850,3 +822,87 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
} }
// TestPingAllByIPManyUpDown is a variant of the PingAll
// test which will take the tailscale node up and down
// five times ensuring they are able to restablish connectivity.
func TestPingAllByIPManyUpDown(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario()
assertNoErr(t, err)
defer scenario.Shutdown()
// TODO(kradalby): it does not look like the user thing works, only second
// get created? maybe only when many?
spec := map[string]int{
"user1": len(MustTestVersions),
"user2": len(MustTestVersions),
}
headscaleConfig := map[string]string{
"HEADSCALE_DERP_URLS": "",
"HEADSCALE_DERP_SERVER_ENABLED": "true",
"HEADSCALE_DERP_SERVER_REGION_ID": "999",
"HEADSCALE_DERP_SERVER_REGION_CODE": "headscale",
"HEADSCALE_DERP_SERVER_REGION_NAME": "Headscale Embedded DERP",
"HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR": "0.0.0.0:3478",
"HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH": "/tmp/derp.key",
// Envknob for enabling DERP debug logs
"DERP_DEBUG_LOGS": "true",
"DERP_PROBER_DEBUG_LOGS": "true",
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{},
hsic.WithTestName("pingallbyip"),
hsic.WithConfigEnv(headscaleConfig),
hsic.WithTLS(),
hsic.WithHostnameAsServerURL(),
)
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
// assertClientsState(t, allClients)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
for run := range 3 {
t.Logf("Starting DownUpPing run %d", run+1)
for _, client := range allClients {
t.Logf("taking down %q", client.Hostname())
client.Down()
}
time.Sleep(5 * time.Second)
for _, client := range allClients {
t.Logf("bringing up %q", client.Hostname())
client.Up()
}
time.Sleep(5 * time.Second)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
}

View file

@ -212,7 +212,11 @@ func TestEnablingRoutes(t *testing.T) {
if route.GetId() == routeToBeDisabled.GetId() { if route.GetId() == routeToBeDisabled.GetId() {
assert.Equal(t, false, route.GetEnabled()) assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
// since this is the only route of this cidr,
// it will not failover, and remain Primary
// until something can replace it.
assert.Equal(t, true, route.GetIsPrimary())
} else { } else {
assert.Equal(t, true, route.GetEnabled()) assert.Equal(t, true, route.GetEnabled())
assert.Equal(t, true, route.GetIsPrimary()) assert.Equal(t, true, route.GetIsPrimary())
@ -291,6 +295,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
client := allClients[2] client := allClients[2]
t.Logf("Advertise route from r1 (%s) and r2 (%s), making it HA, n1 is primary", subRouter1.Hostname(), subRouter2.Hostname())
// advertise HA route on node 1 and 2 // advertise HA route on node 1 and 2
// ID 1 will be primary // ID 1 will be primary
// ID 2 will be secondary // ID 2 will be secondary
@ -384,12 +389,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Node 1 is primary // Node 1 is primary
assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
assert.Equal(t, true, enablingRoutes[0].GetEnabled()) assert.Equal(t, true, enablingRoutes[0].GetEnabled())
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary()) assert.Equal(t, true, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary")
// Node 2 is not primary // Node 2 is not primary
assert.Equal(t, true, enablingRoutes[1].GetAdvertised()) assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
assert.Equal(t, true, enablingRoutes[1].GetEnabled()) assert.Equal(t, true, enablingRoutes[1].GetEnabled())
assert.Equal(t, false, enablingRoutes[1].GetIsPrimary()) assert.Equal(t, false, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary")
// Verify that the client has routes from the primary machine // Verify that the client has routes from the primary machine
srs1, err := subRouter1.Status() srs1, err := subRouter1.Status()
@ -401,6 +406,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey] srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
assertNotNil(t, srs1PeerStatus.PrimaryRoutes) assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -411,7 +419,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
) )
// Take down the current primary // Take down the current primary
t.Logf("taking down subnet router 1 (%s)", subRouter1.Hostname()) t.Logf("taking down subnet router r1 (%s)", subRouter1.Hostname())
t.Logf("expecting r2 (%s) to take over as primary", subRouter2.Hostname())
err = subRouter1.Down() err = subRouter1.Down()
assertNoErr(t, err) assertNoErr(t, err)
@ -435,15 +444,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfterMove[0].GetAdvertised()) assert.Equal(t, true, routesAfterMove[0].GetAdvertised())
assert.Equal(t, true, routesAfterMove[0].GetEnabled()) assert.Equal(t, true, routesAfterMove[0].GetEnabled())
assert.Equal(t, false, routesAfterMove[0].GetIsPrimary()) assert.Equal(t, false, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary")
// Node 2 is primary // Node 2 is primary
assert.Equal(t, true, routesAfterMove[1].GetAdvertised()) assert.Equal(t, true, routesAfterMove[1].GetAdvertised())
assert.Equal(t, true, routesAfterMove[1].GetEnabled()) assert.Equal(t, true, routesAfterMove[1].GetEnabled())
assert.Equal(t, true, routesAfterMove[1].GetIsPrimary()) assert.Equal(t, true, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary")
// TODO(kradalby): Check client status
// Route is expected to be on SR2
srs2, err = subRouter2.Status() srs2, err = subRouter2.Status()
@ -453,6 +459,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
assert.True(t, srs2PeerStatus.Online, "r1 down, r2 up")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assertNotNil(t, srs2PeerStatus.PrimaryRoutes) assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
@ -465,7 +474,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
} }
// Take down subnet router 2, leaving none available // Take down subnet router 2, leaving none available
t.Logf("taking down subnet router 2 (%s)", subRouter2.Hostname()) t.Logf("taking down subnet router r2 (%s)", subRouter2.Hostname())
t.Logf("expecting r2 (%s) to remain primary, no other available", subRouter2.Hostname())
err = subRouter2.Down() err = subRouter2.Down()
assertNoErr(t, err) assertNoErr(t, err)
@ -489,14 +499,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised()) assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised())
assert.Equal(t, true, routesAfterBothDown[0].GetEnabled()) assert.Equal(t, true, routesAfterBothDown[0].GetEnabled())
assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary()) assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
// Node 2 is primary // Node 2 is primary
// if the node goes down, but no other suitable route is // if the node goes down, but no other suitable route is
// available, keep the last known good route. // available, keep the last known good route.
assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised()) assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised())
assert.Equal(t, true, routesAfterBothDown[1].GetEnabled()) assert.Equal(t, true, routesAfterBothDown[1].GetEnabled())
assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary()) assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
// TODO(kradalby): Check client status // TODO(kradalby): Check client status
// Both are expected to be down // Both are expected to be down
@ -508,6 +518,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.False(t, srs1PeerStatus.Online, "r1 down, r2 down")
assert.False(t, srs2PeerStatus.Online, "r1 down, r2 down")
assert.Nil(t, srs1PeerStatus.PrimaryRoutes) assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assertNotNil(t, srs2PeerStatus.PrimaryRoutes) assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
@ -520,7 +533,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
} }
// Bring up subnet router 1, making the route available from there. // Bring up subnet router 1, making the route available from there.
t.Logf("bringing up subnet router 1 (%s)", subRouter1.Hostname()) t.Logf("bringing up subnet router r1 (%s)", subRouter1.Hostname())
t.Logf("expecting r1 (%s) to take over as primary (only one online)", subRouter1.Hostname())
err = subRouter1.Up() err = subRouter1.Up()
assertNoErr(t, err) assertNoErr(t, err)
@ -544,12 +558,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Node 1 is primary // Node 1 is primary
assert.Equal(t, true, routesAfter1Up[0].GetAdvertised()) assert.Equal(t, true, routesAfter1Up[0].GetAdvertised())
assert.Equal(t, true, routesAfter1Up[0].GetEnabled()) assert.Equal(t, true, routesAfter1Up[0].GetEnabled())
assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary()) assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
// Node 2 is not primary // Node 2 is not primary
assert.Equal(t, true, routesAfter1Up[1].GetAdvertised()) assert.Equal(t, true, routesAfter1Up[1].GetAdvertised())
assert.Equal(t, true, routesAfter1Up[1].GetEnabled()) assert.Equal(t, true, routesAfter1Up[1].GetEnabled())
assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary()) assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -558,6 +572,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.True(t, srs1PeerStatus.Online, "r1 is back up, r2 down")
assert.False(t, srs2PeerStatus.Online, "r1 is back up, r2 down")
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -570,7 +587,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
} }
// Bring up subnet router 2, should result in no change. // Bring up subnet router 2, should result in no change.
t.Logf("bringing up subnet router 2 (%s)", subRouter2.Hostname()) t.Logf("bringing up subnet router r2 (%s)", subRouter2.Hostname())
t.Logf("both online, expecting r1 (%s) to still be primary (no flapping)", subRouter1.Hostname())
err = subRouter2.Up() err = subRouter2.Up()
assertNoErr(t, err) assertNoErr(t, err)
@ -594,12 +612,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfter2Up[0].GetAdvertised()) assert.Equal(t, true, routesAfter2Up[0].GetAdvertised())
assert.Equal(t, true, routesAfter2Up[0].GetEnabled()) assert.Equal(t, true, routesAfter2Up[0].GetEnabled())
assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary()) assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
// Node 2 is primary // Node 2 is primary
assert.Equal(t, true, routesAfter2Up[1].GetAdvertised()) assert.Equal(t, true, routesAfter2Up[1].GetAdvertised())
assert.Equal(t, true, routesAfter2Up[1].GetEnabled()) assert.Equal(t, true, routesAfter2Up[1].GetEnabled())
assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary()) assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
// Verify that the route is announced from subnet router 1 // Verify that the route is announced from subnet router 1
clientStatus, err = client.Status() clientStatus, err = client.Status()
@ -608,6 +626,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.True(t, srs1PeerStatus.Online, "r1 up, r2 up")
assert.True(t, srs2PeerStatus.Online, "r1 up, r2 up")
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes) assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
@ -620,7 +641,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
} }
// Disable the route of subnet router 1, making it failover to 2 // Disable the route of subnet router 1, making it failover to 2
t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname()) t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname())
t.Logf("expecting route to failover to r2 (%s), which is still available", subRouter2.Hostname())
_, err = headscale.Execute( _, err = headscale.Execute(
[]string{ []string{
"headscale", "headscale",
@ -648,7 +670,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, routesAfterDisabling1, 2) assert.Len(t, routesAfterDisabling1, 2)
t.Logf("routes after disabling1 %#v", routesAfterDisabling1) t.Logf("routes after disabling r1 %#v", routesAfterDisabling1)
// Node 1 is not primary // Node 1 is not primary
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
@ -680,6 +702,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// enable the route of subnet router 1, no change expected // enable the route of subnet router 1, no change expected
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname()) t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname())
_, err = headscale.Execute( _, err = headscale.Execute(
[]string{ []string{
"headscale", "headscale",
@ -736,7 +759,8 @@ func TestHASubnetRouterFailover(t *testing.T) {
} }
// delete the route of subnet router 2, failover to one expected // delete the route of subnet router 2, failover to one expected
t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname()) t.Logf("deleting route in subnet router r2 (%s)", subRouter2.Hostname())
t.Logf("expecting route to failover to r1 (%s)", subRouter1.Hostname())
_, err = headscale.Execute( _, err = headscale.Execute(
[]string{ []string{
"headscale", "headscale",
@ -764,7 +788,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err) assertNoErr(t, err)
assert.Len(t, routesAfterDeleting2, 1) assert.Len(t, routesAfterDeleting2, 1)
t.Logf("routes after deleting2 %#v", routesAfterDeleting2) t.Logf("routes after deleting r2 %#v", routesAfterDeleting2)
// Node 1 is primary // Node 1 is primary
assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised()) assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised())

View file

@ -50,6 +50,8 @@ var (
tailscaleVersions2021 = map[string]bool{ tailscaleVersions2021 = map[string]bool{
"head": true, "head": true,
"unstable": true, "unstable": true,
"1.60": true, // CapVer: 82
"1.58": true, // CapVer: 82
"1.56": true, // CapVer: 82 "1.56": true, // CapVer: 82
"1.54": true, // CapVer: 79 "1.54": true, // CapVer: 79
"1.52": true, // CapVer: 79 "1.52": true, // CapVer: 79

View file

@ -27,7 +27,7 @@ type TailscaleClient interface {
Down() error Down() error
IPs() ([]netip.Addr, error) IPs() ([]netip.Addr, error)
FQDN() (string, error) FQDN() (string, error)
Status() (*ipnstate.Status, error) Status(...bool) (*ipnstate.Status, error)
Netmap() (*netmap.NetworkMap, error) Netmap() (*netmap.NetworkMap, error)
Netcheck() (*netcheck.Report, error) Netcheck() (*netcheck.Report, error)
WaitForNeedsLogin() error WaitForNeedsLogin() error

View file

@ -9,6 +9,7 @@ import (
"log" "log"
"net/netip" "net/netip"
"net/url" "net/url"
"os"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -503,7 +504,7 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
} }
// Status returns the ipnstate.Status of the Tailscale instance. // Status returns the ipnstate.Status of the Tailscale instance.
func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) { func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) {
command := []string{ command := []string{
"tailscale", "tailscale",
"status", "status",
@ -521,60 +522,70 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
return nil, fmt.Errorf("failed to unmarshal tailscale status: %w", err) return nil, fmt.Errorf("failed to unmarshal tailscale status: %w", err)
} }
err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_status.json", t.hostname), []byte(result), 0o755)
if err != nil {
return nil, fmt.Errorf("status netmap to /tmp/control: %w", err)
}
return &status, err return &status, err
} }
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance. // Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
// Only works with Tailscale 1.56 and newer. // Only works with Tailscale 1.56 and newer.
// Panics if version is lower then minimum. // Panics if version is lower then minimum.
// func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
// if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { if !util.TailscaleVersionNewerOrEqual("1.56", t.version) {
// panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version)) panic(fmt.Sprintf("tsic.Netmap() called with unsupported version: %s", t.version))
// } }
// command := []string{ command := []string{
// "tailscale", "tailscale",
// "debug", "debug",
// "netmap", "netmap",
// } }
// result, stderr, err := t.Execute(command) result, stderr, err := t.Execute(command)
// if err != nil { if err != nil {
// fmt.Printf("stderr: %s\n", stderr) fmt.Printf("stderr: %s\n", stderr)
// return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err) return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err)
// } }
// var nm netmap.NetworkMap var nm netmap.NetworkMap
// err = json.Unmarshal([]byte(result), &nm) err = json.Unmarshal([]byte(result), &nm)
// if err != nil { if err != nil {
// return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err) return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
// } }
// return &nm, err err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_netmap.json", t.hostname), []byte(result), 0o755)
// } if err != nil {
return nil, fmt.Errorf("saving netmap to /tmp/control: %w", err)
}
return &nm, err
}
// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance. // Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance.
// This implementation is based on getting the netmap from `tailscale debug watch-ipn` // This implementation is based on getting the netmap from `tailscale debug watch-ipn`
// as there seem to be some weirdness omitting endpoint and DERP info if we use // as there seem to be some weirdness omitting endpoint and DERP info if we use
// Patch updates. // Patch updates.
// This implementation works on all supported versions. // This implementation works on all supported versions.
func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { // func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
// watch-ipn will only give an update if something is happening, // // watch-ipn will only give an update if something is happening,
// since we send keep alives, the worst case for this should be // // since we send keep alives, the worst case for this should be
// 1 minute, but set a slightly more conservative time. // // 1 minute, but set a slightly more conservative time.
ctx, _ := context.WithTimeout(context.Background(), 3*time.Minute) // ctx, _ := context.WithTimeout(context.Background(), 3*time.Minute)
notify, err := t.watchIPN(ctx) // notify, err := t.watchIPN(ctx)
if err != nil { // if err != nil {
return nil, err // return nil, err
} // }
if notify.NetMap == nil { // if notify.NetMap == nil {
return nil, fmt.Errorf("no netmap present in ipn.Notify") // return nil, fmt.Errorf("no netmap present in ipn.Notify")
} // }
return notify.NetMap, nil // return notify.NetMap, nil
} // }
// watchIPN watches `tailscale debug watch-ipn` for a ipn.Notify object until // watchIPN watches `tailscale debug watch-ipn` for a ipn.Notify object until
// it gets one that has a netmap.NetworkMap. // it gets one that has a netmap.NetworkMap.

View file

@ -7,6 +7,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -154,11 +155,11 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) {
func assertValidNetmap(t *testing.T, client TailscaleClient) { func assertValidNetmap(t *testing.T, client TailscaleClient) {
t.Helper() t.Helper()
// if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) { if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) {
// t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version()) t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version())
// return return
// } }
t.Logf("Checking netmap of %q", client.Hostname()) t.Logf("Checking netmap of %q", client.Hostname())
@ -175,7 +176,11 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) assert.NotEmptyf(t, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname())
assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) assert.NotEmptyf(t, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname())
assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname()) if netmap.SelfNode.Online() != nil {
assert.Truef(t, *netmap.SelfNode.Online(), "%q is not online", client.Hostname())
} else {
t.Errorf("Online should not be nil for %s", client.Hostname())
}
assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) assert.Falsef(t, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname())
assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) assert.Falsef(t, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname())
@ -213,7 +218,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
// This test is not suitable for ACL/partial connection tests. // This test is not suitable for ACL/partial connection tests.
func assertValidStatus(t *testing.T, client TailscaleClient) { func assertValidStatus(t *testing.T, client TailscaleClient) {
t.Helper() t.Helper()
status, err := client.Status() status, err := client.Status(true)
if err != nil { if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err) t.Fatalf("getting status for %q: %s", client.Hostname(), err)
} }