ensure online status and route changes are propagated (#1564)

This commit is contained in:
Kristoffer Dalby 2023-12-09 18:09:24 +01:00 committed by GitHub
parent 0153e26392
commit f65f4eca35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 3170 additions and 857 deletions

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestHASubnetRouterFailover
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestHASubnetRouterFailover:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestHASubnetRouterFailover
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestHASubnetRouterFailover$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestNodeOnlineLastSeenStatus
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestNodeOnlineLastSeenStatus:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestNodeOnlineLastSeenStatus
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestNodeOnlineLastSeenStatus$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

1
.gitignore vendored
View file

@ -1,5 +1,6 @@
ignored/
tailscale/
.vscode/
# Binaries for programs and plugins
*.exe

View file

@ -27,6 +27,8 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553)
- Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611)
- The latest supported client is 1.32
- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564)
- If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url.
### Changes

View file

@ -1,47 +0,0 @@
package main
import (
"log"
"github.com/juanfont/headscale/integration"
"github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3"
)
func main() {
log.Printf("creating docker pool")
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("could not connect to docker: %s", err)
}
log.Printf("creating docker network")
network, err := pool.CreateNetwork("docker-integration-net")
if err != nil {
log.Fatalf("failed to create or get network: %s", err)
}
for _, version := range integration.AllVersions {
log.Printf("creating container image for Tailscale (%s)", version)
tsClient, err := tsic.New(
pool,
version,
network,
)
if err != nil {
log.Fatalf("failed to create tailscale node: %s", err)
}
err = tsClient.Shutdown()
if err != nil {
log.Fatalf("failed to shut down container: %s", err)
}
}
network.Close()
err = pool.RemoveNetwork(network)
if err != nil {
log.Fatalf("failed to remove network: %s", err)
}
}

View file

@ -493,7 +493,7 @@ func nodesToPtables(
"Ephemeral",
"Last seen",
"Expiration",
"Online",
"Connected",
"Expired",
}
if showTags {

View file

@ -31,7 +31,7 @@
# When updating go.mod or go.sum, a new sha will need to be calculated,
# update this if you have a mismatch after doing a change to thos files.
vendorHash = "sha256-2ci6m1rKI3QdwbkqaGQlf0R+w4PhD0lkrLAu6wKj1LE=";
vendorHash = "sha256-7yqJbF0GkKa3wjiGWJ8BZSJyckrpwmCiX77/aoPGmRc=";
ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"];
};

2
go.mod
View file

@ -35,6 +35,7 @@ require (
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
go4.org/netipx v0.0.0-20230824141953-6213f710f925
golang.org/x/crypto v0.16.0
golang.org/x/exp v0.0.0-20231127185646-65229373498e
golang.org/x/net v0.19.0
golang.org/x/oauth2 v0.15.0
golang.org/x/sync v0.5.0
@ -146,7 +147,6 @@ require (
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/multierr v1.11.0 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
golang.org/x/exp v0.0.0-20231127185646-65229373498e // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/term v0.15.0 // indirect

View file

@ -59,6 +59,7 @@ var (
errUnsupportedLetsEncryptChallengeType = errors.New(
"unknown value for Lets Encrypt challenge type",
)
errEmptyInitialDERPMap = errors.New("initial DERPMap is empty, Headscale requries at least one entry")
)
const (
@ -193,7 +194,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
}
if derpServerKey.Equal(*noisePrivateKey) {
return nil, fmt.Errorf("DERP server private key and noise private key are the same: %w", err)
return nil, fmt.Errorf(
"DERP server private key and noise private key are the same: %w",
err,
)
}
embeddedDERPServer, err := derpServer.NewDERPServer(
@ -259,21 +263,14 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
h.DERPMap.Regions[region.RegionID] = &region
}
h.nodeNotifier.NotifyAll(types.StateUpdate{
stateUpdate := types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: *h.DERPMap,
})
DERPMap: h.DERPMap,
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyAll(stateUpdate)
}
}
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C {
err := h.db.HandlePrimarySubnetFailover()
if err != nil {
log.Error().Err(err).Msg("failed to handle primary subnet failover")
}
}
}
@ -505,13 +502,15 @@ func (h *Headscale) Serve() error {
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
}
if len(h.DERPMap.Regions) == 0 {
return errEmptyInitialDERPMap
}
// TODO(kradalby): These should have cancel channels and be cleaned
// up on shutdown.
go h.expireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval)
go h.failoverSubnetRoutes(updateInterval)
if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true
} else {

View file

@ -16,6 +16,46 @@ import (
"tailscale.com/types/key"
)
func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (func(string), func(string), func(error, string)) {
return func(msg string) {
log.Info().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("node", registerRequest.Hostinfo.Hostname).
Str("followup", registerRequest.Followup).
Time("expiry", registerRequest.Expiry).
Msg(msg)
},
func(msg string) {
log.Trace().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("node", registerRequest.Hostinfo.Hostname).
Str("followup", registerRequest.Followup).
Time("expiry", registerRequest.Expiry).
Msg(msg)
},
func(err error, msg string) {
log.Error().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("node", registerRequest.Hostinfo.Hostname).
Str("followup", registerRequest.Followup).
Time("expiry", registerRequest.Expiry).
Err(err).
Msg(msg)
}
}
// handleRegister is the logic for registering a client.
func (h *Headscale) handleRegister(
writer http.ResponseWriter,
@ -23,8 +63,11 @@ func (h *Headscale) handleRegister(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
logTrace("handleRegister database lookup has returned")
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the node has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" {
@ -42,15 +85,9 @@ func (h *Headscale) handleRegister(
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if registerRequest.Followup != "" {
logTrace("register request is a followup")
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
log.Debug().
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Msg("Node is waiting for interactive login")
logTrace("Node is waiting for interactive login")
select {
case <-req.Context().Done():
@ -63,26 +100,14 @@ func (h *Headscale) handleRegister(
}
}
log.Info().
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Msg("New node not yet in the database")
logInfo("Node not found in database, creating new")
givenName, err := h.db.GenerateGivenName(
machineKey,
registerRequest.Hostinfo.Hostname,
)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed to generate given name for node")
logErr(err, "Failed to generate given name for node")
return
}
@ -101,11 +126,7 @@ func (h *Headscale) handleRegister(
}
if !registerRequest.Expiry.IsZero() {
log.Trace().
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested")
logTrace("Non-zero expiry time requested")
newNode.Expiry = &registerRequest.Expiry
}
@ -419,13 +440,12 @@ func (h *Headscale) handleNewNode(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
resp := tailcfg.RegisterResponse{}
// The node registration is new, redirect the client to the registration URL
log.Debug().
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Msg("The node seems to be new, sending auth url")
logTrace("The node seems to be new, sending auth url")
if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf(
@ -441,10 +461,7 @@ func (h *Headscale) handleNewNode(
respBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot encode message")
logErr(err, "Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
@ -454,17 +471,10 @@ func (h *Headscale) handleNewNode(
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
logErr(err, "Failed to write response")
}
log.Info().
Caller().
Str("AuthURL", resp.AuthURL).
Str("node", registerRequest.Hostinfo.Hostname).
Msg("Successfully sent auth url")
logInfo(fmt.Sprintf("Successfully sent auth url: %s", resp.AuthURL))
}
func (h *Headscale) handleNodeLogOut(
@ -490,6 +500,19 @@ func (h *Headscale) handleNodeLogOut(
return
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &now,
},
},
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.NodeKeyExpired = true

View file

@ -171,11 +171,13 @@ func NewHeadscaleDatabase(
dKey = "discokey:" + node.DiscoKey
}
err := db.db.Exec("UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
err := db.db.Exec(
"UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
sql.Named("mKey", mKey),
sql.Named("nKey", nKey),
sql.Named("dKey", dKey),
sql.Named("id", node.ID)).Error
sql.Named("id", node.ID),
).Error
if err != nil {
return nil, err
}

View file

@ -61,11 +61,6 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID })
log.Trace().
Caller().
Str("node", node.Hostname).
Msgf("Found peers: %s", nodes.String())
return nodes, nil
}
@ -176,6 +171,12 @@ func (hsdb *HSDatabase) GetNodeByMachineKey(
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
return hsdb.getNodeByMachineKey(machineKey)
}
func (hsdb *HSDatabase) getNodeByMachineKey(
machineKey key.MachinePublic,
) (*types.Node, error) {
mach := types.Node{}
if result := hsdb.db.
Preload("AuthKey").
@ -252,6 +253,10 @@ func (hsdb *HSDatabase) SetTags(
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
if len(tags) == 0 {
return nil
}
newTags := []string{}
for _, tag := range tags {
if !util.StringOrPrefixListContains(newTags, tag) {
@ -265,10 +270,14 @@ func (hsdb *HSDatabase) SetTags(
return fmt.Errorf("failed to update tags for node in the database: %w", err)
}
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey.String())
ChangeNodes: types.Nodes{node},
Message: "called from db.SetTags",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil
}
@ -301,10 +310,14 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
return fmt.Errorf("failed to rename node in the database: %w", err)
}
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey.String())
ChangeNodes: types.Nodes{node},
Message: "called from db.RenameNode",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil
}
@ -327,10 +340,18 @@ func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error
)
}
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey.String())
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &expiry,
},
},
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
}
@ -354,10 +375,13 @@ func (hsdb *HSDatabase) deleteNode(node *types.Node) error {
return err
}
hsdb.notifier.NotifyAll(types.StateUpdate{
stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
})
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
}
@ -629,20 +653,6 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool
return false
}
func (hsdb *HSDatabase) ListOnlineNodes(
node *types.Node,
) (map[tailcfg.NodeID]bool, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
peers, err := hsdb.listPeers(node)
if err != nil {
return nil, err
}
return peers.OnlineNodeMap(), nil
}
// enableRoutes enables new routes based on a list of new routes.
func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error {
newRoutes := make([]netip.Prefix, len(routeStrs))
@ -694,10 +704,30 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
}
}
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{
// Ensure the node has the latest routes when notifying the other
// nodes
nRoutes, err := hsdb.getNodeRoutes(node)
if err != nil {
return fmt.Errorf("failed to read back routes: %w", err)
}
node.Routes = nRoutes
log.Trace().
Caller().
Str("node", node.Hostname).
Strs("routes", routeStrs).
Msg("enabling routes")
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
}, node.MachineKey.String())
ChangeNodes: types.Nodes{node},
Message: "called from db.enableRoutes",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(
stateUpdate, node.MachineKey.String())
}
return nil
}
@ -728,7 +758,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
return normalizedHostname, nil
}
func (hsdb *HSDatabase) GenerateGivenName(mkey key.MachinePublic, suppliedName string) (string, error) {
func (hsdb *HSDatabase) GenerateGivenName(
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
@ -823,33 +856,34 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
// checked everything.
started := time.Now()
users, err := hsdb.listUsers()
if err != nil {
log.Error().Err(err).Msg("Error listing users")
expired := make([]*tailcfg.PeerChange, 0)
return time.Unix(0, 0)
}
for _, user := range users {
nodes, err := hsdb.listNodesByUser(user.Name)
nodes, err := hsdb.listNodes()
if err != nil {
log.Error().
Err(err).
Str("user", user.Name).
Msg("Error listing nodes in user")
Msg("Error listing nodes to find expired nodes")
return time.Unix(0, 0)
}
expired := make([]tailcfg.NodeID, 0)
for index, node := range nodes {
if node.IsExpired() &&
// TODO(kradalby): Replace this, it is very spammy
// It will notify about all nodes that has been expired.
// It should only notify about expired nodes since _last check_.
node.Expiry.After(lastCheck) {
expired = append(expired, tailcfg.NodeID(node.ID))
expired = append(expired, &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry,
})
now := time.Now()
err := hsdb.nodeSetExpiry(nodes[index], now)
if err != nil {
// Do not use setNodeExpiry as that has a notifier hook, which
// can cause a deadlock, we are updating all changed nodes later
// and there is no point in notifiying twice.
if err := hsdb.db.Model(nodes[index]).Updates(types.Node{
Expiry: &now,
}).Error; err != nil {
log.Error().
Err(err).
Str("node", node.Hostname).
@ -864,12 +898,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
}
}
if len(expired) > 0 {
hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: expired,
})
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: expired,
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return started

View file

@ -603,8 +603,9 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
db.db.Save(&node)
err = db.SaveNodeRoutes(&node)
sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
node0ByID, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil)

View file

@ -7,7 +7,9 @@ import (
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"gorm.io/gorm"
"tailscale.com/types/key"
)
var ErrRouteIsNotAvailable = errors.New("route is not available")
@ -21,7 +23,38 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
var routes types.Routes
err := hsdb.db.Preload("Node").Find(&routes).Error
err := hsdb.db.
Preload("Node").
Preload("Node.User").
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
Preload("Node").
Preload("Node.User").
Where("advertised = ? AND enabled = ?", true, true).
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
Preload("Node").
Preload("Node.User").
Where("prefix = ?", types.IPPrefix(pref)).
Find(&routes).Error
if err != nil {
return nil, err
}
@ -40,6 +73,7 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes,
var routes types.Routes
err := hsdb.db.
Preload("Node").
Preload("Node.User").
Where("node_id = ? AND advertised = true", node.ID).
Find(&routes).Error
if err != nil {
@ -60,6 +94,7 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
Preload("Node").
Preload("Node.User").
Where("node_id = ?", node.ID).
Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
@ -78,7 +113,10 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
var route types.Route
err := hsdb.db.Preload("Node").First(&route, id).Error
err := hsdb.db.
Preload("Node").
Preload("Node.User").
First(&route, id).Error
if err != nil {
return nil, err
}
@ -122,21 +160,26 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
return err
}
var routes types.Routes
node := route.Node
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.IsExitRoute() {
err = hsdb.failoverRouteWithNotify(route)
if err != nil {
return err
}
route.Enabled = false
route.IsPrimary = false
err = hsdb.db.Save(route).Error
if err != nil {
return err
}
return hsdb.handlePrimarySubnetFailover()
}
routes, err := hsdb.getNodeRoutes(&route.Node)
} else {
routes, err = hsdb.getNodeRoutes(&node)
if err != nil {
return err
}
@ -151,8 +194,27 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
}
}
}
}
return hsdb.handlePrimarySubnetFailover()
if routes == nil {
routes, err = hsdb.getNodeRoutes(&node)
if err != nil {
return err
}
}
node.Routes = routes
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node},
Message: "called from db.DisableRoute",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
}
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
@ -164,18 +226,23 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
return err
}
var routes types.Routes
node := route.Node
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.IsExitRoute() {
err := hsdb.failoverRouteWithNotify(route)
if err != nil {
return nil
}
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
return err
}
return hsdb.handlePrimarySubnetFailover()
}
routes, err := hsdb.getNodeRoutes(&route.Node)
} else {
routes, err := hsdb.getNodeRoutes(&node)
if err != nil {
return err
}
@ -190,8 +257,27 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
return err
}
}
return hsdb.handlePrimarySubnetFailover()
if routes == nil {
routes, err = hsdb.getNodeRoutes(&node)
if err != nil {
return err
}
}
node.Routes = routes
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node},
Message: "called from db.DeleteRoute",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
}
func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
@ -204,9 +290,13 @@ func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
return err
}
// TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all.
hsdb.failoverRouteWithNotify(&routes[i])
}
return hsdb.handlePrimarySubnetFailover()
return nil
}
// isUniquePrefix returns if there is another node providing the same route already.
@ -259,18 +349,22 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er
// SaveNodeRoutes takes a node and updates the database with
// the new routes.
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) error {
// It returns a bool wheter an update should be sent as the
// saved route impacts nodes.
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.saveNodeRoutes(node)
}
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
sendUpdate := false
currentRoutes := types.Routes{}
err := hsdb.db.Where("node_id = ?", node.ID).Find(&currentRoutes).Error
if err != nil {
return err
return sendUpdate, err
}
advertisedRoutes := map[netip.Prefix]bool{}
@ -290,7 +384,14 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
currentRoutes[pos].Advertised = true
err := hsdb.db.Save(&currentRoutes[pos]).Error
if err != nil {
return err
return sendUpdate, err
}
// If a route that is newly "saved" is already
// enabled, set sendUpdate to true as it is now
// available.
if route.Enabled {
sendUpdate = true
}
}
advertisedRoutes[netip.Prefix(route.Prefix)] = true
@ -299,7 +400,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
currentRoutes[pos].Enabled = false
err := hsdb.db.Save(&currentRoutes[pos]).Error
if err != nil {
return err
return sendUpdate, err
}
}
}
@ -313,142 +414,224 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
Enabled: false,
}
err := hsdb.db.Create(&route).Error
if err != nil {
return sendUpdate, err
}
}
}
return sendUpdate, nil
}
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
// currently have a functioning host that exposes the network.
func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error {
nodeRoutes, err := hsdb.getNodeRoutes(node)
if err != nil {
return nil
}
for _, nodeRoute := range nodeRoutes {
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix))
if err != nil {
return err
}
for _, route := range routes {
if route.IsPrimary {
// if we have a primary route, and the node is connected
// nothing needs to be done.
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
continue
}
// if not, we need to failover the route
err := hsdb.failoverRouteWithNotify(&route)
if err != nil {
return err
}
}
}
}
return nil
}
func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
hsdb.mu.Lock()
defer hsdb.mu.Unlock()
return hsdb.handlePrimarySubnetFailover()
}
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
// first, get all the enabled routes
var routes types.Routes
err := hsdb.db.
Preload("Node").
Where("advertised = ? AND enabled = ?", true, true).
Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().Err(err).Msg("error getting routes")
}
changedNodes := make(types.Nodes, 0)
for pos, route := range routes {
if route.IsExitRoute() {
continue
}
node := &route.Node
if !route.IsPrimary {
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().
Str("prefix", netip.Prefix(route.Prefix).String()).
Str("node", route.Node.GivenName).
Msg("Setting primary route")
routes[pos].IsPrimary = true
err := hsdb.db.Save(&routes[pos]).Error
func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error {
routes, err := hsdb.getNodeRoutes(node)
if err != nil {
log.Error().Err(err).Msg("error marking route as primary")
return nil
}
var changedKeys []key.MachinePublic
for _, route := range routes {
changed, err := hsdb.failoverRoute(&route)
if err != nil {
return err
}
changedNodes = append(changedNodes, node)
continue
}
changedKeys = append(changedKeys, changed...)
}
if route.IsPrimary {
if route.Node.IsOnline() {
continue
}
changedKeys = lo.Uniq(changedKeys)
// node offline, find a new primary
log.Info().
Str("node", route.Node.Hostname).
Str("prefix", netip.Prefix(route.Prefix).String()).
Msgf("node offline, finding a new primary subnet")
// find a new primary route
var newPrimaryRoutes types.Routes
err := hsdb.db.
Preload("Node").
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
route.NodeID,
true, true).
Find(&newPrimaryRoutes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().Err(err).Msg("error finding new primary route")
var nodes types.Nodes
for _, key := range changedKeys {
node, err := hsdb.GetNodeByMachineKey(key)
if err != nil {
return err
}
var newPrimaryRoute *types.Route
for pos, r := range newPrimaryRoutes {
if r.Node.IsOnline() {
newPrimaryRoute = &newPrimaryRoutes[pos]
nodes = append(nodes, node)
}
if nodes != nil {
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.FailoverNodeRoutesWithNotify",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
}
return nil
}
func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
changedKeys, err := hsdb.failoverRoute(r)
if err != nil {
return err
}
if len(changedKeys) == 0 {
return nil
}
var nodes types.Nodes
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("loading machines with new primary routes from db")
for _, key := range changedKeys {
node, err := hsdb.getNodeByMachineKey(key)
if err != nil {
return err
}
nodes = append(nodes, node)
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notifying peers about primary route change")
if nodes != nil {
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.failoverRouteWithNotify",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notified peers about primary route change")
return nil
}
// failoverRoute takes a route that is no longer available,
// this can be either from:
// - being disabled
// - being deleted
// - host going offline
//
// and tries to find a new route to take over its place.
// If the given route was not primary, it returns early.
func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) {
if r == nil {
return nil, nil
}
// This route is not a primary route, and it isnt
// being served to nodes.
if !r.IsPrimary {
return nil, nil
}
// We do not have to failover exit nodes
if r.IsExitRoute() {
return nil, nil
}
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix))
if err != nil {
return nil, err
}
var newPrimary *types.Route
// Find a new suitable route
for idx, route := range routes {
if r.ID == route.ID {
continue
}
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
newPrimary = &routes[idx]
break
}
}
if newPrimaryRoute == nil {
log.Warn().
Str("node", route.Node.Hostname).
Str("prefix", netip.Prefix(route.Prefix).String()).
Msgf("no alternative primary route found")
continue
// If a new route was not found/available,
// return with an error.
// We do not want to update the database as
// the one currently marked as primary is the
// best we got.
if newPrimary == nil {
return nil, nil
}
log.Info().
Str("old_node", route.Node.Hostname).
Str("prefix", netip.Prefix(route.Prefix).String()).
Str("new_node", newPrimaryRoute.Node.Hostname).
Msgf("found new primary route")
log.Trace().
Str("hostname", newPrimary.Node.Hostname).
Msg("found new primary, updating db")
// disable the old primary route
routes[pos].IsPrimary = false
err = hsdb.db.Save(&routes[pos]).Error
// Remove primary from the old route
r.IsPrimary = false
err = hsdb.db.Save(&r).Error
if err != nil {
log.Error().Err(err).Msg("error disabling old primary route")
log.Error().Err(err).Msg("error disabling new primary route")
return err
return nil, err
}
// enable the new primary route
newPrimaryRoute.IsPrimary = true
err = hsdb.db.Save(&newPrimaryRoute).Error
log.Trace().
Str("hostname", newPrimary.Node.Hostname).
Msg("removed primary from old route")
// Set primary for the new primary
newPrimary.IsPrimary = true
err = hsdb.db.Save(&newPrimary).Error
if err != nil {
log.Error().Err(err).Msg("error enabling new primary route")
return err
return nil, err
}
changedNodes = append(changedNodes, node)
}
}
log.Trace().
Str("hostname", newPrimary.Node.Hostname).
Msg("set primary to new route")
if len(changedNodes) > 0 {
hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: changedNodes,
})
}
return nil
// Return a list of the machinekeys of the changed nodes.
return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
}
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.

View file

@ -2,12 +2,19 @@ package db
import (
"net/netip"
"os"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (s *Suite) TestGetRoutes(c *check.C) {
@ -37,8 +44,9 @@ func (s *Suite) TestGetRoutes(c *check.C) {
}
db.db.Save(&node)
err = db.SaveNodeRoutes(&node)
su, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
c.Assert(su, check.Equals, false)
advertisedRoutes, err := db.GetAdvertisedRoutes(&node)
c.Assert(err, check.IsNil)
@ -85,8 +93,9 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
}
db.db.Save(&node)
err = db.SaveNodeRoutes(&node)
sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
availableRoutes, err := db.GetAdvertisedRoutes(&node)
c.Assert(err, check.IsNil)
@ -156,8 +165,9 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
}
db.db.Save(&node1)
err = db.SaveNodeRoutes(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, route.String())
c.Assert(err, check.IsNil)
@ -178,8 +188,9 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
}
db.db.Save(&node2)
err = db.SaveNodeRoutes(&node2)
sendUpdate, err = db.SaveNodeRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node2, route2.String())
c.Assert(err, check.IsNil)
@ -201,142 +212,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
c.Assert(len(routes), check.Equals, 0)
}
func (s *Suite) TestSubnetFailover(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
prefix2, err := netip.ParsePrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
}
now := time.Now()
node1 := types.Node{
ID: 1,
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)
err = db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, prefix.String())
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, prefix2.String())
c.Assert(err, check.IsNil)
err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2)
route, err := db.getPrimaryRoute(prefix)
c.Assert(err, check.IsNil)
c.Assert(route.NodeID, check.Equals, node1.ID)
hostInfo2 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix2},
}
node2 := types.Node{
ID: 2,
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo2,
LastSeen: &now,
}
db.db.Save(&node2)
err = db.saveNodeRoutes(&node2)
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node2, prefix2.String())
c.Assert(err, check.IsNil)
err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
enabledRoutes1, err = db.GetEnabledRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2)
enabledRoutes2, err := db.GetEnabledRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
routes, err := db.GetNodePrimaryRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
routes, err = db.GetNodePrimaryRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0)
// lets make node1 lastseen 10 mins ago
before := now.Add(-10 * time.Minute)
node1.LastSeen = &before
err = db.db.Save(&node1).Error
c.Assert(err, check.IsNil)
err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
routes, err = db.GetNodePrimaryRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1)
routes, err = db.GetNodePrimaryRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1)
node2.Hostinfo = &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
}
err = db.db.Save(&node2).Error
c.Assert(err, check.IsNil)
err = db.SaveNodeRoutes(&node2)
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node2, prefix.String())
c.Assert(err, check.IsNil)
err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
routes, err = db.GetNodePrimaryRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0)
routes, err = db.GetNodePrimaryRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
}
func (s *Suite) TestDeleteRoutes(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
@ -373,8 +248,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
}
db.db.Save(&node1)
err = db.SaveNodeRoutes(&node1)
sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, prefix.String())
c.Assert(err, check.IsNil)
@ -392,3 +268,362 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
}
func TestFailoverRoute(t *testing.T) {
ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Count/verify updates
var sink chan types.StateUpdate
go func() {
for range sink {
}
}()
machineKeys := []key.MachinePublic{
key.NewMachine().Public(),
key.NewMachine().Public(),
key.NewMachine().Public(),
key.NewMachine().Public(),
}
tests := []struct {
name string
failingRoute types.Route
routes types.Routes
want []key.MachinePublic
wantErr bool
}{
{
name: "no-route",
failingRoute: types.Route{},
routes: types.Routes{},
want: nil,
wantErr: false,
},
{
name: "no-prime",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: false,
},
routes: types.Routes{},
want: nil,
wantErr: false,
},
{
name: "exit-node",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("0.0.0.0/0"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{},
want: nil,
wantErr: false,
},
{
name: "no-failover-single-route",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
},
want: nil,
wantErr: false,
},
{
name: "failover-primary",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: false,
},
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
},
wantErr: false,
},
{
name: "failover-none-primary",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: false,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: false,
},
},
want: nil,
wantErr: false,
},
{
name: "failover-primary-multi-route",
failingRoute: types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: false,
},
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: true,
},
types.Route{
Model: gorm.Model{
ID: 3,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[2],
},
IsPrimary: false,
},
},
want: []key.MachinePublic{
machineKeys[1],
machineKeys[0],
},
wantErr: false,
},
{
name: "failover-primary-no-online",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
// Offline
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[3],
},
IsPrimary: false,
},
},
want: nil,
wantErr: false,
},
{
name: "failover-primary-one-not-online",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
// Offline
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[3],
},
IsPrimary: false,
},
types.Route{
Model: gorm.Model{
ID: 3,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: true,
},
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "failover-db-test")
assert.NoError(t, err)
notif := notifier.NewNotifier()
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
notif,
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
"",
)
assert.NoError(t, err)
// Pretend that all the nodes are connected to control
for idx, key := range machineKeys {
// Pretend one node is offline
if idx == 3 {
continue
}
notif.AddNode(key, sink)
}
for _, route := range tt.routes {
if err := db.db.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err)
}
}
got, err := db.failoverRoute(&tt.failingRoute)
if (err != nil) != tt.wantErr {
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -18,6 +18,7 @@ import (
"tailscale.com/net/stun"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
// fastStartHeader is the header (with value "1") that signals to the HTTP
@ -33,13 +34,19 @@ type DERPServer struct {
tailscaleDERP *derp.Server
}
func derpLogf() logger.Logf {
return func(format string, args ...any) {
log.Debug().Caller().Msgf(format, args...)
}
}
func NewDERPServer(
serverURL string,
derpKey key.NodePrivate,
cfg *types.DERPConfig,
) (*DERPServer, error) {
log.Trace().Caller().Msg("Creating new embedded DERP server")
server := derp.NewServer(derpKey, log.Debug().Msgf) // nolint // zerolinter complains
server := derp.NewServer(derpKey, derpLogf()) // nolint // zerolinter complains
return &DERPServer{
serverURL: serverURL,

View file

@ -204,7 +204,13 @@ func (api headscaleV1APIServer) GetNode(
return nil, err
}
return &v1.GetNodeResponse{Node: node.Proto()}, nil
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
return &v1.GetNodeResponse{Node: resp}, nil
}
func (api headscaleV1APIServer) SetTags(
@ -333,7 +339,13 @@ func (api headscaleV1APIServer) ListNodes(
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
response[index] = node.Proto()
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
response[index] = resp
}
return &v1.ListNodesResponse{Nodes: response}, nil
@ -346,13 +358,18 @@ func (api headscaleV1APIServer) ListNodes(
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
m := node.Proto()
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
&node,
)
m.InvalidTags = invalidTags
m.ValidTags = validTags
response[index] = m
resp.InvalidTags = invalidTags
resp.ValidTags = validTags
response[index] = resp
}
return &v1.ListNodesResponse{Nodes: response}, nil

View file

@ -8,6 +8,7 @@ import (
"net/url"
"os"
"path"
"slices"
"sort"
"strings"
"sync"
@ -21,6 +22,7 @@ import (
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"golang.org/x/exp/maps"
"tailscale.com/envknob"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
@ -45,6 +47,7 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
// - Keep information about the previous mapresponse so we can send a diff
// - Store hashes
// - Create a "minifier" that removes info not needed for the node
// - some sort of batching, wait for 5 or 60 seconds before sending
type Mapper struct {
// Configuration
@ -63,6 +66,12 @@ type Mapper struct {
// only one func is accessing it over time.
mu sync.Mutex
peers map[uint64]*types.Node
patches map[uint64][]patch
}
type patch struct {
timestamp time.Time
change *tailcfg.PeerChange
}
func NewMapper(
@ -94,6 +103,7 @@ func NewMapper(
// TODO: populate
peers: peers.IDMap(),
patches: make(map[uint64][]patch),
}
}
@ -235,6 +245,19 @@ func (m *Mapper) FullMapResponse(
m.mu.Lock()
defer m.mu.Unlock()
peers := maps.Keys(m.peers)
peersWithPatches := maps.Keys(m.patches)
slices.Sort(peers)
slices.Sort(peersWithPatches)
if len(peersWithPatches) > 0 {
log.Debug().
Str("node", node.Hostname).
Uints64("peers", peers).
Uints64("pending_patches", peersWithPatches).
Msgf("node requested full map response, but has pending patches")
}
resp, err := m.fullMapResponse(node, pol, mapRequest.Version)
if err != nil {
return nil, err
@ -272,10 +295,12 @@ func (m *Mapper) KeepAliveResponse(
func (m *Mapper) DERPMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
derpMap tailcfg.DERPMap,
derpMap *tailcfg.DERPMap,
) ([]byte, error) {
m.derpMap = derpMap
resp := m.baseMapResponse()
resp.DERPMap = &derpMap
resp.DERPMap = derpMap
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
@ -285,18 +310,29 @@ func (m *Mapper) PeerChangedResponse(
node *types.Node,
changed types.Nodes,
pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
lastSeen := make(map[tailcfg.NodeID]bool)
// Update our internal map.
for _, node := range changed {
m.peers[node.ID] = node
if patches, ok := m.patches[node.ID]; ok {
// preserve online status in case the patch has an outdated one
online := node.IsOnline
// We have just seen the node, let the peers update their list.
lastSeen[tailcfg.NodeID(node.ID)] = true
for _, p := range patches {
// TODO(kradalby): Figure if this needs to be sorted by timestamp
node.ApplyPeerChange(p.change)
}
// Ensure the patches are not applied again later
delete(m.patches, node.ID)
node.IsOnline = online
}
m.peers[node.ID] = node
}
resp := m.baseMapResponse()
@ -316,11 +352,55 @@ func (m *Mapper) PeerChangedResponse(
return nil, err
}
// resp.PeerSeenChange = lastSeen
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
}
// PeerChangedPatchResponse creates a patch MapResponse with
// incoming update from a state change.
func (m *Mapper) PeerChangedPatchResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
changed []*tailcfg.PeerChange,
pol *policy.ACLPolicy,
) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
sendUpdate := false
// patch the internal map
for _, change := range changed {
if peer, ok := m.peers[uint64(change.NodeID)]; ok {
peer.ApplyPeerChange(change)
sendUpdate = true
} else {
log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname)
p := patch{
timestamp: time.Now(),
change: change,
}
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
patches := append(patches, p)
m.patches[uint64(change.NodeID)] = patches
} else {
m.patches[uint64(change.NodeID)] = []patch{p}
}
}
}
if !sendUpdate {
return nil, nil
}
resp := m.baseMapResponse()
resp.PeersChangedPatch = changed
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
// TODO(kradalby): We need some integration tests for this.
func (m *Mapper) PeerRemovedResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
@ -329,13 +409,23 @@ func (m *Mapper) PeerRemovedResponse(
m.mu.Lock()
defer m.mu.Unlock()
// Some nodes might have been removed already
// so we dont want to ask downstream to remove
// twice, than can cause a panic in tailscaled.
notYetRemoved := []tailcfg.NodeID{}
// remove from our internal map
for _, id := range removed {
if _, ok := m.peers[uint64(id)]; ok {
notYetRemoved = append(notYetRemoved, id)
}
delete(m.peers, uint64(id))
delete(m.patches, uint64(id))
}
resp := m.baseMapResponse()
resp.PeersRemoved = removed
resp.PeersRemoved = notYetRemoved
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
@ -345,6 +435,7 @@ func (m *Mapper) marshalMapResponse(
resp *tailcfg.MapResponse,
node *types.Node,
compression string,
messages ...string,
) ([]byte, error) {
atomic.AddUint64(&m.seq, 1)
@ -358,11 +449,25 @@ func (m *Mapper) marshalMapResponse(
if debugDumpMapResponsePath != "" {
data := map[string]interface{}{
"Messages": messages,
"MapRequest": mapRequest,
"MapResponse": resp,
}
body, err := json.Marshal(data)
responseType := "keepalive"
switch {
case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
responseType = "changed"
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
responseType = "patch"
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
responseType = "removed"
}
body, err := json.MarshalIndent(data, "", " ")
if err != nil {
log.Error().
Caller().
@ -381,7 +486,7 @@ func (m *Mapper) marshalMapResponse(
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%d-%s-%d.json", now, m.uid, atomic.LoadUint64(&m.seq)),
fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@ -438,6 +543,7 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
resp := tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
// TODO(kradalby): Implement PingRequest?
}
return resp
@ -559,8 +665,5 @@ func appendPeerChanges(
resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy
// TODO(kradalby): This currently does not take last seen in keepalives into account
resp.OnlineChange = peers.OnlineNodeMap()
return nil
}

View file

@ -237,7 +237,6 @@ func Test_fullMapResponse(t *testing.T) {
Tags: []string{},
PrimaryRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
LastSeen: &lastSeen,
Online: new(bool),
MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{
tailcfg.CapabilityFileSharing,
@ -293,7 +292,6 @@ func Test_fullMapResponse(t *testing.T) {
Tags: []string{},
PrimaryRoutes: []netip.Prefix{},
LastSeen: &lastSeen,
Online: new(bool),
MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{
tailcfg.CapabilityFileSharing,
@ -400,7 +398,6 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
@ -442,10 +439,6 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{},
Domain: "",
CollectServices: "false",
OnlineChange: map[tailcfg.NodeID]bool{
tailPeer1.ID: false,
tailcfg.NodeID(peer2.ID): false,
},
PacketFilter: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.2/32"},

View file

@ -87,11 +87,9 @@ func tailNode(
hostname, err := node.GetFQDN(dnsConfig, baseDomain)
if err != nil {
return nil, err
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
}
online := node.IsOnline()
tags, _ := pol.TagsOfNode(node)
tags = lo.Uniq(append(tags, node.ForcedTags...))
@ -101,6 +99,7 @@ func tailNode(
strconv.FormatUint(node.ID, util.Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
Cap: capVer,
User: tailcfg.UserID(node.UserID),
@ -116,13 +115,14 @@ func tailNode(
Hostinfo: node.Hostinfo.View(),
Created: node.CreatedAt,
Online: node.IsOnline,
Tags: tags,
PrimaryRoutes: primaryPrefixes,
LastSeen: node.LastSeen,
Online: &online,
MachineAuthorized: !node.IsExpired(),
Expired: node.IsExpired(),
}
// - 74: 2023-09-18: Client understands NodeCapMap
@ -153,5 +153,11 @@ func tailNode(
tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrDisableUPnP)
}
if node.IsOnline == nil || !*node.IsOnline {
// LastSeen is only set when node is
// not connected to the control server.
tNode.LastSeen = node.LastSeen
}
return &tNode, nil
}

View file

@ -68,7 +68,6 @@ func TestTailNode(t *testing.T) {
Hostinfo: hiview(tailcfg.Hostinfo{}),
Tags: []string{},
PrimaryRoutes: []netip.Prefix{},
Online: new(bool),
MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{
"https://tailscale.com/cap/file-sharing", "https://tailscale.com/cap/is-admin",
@ -165,7 +164,6 @@ func TestTailNode(t *testing.T) {
},
LastSeen: &lastSeen,
Online: new(bool),
MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{

View file

@ -1,6 +1,8 @@
package notifier
import (
"fmt"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
@ -56,6 +58,19 @@ func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
Msg("Removed channel")
}
// IsConnected reports if a node is connected to headscale and has a
// poll session open.
func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
n.l.RLock()
defer n.l.RUnlock()
if _, ok := n.nodes[machineKey.String()]; ok {
return true
}
return false
}
func (n *Notifier) NotifyAll(update types.StateUpdate) {
n.NotifyWithIgnore(update)
}
@ -79,3 +94,16 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string)
c <- update
}
}
func (n *Notifier) String() string {
n.l.RLock()
defer n.l.RUnlock()
str := []string{"Notifier, in map:\n"}
for k, v := range n.nodes {
str = append(str, fmt.Sprintf("\t%s: %v\n", k, v))
}
return strings.Join(str, "")
}

View file

@ -14,29 +14,8 @@ import (
"go4.org/netipx"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
var ipComparer = cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0
})
var mkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
return x.String() == y.String()
})
var nkeyComparer = cmp.Comparer(func(x, y key.NodePublic) bool {
return x.String() == y.String()
})
var dkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
return x.String() == y.String()
})
var keyComparers []cmp.Option = []cmp.Option{
mkeyComparer, nkeyComparer, dkeyComparer,
}
func Test(t *testing.T) {
check.TestingT(t)
}
@ -969,7 +948,7 @@ func Test_listNodesInUser(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
got := filterNodesByUser(test.args.nodes, test.args.user)
if diff := cmp.Diff(test.want, got, keyComparers...); diff != "" {
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff)
}
})
@ -1733,7 +1712,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
test.args.nodes,
test.args.user,
)
if diff := cmp.Diff(test.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" {
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("excludeCorrectlyTaggedNodes() (-want +got):\n%s", diff)
}
})
@ -2085,10 +2064,6 @@ func Test_getTags(t *testing.T) {
}
func Test_getFilteredByACLPeers(t *testing.T) {
ipComparer := cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0
})
type args struct {
nodes types.Nodes
rules []tailcfg.FilterRule
@ -2752,7 +2727,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
tt.args.nodes,
tt.args.rules,
)
if diff := cmp.Diff(tt.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" {
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
}
})

View file

@ -9,6 +9,7 @@ import (
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
xslices "golang.org/x/exp/slices"
"tailscale.com/tailcfg"
)
@ -61,7 +62,7 @@ func (h *Headscale) handlePoll(
) {
logInfo, logErr := logPollFunc(mapRequest, node)
// This is the mechanism where the node gives us inforamtion about its
// This is the mechanism where the node gives us information about its
// current configuration.
//
// If OmitPeers is true, Stream is false, and ReadOnly is false,
@ -69,6 +70,7 @@ func (h *Headscale) handlePoll(
// breaking existing long-polling (Stream == true) connections.
// In this case, the server can omit the entire response; the client
// only checks the HTTP response status code.
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
if mapRequest.OmitPeers && !mapRequest.Stream && !mapRequest.ReadOnly {
log.Info().
Caller().
@ -78,14 +80,85 @@ func (h *Headscale) handlePoll(
Str("node_key", node.NodeKey.ShortString()).
Str("node", node.Hostname).
Int("cap_ver", int(mapRequest.Version)).
Msg("Received endpoint update")
Msg("Received update")
change := node.PeerChangeFromMapRequest(mapRequest)
online := h.nodeNotifier.IsConnected(node.MachineKey)
change.Online = &online
node.ApplyPeerChange(&change)
hostInfoChange := node.Hostinfo.Equal(mapRequest.Hostinfo)
logTracePeerChange(node.Hostname, hostInfoChange, &change)
// Check if the Hostinfo of the node has changed.
// If it has changed, check if there has been a change tod
// the routable IPs of the host and update update them in
// the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the route change.
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if !hostInfoChange {
oldRoutes := node.Hostinfo.RoutableIPs
newRoutes := mapRequest.Hostinfo.RoutableIPs
oldServicesCount := len(node.Hostinfo.Services)
newServicesCount := len(mapRequest.Hostinfo.Services)
now := time.Now().UTC()
node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname
node.Hostinfo = mapRequest.Hostinfo
node.DiscoKey = mapRequest.DiscoKey
node.Endpoints = mapRequest.Endpoints
sendUpdate := false
// Route changes come as part of Hostinfo, which means that
// when an update comes, the Node Route logic need to run.
// This will require a "change" in comparison to a "patch",
// which is more costly.
if !xslices.Equal(oldRoutes, newRoutes) {
var err error
sendUpdate, err = h.db.SaveNodeRoutes(node)
if err != nil {
logErr(err, "Error processing node routes")
http.Error(writer, "", http.StatusInternalServerError)
return
}
}
// Services is mostly useful for discovery and not critical,
// except for peerapi, which is how nodes talk to eachother.
// If peerapi was not part of the initial mapresponse, we
// need to make sure its sent out later as it is needed for
// Taildrop.
// TODO(kradalby): Length comparison is a bit naive, replace.
if oldServicesCount != newServicesCount {
sendUpdate = true
}
if sendUpdate {
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from handlePoll -> update -> new hostinfo",
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(
stateUpdate,
node.MachineKey.String())
}
return
}
}
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
@ -94,20 +167,15 @@ func (h *Headscale) handlePoll(
return
}
err := h.db.SaveNodeRoutes(node)
if err != nil {
logErr(err, "Error processing node routes")
http.Error(writer, "", http.StatusInternalServerError)
return
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{&change},
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(
types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
},
stateUpdate,
node.MachineKey.String())
}
writer.WriteHeader(http.StatusOK)
if f, ok := writer.(http.Flusher); ok {
@ -115,7 +183,7 @@ func (h *Headscale) handlePoll(
}
return
} else if mapRequest.OmitPeers && !mapRequest.Stream && mapRequest.ReadOnly {
// ReadOnly is whether the client just wants to fetch the
// MapResponse, without updating their Endpoints. The
// Endpoints field will be ignored and LastSeen will not be
@ -133,12 +201,39 @@ func (h *Headscale) handlePoll(
return
}
now := time.Now().UTC()
node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname
change := node.PeerChangeFromMapRequest(mapRequest)
// A stream is being set up, the node is Online
online := true
change.Online = &online
node.ApplyPeerChange(&change)
// Only save HostInfo if changed, update routes if changed
// TODO(kradalby): Remove when capver is over 68
if !node.Hostinfo.Equal(mapRequest.Hostinfo) {
oldRoutes := node.Hostinfo.RoutableIPs
newRoutes := mapRequest.Hostinfo.RoutableIPs
node.Hostinfo = mapRequest.Hostinfo
node.DiscoKey = mapRequest.DiscoKey
node.Endpoints = mapRequest.Endpoints
if !xslices.Equal(oldRoutes, newRoutes) {
_, err := h.db.SaveNodeRoutes(node)
if err != nil {
logErr(err, "Error processing node routes")
http.Error(writer, "", http.StatusInternalServerError)
return
}
}
}
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
// When a node connects to control, list the peers it has at
// that given point, further updates are kept in memory in
@ -152,6 +247,11 @@ func (h *Headscale) handlePoll(
return
}
for _, peer := range peers {
online := h.nodeNotifier.IsConnected(peer.MachineKey)
peer.IsOnline = &online
}
mapp := mapper.NewMapper(
node,
peers,
@ -162,11 +262,6 @@ func (h *Headscale) handlePoll(
h.cfg.RandomizeClientPort,
)
err = h.db.SaveNodeRoutes(node)
if err != nil {
logErr(err, "Error processing node routes")
}
// update ACLRules with peer informations (to update server tags if necessary)
if h.ACLPolicy != nil {
// update routes with peer information
@ -176,14 +271,6 @@ func (h *Headscale) handlePoll(
}
}
// TODO(kradalby): Save specific stuff, not whole object.
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
logInfo("Sending initial map")
mapResp, err := mapp.FullMapResponse(mapRequest, node, h.ACLPolicy)
@ -208,18 +295,26 @@ func (h *Headscale) handlePoll(
return
}
h.nodeNotifier.NotifyWithIgnore(
types.StateUpdate{
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
},
ChangeNodes: types.Nodes{node},
Message: "called from handlePoll -> new node added",
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(
stateUpdate,
node.MachineKey.String())
}
// Set up the client stream
h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done()
updateChan := make(chan types.StateUpdate)
// Use a buffered channel in case a node is not fully ready
// to receive a message to make sure we dont block the entire
// notifier.
// 12 is arbitrarily chosen.
updateChan := make(chan types.StateUpdate, 12)
defer closeChanWithLog(updateChan, node.Hostname, "updateChan")
// Register the node's update channel
@ -233,6 +328,10 @@ func (h *Headscale) handlePoll(
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if len(node.Routes) > 0 {
go h.db.EnsureFailoverRouteIsAvailable(node)
}
for {
logInfo("Waiting for update on stream channel")
select {
@ -262,14 +361,7 @@ func (h *Headscale) handlePoll(
// One alternative is to split these different channels into
// goroutines, but then you might have a problem without a lock
// if a keepalive is written at the same time as an update.
go func() {
err = h.db.UpdateLastSeen(node)
if err != nil {
logErr(err, "Cannot update node LastSeen")
return
}
}()
go h.updateNodeOnlineStatus(true, node)
case update := <-updateChan:
logInfo("Received update")
@ -279,18 +371,35 @@ func (h *Headscale) handlePoll(
var err error
switch update.Type {
case types.StateFullUpdate:
logInfo("Sending Full MapResponse")
data, err = mapp.FullMapResponse(mapRequest, node, h.ACLPolicy)
case types.StatePeerChanged:
logInfo("Sending PeerChanged MapResponse")
data, err = mapp.PeerChangedResponse(mapRequest, node, update.Changed, h.ACLPolicy)
logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message))
for _, node := range update.ChangeNodes {
// If a node is not reported to be online, it might be
// because the value is outdated, check with the notifier.
// However, if it is set to Online, and not in the notifier,
// this might be because it has announced itself, but not
// reached the stage to actually create the notifier channel.
if node.IsOnline != nil && !*node.IsOnline {
isOnline := h.nodeNotifier.IsConnected(node.MachineKey)
node.IsOnline = &isOnline
}
}
data, err = mapp.PeerChangedResponse(mapRequest, node, update.ChangeNodes, h.ACLPolicy, update.Message)
case types.StatePeerChangedPatch:
logInfo("Sending PeerChangedPatch MapResponse")
data, err = mapp.PeerChangedPatchResponse(mapRequest, node, update.ChangePatches, h.ACLPolicy)
case types.StatePeerRemoved:
logInfo("Sending PeerRemoved MapResponse")
data, err = mapp.PeerRemovedResponse(mapRequest, node, update.Removed)
case types.StateDERPUpdated:
logInfo("Sending DERPUpdate MapResponse")
data, err = mapp.DERPMapResponse(mapRequest, node, update.DERPMap)
case types.StateFullUpdate:
logInfo("Sending Full MapResponse")
data, err = mapp.FullMapResponse(mapRequest, node, h.ACLPolicy)
}
if err != nil {
@ -299,6 +408,8 @@ func (h *Headscale) handlePoll(
return
}
// Only send update if there is change
if data != nil {
_, err = writer.Write(data)
if err != nil {
logErr(err, "Could not write the map response")
@ -317,36 +428,25 @@ func (h *Headscale) handlePoll(
return
}
// See comment in keepAliveTicker
go func() {
err = h.db.UpdateLastSeen(node)
if err != nil {
logErr(err, "Cannot update node LastSeen")
return
}
}()
log.Info().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Str("node_key", node.NodeKey.ShortString()).
Str("machine_key", node.MachineKey.ShortString()).
Str("node", node.Hostname).
TimeDiff("timeSpent", time.Now(), now).
Msg("update sent")
}
case <-ctx.Done():
logInfo("The client has closed the connection")
go func() {
err = h.db.UpdateLastSeen(node)
if err != nil {
logErr(err, "Cannot update node LastSeen")
go h.updateNodeOnlineStatus(false, node)
return
}
}()
// Failover the node's routes if any.
go h.db.FailoverNodeRoutesWithNotify(node)
// The connection has been closed, so we can stop polling.
return
@ -359,6 +459,36 @@ func (h *Headscale) handlePoll(
}
}
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
now := time.Now()
node.LastSeen = &now
statusUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
Online: &online,
LastSeen: &now,
},
},
}
if statusUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String())
}
err := h.db.UpdateLastSeen(node)
if err != nil {
log.Error().Err(err).Msg("Cannot update node LastSeen")
return
}
}
func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, node, name string) {
log.Trace().
Str("handler", "PollNetMap").
@ -378,8 +508,6 @@ func (h *Headscale) handleLiteRequest(
mapp := mapper.NewMapper(
node,
// TODO(kradalby): It might not be acceptable to send
// an empty peer list here.
types.Nodes{},
h.DERPMap,
h.cfg.BaseDomain,
@ -405,3 +533,38 @@ func (h *Headscale) handleLiteRequest(
logErr(err, "Failed to write response")
}
}
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
trace := log.Trace().Str("node_id", change.NodeID.String()).Str("hostname", hostname)
if change.Key != nil {
trace = trace.Str("node_key", change.Key.ShortString())
}
if change.DiscoKey != nil {
trace = trace.Str("disco_key", change.DiscoKey.ShortString())
}
if change.Online != nil {
trace = trace.Bool("online", *change.Online)
}
if change.Endpoints != nil {
eps := make([]string, len(change.Endpoints))
for idx, ep := range change.Endpoints {
eps[idx] = ep.String()
}
trace = trace.Strs("endpoints", eps)
}
if hostinfoChange {
trace = trace.Bool("hostinfo_changed", hostinfoChange)
}
if change.DERPRegion != 0 {
trace = trace.Int("derp_region", change.DERPRegion)
}
trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received")
}

View file

@ -84,20 +84,31 @@ type StateUpdateType int
const (
StateFullUpdate StateUpdateType = iota
// StatePeerChanged is used for updates that needs
// to be calculated with all peers and all policy rules.
// This would typically be things that include tags, routes
// and similar.
StatePeerChanged
StatePeerChangedPatch
StatePeerRemoved
StateDERPUpdated
)
// StateUpdate is an internal message containing information about
// a state change that has happened to the network.
// If type is StateFullUpdate, all fields are ignored.
type StateUpdate struct {
// The type of update
Type StateUpdateType
// Changed must be set when Type is StatePeerChanged and
// contain the Node IDs of nodes that have changed.
Changed Nodes
// ChangeNodes must be set when Type is StatePeerAdded
// and StatePeerChanged and contains the full node
// object for added nodes.
ChangeNodes Nodes
// ChangePatches must be set when Type is StatePeerChangedPatch
// and contains a populated PeerChange object.
ChangePatches []*tailcfg.PeerChange
// Removed must be set when Type is StatePeerRemoved and
// contain a list of the nodes that has been removed from
@ -106,5 +117,36 @@ type StateUpdate struct {
// DERPMap must be set when Type is StateDERPUpdated and
// contain the new DERP Map.
DERPMap tailcfg.DERPMap
DERPMap *tailcfg.DERPMap
// Additional message for tracking origin or what being
// updated, useful for ambiguous updates like StatePeerChanged.
Message string
}
// Valid reports if a StateUpdate is correctly filled and
// panics if the mandatory fields for a type is not
// filled.
// Reports true if valid.
func (su *StateUpdate) Valid() bool {
switch su.Type {
case StatePeerChanged:
if su.ChangeNodes == nil {
panic("Mandatory field ChangeNodes is not set on StatePeerChanged update")
}
case StatePeerChangedPatch:
if su.ChangePatches == nil {
panic("Mandatory field ChangePatches is not set on StatePeerChangedPatch update")
}
case StatePeerRemoved:
if su.Removed == nil {
panic("Mandatory field Removed is not set on StatePeerRemove update")
}
case StateDERPUpdated:
if su.DERPMap == nil {
panic("Mandatory field DERPMap is not set on StateDERPUpdated update")
}
}
return true
}

View file

@ -21,7 +21,9 @@ import (
var (
ErrNodeAddressesInvalid = errors.New("failed to parse node addresses")
ErrHostnameTooLong = errors.New("hostname too long")
ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars")
ErrNodeHasNoGivenName = errors.New("node has no given name")
ErrNodeUserHasNoName = errors.New("node user has no name")
)
// Node is a Headscale client.
@ -95,22 +97,14 @@ type Node struct {
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
IsOnline *bool `gorm:"-"`
}
type (
Nodes []*Node
)
func (nodes Nodes) OnlineNodeMap() map[tailcfg.NodeID]bool {
ret := make(map[tailcfg.NodeID]bool)
for _, node := range nodes {
ret[tailcfg.NodeID(node.ID)] = node.IsOnline()
}
return ret
}
type NodeAddresses []netip.Addr
func (na NodeAddresses) Sort() {
@ -206,21 +200,6 @@ func (node Node) IsExpired() bool {
return time.Now().UTC().After(*node.Expiry)
}
// IsOnline returns if the node is connected to Headscale.
// This is really a naive implementation, as we don't really see
// if there is a working connection between the client and the server.
func (node *Node) IsOnline() bool {
if node.LastSeen == nil {
return false
}
if node.IsExpired() {
return false
}
return node.LastSeen.After(time.Now().Add(-KeepAliveInterval))
}
// IsEphemeral returns if the node is registered as an Ephemeral node.
// https://tailscale.com/kb/1111/ephemeral-nodes/
func (node *Node) IsEphemeral() bool {
@ -339,7 +318,6 @@ func (node *Node) Proto() *v1.Node {
GivenName: node.GivenName,
User: node.User.Proto(),
ForcedTags: node.ForcedTags,
Online: node.IsOnline(),
// TODO(kradalby): Implement register method enum converter
// RegisterMethod: ,
@ -365,6 +343,14 @@ func (node *Node) Proto() *v1.Node {
func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) {
var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
if node.GivenName == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
}
if node.User.Name == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
}
hostname = fmt.Sprintf(
"%s.%s.%s",
node.GivenName,
@ -373,7 +359,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
)
if len(hostname) > MaxHostnameLength {
return "", fmt.Errorf(
"hostname %q is too long it cannot except 255 ASCII chars: %w",
"failed to create valid FQDN (%s): %w",
hostname,
ErrHostnameTooLong,
)
@ -385,8 +371,98 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
return hostname, nil
}
func (node Node) String() string {
return node.Hostname
// func (node *Node) String() string {
// return node.Hostname
// }
// PeerChangeFromMapRequest takes a MapRequest and compares it to the node
// to produce a PeerChange struct that can be used to updated the node and
// inform peers about smaller changes to the node.
// When a field is added to this function, remember to also add it to:
// - node.ApplyPeerChange
// - logTracePeerChange in poll.go
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
ret := tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
}
if node.NodeKey.String() != req.NodeKey.String() {
ret.Key = &req.NodeKey
}
if node.DiscoKey.String() != req.DiscoKey.String() {
ret.DiscoKey = &req.DiscoKey
}
if node.Hostinfo != nil &&
node.Hostinfo.NetInfo != nil &&
req.Hostinfo != nil &&
req.Hostinfo.NetInfo != nil &&
node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
}
if req.Hostinfo != nil && req.Hostinfo.NetInfo != nil {
// If there is no stored Hostinfo or NetInfo, use
// the new PreferredDERP.
if node.Hostinfo == nil {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
} else if node.Hostinfo.NetInfo == nil {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
} else {
// If there is a PreferredDERP check if it has changed.
if node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
}
}
}
// TODO(kradalby): Find a good way to compare updates
ret.Endpoints = req.Endpoints
now := time.Now()
ret.LastSeen = &now
return ret
}
// ApplyPeerChange takes a PeerChange struct and updates the node.
func (node *Node) ApplyPeerChange(change *tailcfg.PeerChange) {
if change.Key != nil {
node.NodeKey = *change.Key
}
if change.DiscoKey != nil {
node.DiscoKey = *change.DiscoKey
}
if change.Online != nil {
node.IsOnline = change.Online
}
if change.Endpoints != nil {
node.Endpoints = change.Endpoints
}
// This might technically not be useful as we replace
// the whole hostinfo blob when it has changed.
if change.DERPRegion != 0 {
if node.Hostinfo == nil {
node.Hostinfo = &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: change.DERPRegion,
},
}
} else if node.Hostinfo.NetInfo == nil {
node.Hostinfo.NetInfo = &tailcfg.NetInfo{
PreferredDERP: change.DERPRegion,
}
} else {
node.Hostinfo.NetInfo.PreferredDERP = change.DERPRegion
}
}
node.LastSeen = change.LastSeen
}
func (nodes Nodes) String() string {

View file

@ -4,7 +4,10 @@ import (
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func Test_NodeCanAccess(t *testing.T) {
@ -139,3 +142,227 @@ func TestNodeAddressesOrder(t *testing.T) {
}
}
}
func TestNodeFQDN(t *testing.T) {
tests := []struct {
name string
node Node
dns tailcfg.DNSConfig
domain string
want string
wantErr string
}{
{
name: "all-set",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
dns: tailcfg.DNSConfig{
Proxied: true,
},
domain: "example.com",
want: "test.user.example.com",
},
{
name: "no-given-name",
node: Node{
User: User{
Name: "user",
},
},
dns: tailcfg.DNSConfig{
Proxied: true,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node has no given name",
},
{
name: "no-user-name",
node: Node{
GivenName: "test",
User: User{},
},
dns: tailcfg.DNSConfig{
Proxied: true,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node user has no name",
},
{
name: "no-magic-dns",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
dns: tailcfg.DNSConfig{
Proxied: false,
},
domain: "example.com",
want: "test",
},
{
name: "no-dnsconfig",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
domain: "example.com",
want: "test",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.node.GetFQDN(&tc.dns, tc.domain)
if (err != nil) && (err.Error() != tc.wantErr) {
t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr)
return
}
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("GetFQDN unexpected result (-want +got):\n%s", diff)
}
})
}
}
func TestPeerChangeFromMapRequest(t *testing.T) {
nKeys := []key.NodePublic{
key.NewNode().Public(),
key.NewNode().Public(),
key.NewNode().Public(),
}
dKeys := []key.DiscoPublic{
key.NewDisco().Public(),
key.NewDisco().Public(),
key.NewDisco().Public(),
}
tests := []struct {
name string
node Node
mapReq tailcfg.MapRequest
want tailcfg.PeerChange
}{
{
name: "preferred-derp-changed",
node: Node{
ID: 1,
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Endpoints: []netip.AddrPort{},
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 998,
},
},
},
mapReq: tailcfg.MapRequest{
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 999,
},
},
},
want: tailcfg.PeerChange{
NodeID: 1,
DERPRegion: 999,
},
},
{
name: "preferred-derp-no-changed",
node: Node{
ID: 1,
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Endpoints: []netip.AddrPort{},
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 100,
},
},
},
mapReq: tailcfg.MapRequest{
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 100,
},
},
},
want: tailcfg.PeerChange{
NodeID: 1,
DERPRegion: 0,
},
},
{
name: "preferred-derp-no-mapreq-netinfo",
node: Node{
ID: 1,
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Endpoints: []netip.AddrPort{},
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 200,
},
},
},
mapReq: tailcfg.MapRequest{
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Hostinfo: &tailcfg.Hostinfo{},
},
want: tailcfg.PeerChange{
NodeID: 1,
DERPRegion: 0,
},
},
{
name: "preferred-derp-no-node-netinfo",
node: Node{
ID: 1,
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Endpoints: []netip.AddrPort{},
Hostinfo: &tailcfg.Hostinfo{},
},
mapReq: tailcfg.MapRequest{
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 200,
},
},
},
want: tailcfg.PeerChange{
NodeID: 1,
DERPRegion: 200,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.node.PeerChangeFromMapRequest(tc.mapReq)
if diff := cmp.Diff(tc.want, got, cmpopts.IgnoreFields(tailcfg.PeerChange{}, "LastSeen")); diff != "" {
t.Errorf("Patch unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -19,6 +19,8 @@ type Route struct {
NodeID uint64
Node Node
// TODO(kradalby): change this custom type to netip.Prefix
Prefix IPPrefix
Advertised bool
@ -29,13 +31,17 @@ type Route struct {
type Routes []Route
func (r *Route) String() string {
return fmt.Sprintf("%s:%s", r.Node, netip.Prefix(r.Prefix).String())
return fmt.Sprintf("%s:%s", r.Node.Hostname, netip.Prefix(r.Prefix).String())
}
func (r *Route) IsExitRoute() bool {
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
}
func (r *Route) IsAnnouncable() bool {
return r.Advertised && r.Enabled
}
func (rs Routes) Prefixes() []netip.Prefix {
prefixes := make([]netip.Prefix, len(rs))
for i, r := range rs {
@ -45,6 +51,32 @@ func (rs Routes) Prefixes() []netip.Prefix {
return prefixes
}
// Primaries returns Primary routes from a list of routes.
func (rs Routes) Primaries() Routes {
res := make(Routes, 0)
for _, route := range rs {
if route.IsPrimary {
res = append(res, route)
}
}
return res
}
func (rs Routes) PrefixMap() map[IPPrefix][]Route {
res := map[IPPrefix][]Route{}
for _, route := range rs {
if _, ok := res[route.Prefix]; ok {
res[route.Prefix] = append(res[route.Prefix], route)
} else {
res[route.Prefix] = []Route{route}
}
}
return res
}
func (rs Routes) Proto() []*v1.Route {
protoRoutes := []*v1.Route{}

View file

@ -0,0 +1,94 @@
package types
import (
"fmt"
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/util"
)
func TestPrefixMap(t *testing.T) {
ipp := func(s string) IPPrefix { return IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Remove when we have gotten rid of IPPrefix type
prefixComparer := cmp.Comparer(func(x, y IPPrefix) bool {
return x == y
})
tests := []struct {
rs Routes
want map[IPPrefix][]Route
}{
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
want: map[IPPrefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
},
},
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
Route{
Prefix: ipp("10.0.1.0/24"),
},
},
want: map[IPPrefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
ipp("10.0.1.0/24"): Routes{
Route{
Prefix: ipp("10.0.1.0/24"),
},
},
},
},
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: true,
},
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: false,
},
},
want: map[IPPrefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: true,
},
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: false,
},
},
},
},
}
for idx, tt := range tests {
t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) {
got := tt.rs.PrefixMap()
if diff := cmp.Diff(tt.want, got, prefixComparer, util.MkeyComparer, util.NkeyComparer, util.DkeyComparer); diff != "" {
t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff)
}
})
}
}

32
hscontrol/util/test.go Normal file
View file

@ -0,0 +1,32 @@
package util
import (
"net/netip"
"github.com/google/go-cmp/cmp"
"tailscale.com/types/key"
)
var PrefixComparer = cmp.Comparer(func(x, y netip.Prefix) bool {
return x == y
})
var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0
})
var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
return x.String() == y.String()
})
var NkeyComparer = cmp.Comparer(func(x, y key.NodePublic) bool {
return x.String() == y.String()
})
var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
return x.String() == y.String()
})
var Comparers []cmp.Option = []cmp.Option{
IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer,
}

View file

@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"sort"
"strconv"
"testing"
"time"
@ -22,7 +21,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul
err = json.Unmarshal([]byte(str), result)
if err != nil {
return err
return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str)
}
return nil
@ -178,7 +177,11 @@ func TestPreAuthKeyCommand(t *testing.T) {
assert.Equal(
t,
[]string{keys[0].GetId(), keys[1].GetId(), keys[2].GetId()},
[]string{listedPreAuthKeys[1].GetId(), listedPreAuthKeys[2].GetId(), listedPreAuthKeys[3].GetId()},
[]string{
listedPreAuthKeys[1].GetId(),
listedPreAuthKeys[2].GetId(),
listedPreAuthKeys[3].GetId(),
},
)
assert.NotEmpty(t, listedPreAuthKeys[1].GetKey())
@ -384,141 +387,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
assert.Len(t, listedPreAuthKeys, 3)
}
func TestEnablingRoutes(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "enable-routing"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
// advertise routes using the up command
for i, client := range allClients {
routeStr := fmt.Sprintf("10.0.%d.0/24", i)
command := []string{
"tailscale",
"set",
"--advertise-routes=" + routeStr,
}
_, _, err := client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 3)
for _, route := range routes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
}
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)
}
var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
}
routeIDToBeDisabled := enablingRoutes[0].GetId()
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"disable",
"--route",
strconv.Itoa(int(routeIDToBeDisabled)),
})
assertNoErr(t, err)
var disablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&disablingRoutes,
)
assertNoErr(t, err)
for _, route := range disablingRoutes {
assert.Equal(t, true, route.GetAdvertised())
if route.GetId() == routeIDToBeDisabled {
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
} else {
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
}
}
}
func TestApiKeyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()

View file

@ -44,6 +44,9 @@ func TestDERPServerScenario(t *testing.T) {
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP"
headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478"
headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key"
// Envknob for enabling DERP debug logs
headscaleConfig["DERP_DEBUG_LOGS"] = "true"
headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true"
err = scenario.CreateHeadscaleEnv(
spec,

View file

@ -14,6 +14,8 @@ import (
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/types/key"
)
func TestPingAllByIP(t *testing.T) {
@ -248,9 +250,8 @@ func TestPingAllByHostname(t *testing.T) {
defer scenario.Shutdown()
spec := map[string]int{
// Omit 1.16.2 (-1) because it does not have the FQDN field
"user3": len(MustTestVersions) - 1,
"user4": len(MustTestVersions) - 1,
"user3": len(MustTestVersions),
"user4": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname"))
@ -296,8 +297,7 @@ func TestTaildrop(t *testing.T) {
defer scenario.Shutdown()
spec := map[string]int{
// Omit 1.16.2 (-1) because it does not have the FQDN field
"taildrop": len(MustTestVersions) - 1,
"taildrop": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("taildrop"))
@ -313,6 +313,42 @@ func TestTaildrop(t *testing.T) {
_, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err)
for _, client := range allClients {
if !strings.Contains(client.Hostname(), "head") {
command := []string{"apk", "add", "curl"}
_, _, err := client.Execute(command)
if err != nil {
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
}
}
curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"}
err = retry(10, 1*time.Second, func() error {
result, _, err := client.Execute(curlCommand)
if err != nil {
return err
}
var fts []apitype.FileTarget
err = json.Unmarshal([]byte(result), &fts)
if err != nil {
return err
}
if len(fts) != len(allClients)-1 {
ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname())
for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
}
return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr)
}
return err
})
if err != nil {
t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err)
}
}
for _, client := range allClients {
command := []string{"touch", fmt.Sprintf("/tmp/file_from_%s", client.Hostname())}
@ -347,8 +383,9 @@ func TestTaildrop(t *testing.T) {
})
if err != nil {
t.Fatalf(
"failed to send taildrop file on %s, err: %s",
"failed to send taildrop file on %s with command %q, err: %s",
client.Hostname(),
strings.Join(command, " "),
err,
)
}
@ -517,10 +554,139 @@ func TestExpireNode(t *testing.T) {
err = json.Unmarshal([]byte(result), &node)
assertNoErr(t, err)
var expiredNodeKey key.NodePublic
err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey()))
assertNoErr(t, err)
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
time.Sleep(30 * time.Second)
// Verify that the expired not is no longer present in the Peer list
// of connected nodes.
now := time.Now()
// Verify that the expired node has been marked in all peers list.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
if client.Hostname() != node.GetName() {
t.Logf("available peers of %s: %v", client.Hostname(), status.Peers())
// In addition to marking nodes expired, we filter them out during the map response
// this check ensures that the node is either not present, or that it is expired
// if it is in the map response.
if peerStatus, ok := status.Peer[expiredNodeKey]; ok {
assertNotNil(t, peerStatus.Expired)
assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %s should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
assert.Truef(t, peerStatus.Expired, "node %s should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired)
}
// TODO(kradalby): We do not propogate expiry correctly, nodes should be aware
// of their status, and this should be sent directly to the node when its
// expired. This needs a notifier that goes directly to the node (currently we only do peers)
// so fix this in a follow up PR.
// } else {
// assert.True(t, status.Self.Expired)
}
}
}
func TestNodeOnlineLastSeenStatus(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario()
assertNoErr(t, err)
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("onlinelastseen"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps))
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
// Assert that we have the original count - self
assert.Len(t, status.Peers(), len(MustTestVersions)-1)
}
headscale, err := scenario.Headscale()
assertNoErr(t, err)
keepAliveInterval := 60 * time.Second
// Duration is chosen arbitrarily, 10m is reported in #1561
testDuration := 12 * time.Minute
start := time.Now()
end := start.Add(testDuration)
log.Printf("Starting online test from %v to %v", start, end)
for {
// Let the test run continuously for X minutes to verify
// all nodes stay connected and has the expected status over time.
if end.Before(time.Now()) {
return
}
result, err := headscale.Execute([]string{
"headscale", "nodes", "list", "--output", "json",
})
assertNoErr(t, err)
var nodes []*v1.Node
err = json.Unmarshal([]byte(result), &nodes)
assertNoErr(t, err)
now := time.Now()
// Threshold with some leeway
lastSeenThreshold := now.Add(-keepAliveInterval - (10 * time.Second))
// Verify that headscale reports the nodes as online
for _, node := range nodes {
// All nodes should be online
assert.Truef(
t,
node.GetOnline(),
"expected %s to have online status in Headscale, marked as offline %s after start",
node.GetName(),
time.Since(start),
)
lastSeen := node.GetLastSeen().AsTime()
// All nodes should have been last seen between now and the keepAliveInterval
assert.Truef(
t,
lastSeen.After(lastSeenThreshold),
"lastSeen (%v) was not %s after the threshold (%v)",
lastSeen,
keepAliveInterval,
lastSeenThreshold,
)
}
// Verify that all nodes report all nodes to be online
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
@ -528,14 +694,36 @@ func TestExpireNode(t *testing.T) {
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
peerPublicKey := strings.TrimPrefix(peerStatus.PublicKey.String(), "nodekey:")
assert.NotEqual(t, node.GetNodeKey(), peerPublicKey)
// .Online is only available from CapVer 16, which
// is not present in 1.18 which is the lowest we
// test.
if strings.Contains(client.Hostname(), "1-18") {
continue
}
if client.Hostname() != node.GetName() {
// Assert that we have the original count - self - expired node
assert.Len(t, status.Peers(), len(MustTestVersions)-2)
// All peers of this nodess are reporting to be
// connected to the control server
assert.Truef(
t,
peerStatus.Online,
"expected node %s to be marked as online in %s peer list, marked as offline %s after start",
peerStatus.HostName,
client.Hostname(),
time.Since(start),
)
// from docs: last seen to tailcontrol; only present if offline
// assert.Nilf(
// t,
// peerStatus.LastSeen,
// "expected node %s to not have LastSeen set, got %s",
// peerStatus.HostName,
// peerStatus.LastSeen,
// )
}
}
// Check maximum once per second
time.Sleep(time.Second)
}
}

780
integration/route_test.go Normal file
View file

@ -0,0 +1,780 @@
package integration
import (
"fmt"
"log"
"net/netip"
"sort"
"strconv"
"testing"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
)
// This test is both testing the routes command and the propagation of
// routes.
func TestEnablingRoutes(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "enable-routing"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
expectedRoutes := map[string]string{
"1": "10.0.0.0/24",
"2": "10.0.1.0/24",
"3": "10.0.2.0/24",
}
// advertise routes using the up command
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
command := []string{
"tailscale",
"set",
"--advertise-routes=" + expectedRoutes[string(status.Self.ID)],
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 3)
for _, route := range routes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
}
// Verify that no routes has been sent to the client,
// they are not yet enabled.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.Nil(t, peerStatus.PrimaryRoutes)
}
}
// Enable all routes
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)
}
var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
}
time.Sleep(5 * time.Second)
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.NotNil(t, peerStatus.PrimaryRoutes)
if peerStatus.PrimaryRoutes == nil {
continue
}
pRoutes := peerStatus.PrimaryRoutes.AsSlice()
assert.Len(t, pRoutes, 1)
if len(pRoutes) > 0 {
peerRoute := peerStatus.PrimaryRoutes.AsSlice()[0]
// id starts at 1, we created routes with 0 index
assert.Equalf(
t,
expectedRoutes[string(peerStatus.ID)],
peerRoute.String(),
"expected route %s to be present on peer %s (%s) in %s (%s) status",
expectedRoutes[string(peerStatus.ID)],
peerStatus.HostName,
peerStatus.ID,
client.Hostname(),
client.ID(),
)
}
}
}
routeToBeDisabled := enablingRoutes[0]
log.Printf("preparing to disable %v", routeToBeDisabled)
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"disable",
"--route",
strconv.Itoa(int(routeToBeDisabled.GetId())),
})
assertNoErr(t, err)
var disablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&disablingRoutes,
)
assertNoErr(t, err)
for _, route := range disablingRoutes {
assert.Equal(t, true, route.GetAdvertised())
if route.GetId() == routeToBeDisabled.GetId() {
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
} else {
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
}
}
time.Sleep(5 * time.Second)
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
if string(peerStatus.ID) == fmt.Sprintf("%d", routeToBeDisabled.GetNode().GetId()) {
assert.Nilf(
t,
peerStatus.PrimaryRoutes,
"expected node %s to have no routes, got primary route (%v)",
peerStatus.HostName,
peerStatus.PrimaryRoutes,
)
}
}
}
}
func TestHASubnetRouterFailover(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "enable-routing"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
expectedRoutes := map[string]string{
"1": "10.0.0.0/24",
"2": "10.0.0.0/24",
}
// Sort nodes by ID
sort.SliceStable(allClients, func(i, j int) bool {
statusI, err := allClients[i].Status()
if err != nil {
return false
}
statusJ, err := allClients[j].Status()
if err != nil {
return false
}
return statusI.Self.ID < statusJ.Self.ID
})
subRouter1 := allClients[0]
subRouter2 := allClients[1]
client := allClients[2]
// advertise HA route on node 1 and 2
// ID 1 will be primary
// ID 2 will be secondary
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
if route, ok := expectedRoutes[string(status.Self.ID)]; ok {
command := []string{
"tailscale",
"set",
"--advertise-routes=" + route,
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
}
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 2)
for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
}
// Verify that no routes has been sent to the client,
// they are not yet enabled.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.Nil(t, peerStatus.PrimaryRoutes)
}
}
// Enable all routes
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)
time.Sleep(time.Second)
}
var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 2)
// Node 1 is primary
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
// Node 2 is not primary
assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
assert.Equal(t, true, enablingRoutes[1].GetEnabled())
assert.Equal(t, false, enablingRoutes[1].GetIsPrimary())
// Verify that the client has routes from the primary machine
srs1, err := subRouter1.Status()
srs2, err := subRouter2.Status()
clientStatus, err := client.Status()
assertNoErr(t, err)
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
// Take down the current primary
t.Logf("taking down subnet router 1 (%s)", subRouter1.Hostname())
err = subRouter1.Down()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterMove []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterMove,
)
assertNoErr(t, err)
assert.Len(t, routesAfterMove, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfterMove[0].GetAdvertised())
assert.Equal(t, true, routesAfterMove[0].GetEnabled())
assert.Equal(t, false, routesAfterMove[0].GetIsPrimary())
// Node 2 is primary
assert.Equal(t, true, routesAfterMove[1].GetAdvertised())
assert.Equal(t, true, routesAfterMove[1].GetEnabled())
assert.Equal(t, true, routesAfterMove[1].GetIsPrimary())
// TODO(kradalby): Check client status
// Route is expected to be on SR2
srs2, err = subRouter2.Status()
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}
// Take down subnet router 2, leaving none available
t.Logf("taking down subnet router 2 (%s)", subRouter2.Hostname())
err = subRouter2.Down()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterBothDown []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterBothDown,
)
assertNoErr(t, err)
assert.Len(t, routesAfterBothDown, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised())
assert.Equal(t, true, routesAfterBothDown[0].GetEnabled())
assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary())
// Node 2 is primary
// if the node goes down, but no other suitable route is
// available, keep the last known good route.
assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised())
assert.Equal(t, true, routesAfterBothDown[1].GetEnabled())
assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary())
// TODO(kradalby): Check client status
// Both are expected to be down
// Verify that the route is not presented from either router
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}
// Bring up subnet router 1, making the route available from there.
t.Logf("bringing up subnet router 1 (%s)", subRouter1.Hostname())
err = subRouter1.Up()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfter1Up []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfter1Up,
)
assertNoErr(t, err)
assert.Len(t, routesAfter1Up, 2)
// Node 1 is primary
assert.Equal(t, true, routesAfter1Up[0].GetAdvertised())
assert.Equal(t, true, routesAfter1Up[0].GetEnabled())
assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary())
// Node 2 is not primary
assert.Equal(t, true, routesAfter1Up[1].GetAdvertised())
assert.Equal(t, true, routesAfter1Up[1].GetEnabled())
assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
}
// Bring up subnet router 2, should result in no change.
t.Logf("bringing up subnet router 2 (%s)", subRouter2.Hostname())
err = subRouter2.Up()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfter2Up []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfter2Up,
)
assertNoErr(t, err)
assert.Len(t, routesAfter2Up, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfter2Up[0].GetAdvertised())
assert.Equal(t, true, routesAfter2Up[0].GetEnabled())
assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary())
// Node 2 is primary
assert.Equal(t, true, routesAfter2Up[1].GetAdvertised())
assert.Equal(t, true, routesAfter2Up[1].GetEnabled())
assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
}
// Disable the route of subnet router 1, making it failover to 2
t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"disable",
"--route",
fmt.Sprintf("%d", routesAfter2Up[0].GetId()),
})
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterDisabling1 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterDisabling1,
)
assertNoErr(t, err)
assert.Len(t, routesAfterDisabling1, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())
assert.Equal(t, false, routesAfterDisabling1[0].GetIsPrimary())
// Node 2 is primary
assert.Equal(t, true, routesAfterDisabling1[1].GetAdvertised())
assert.Equal(t, true, routesAfterDisabling1[1].GetEnabled())
assert.Equal(t, true, routesAfterDisabling1[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.NotNil(t, srs2PeerStatus.PrimaryRoutes)
if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}
// enable the route of subnet router 1, no change expected
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
fmt.Sprintf("%d", routesAfter2Up[0].GetId()),
})
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterEnabling1 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterEnabling1,
)
assertNoErr(t, err)
assert.Len(t, routesAfterEnabling1, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfterEnabling1[0].GetAdvertised())
assert.Equal(t, true, routesAfterEnabling1[0].GetEnabled())
assert.Equal(t, false, routesAfterEnabling1[0].GetIsPrimary())
// Node 2 is primary
assert.Equal(t, true, routesAfterEnabling1[1].GetAdvertised())
assert.Equal(t, true, routesAfterEnabling1[1].GetEnabled())
assert.Equal(t, true, routesAfterEnabling1[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.NotNil(t, srs2PeerStatus.PrimaryRoutes)
if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}
// delete the route of subnet router 2, failover to one expected
t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"delete",
"--route",
fmt.Sprintf("%d", routesAfterEnabling1[1].GetId()),
})
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterDeleting2 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterDeleting2,
)
assertNoErr(t, err)
assert.Len(t, routesAfterDeleting2, 1)
t.Logf("routes after deleting2 %#v", routesAfterDeleting2)
// Node 1 is primary
assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised())
assert.Equal(t, true, routesAfterDeleting2[0].GetEnabled())
assert.Equal(t, true, routesAfterDeleting2[0].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
}
}

View file

@ -29,7 +29,7 @@ run_tests() {
-failfast \
-timeout 120m \
-parallel 1 \
-run "^$test_name\$" >/dev/null 2>&1
-run "^$test_name\$" >./control_logs/"$test_name"_"$i".log 2>&1
status=$?
end=$(date +%s)

View file

@ -15,6 +15,7 @@ import (
"github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3"
"github.com/puzpuzpuz/xsync/v3"
"github.com/samber/lo"
"golang.org/x/sync/errgroup"
)
@ -93,7 +94,7 @@ var (
//
// - Two unstable (HEAD and unstable)
// - Two latest versions
// - Two oldest versions.
// - Two oldest supported version.
MustTestVersions = append(
AllVersions[0:4],
AllVersions[len(AllVersions)-2:]...,
@ -296,11 +297,13 @@ func (s *Scenario) CreateTailscaleNodesInUser(
opts ...tsic.Option,
) error {
if user, ok := s.users[userStr]; ok {
var versions []string
for i := 0; i < count; i++ {
version := requestedVersion
if requestedVersion == "all" {
version = MustTestVersions[i%len(MustTestVersions)]
}
versions = append(versions, version)
headscale, err := s.Headscale()
if err != nil {
@ -350,6 +353,8 @@ func (s *Scenario) CreateTailscaleNodesInUser(
return err
}
log.Printf("testing versions %v", lo.Uniq(versions))
return nil
}
@ -403,7 +408,17 @@ func (s *Scenario) CountTailscale() int {
func (s *Scenario) WaitForTailscaleSync() error {
tsCount := s.CountTailscale()
return s.WaitForTailscaleSyncWithPeerCount(tsCount - 1)
err := s.WaitForTailscaleSyncWithPeerCount(tsCount - 1)
if err != nil {
for _, user := range s.users {
for _, client := range user.Clients {
peers, _ := client.PrettyPeers()
log.Println(peers)
}
}
}
return err
}
// WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports

View file

@ -109,7 +109,7 @@ func TestSSHOneUserToAll(t *testing.T) {
},
},
},
len(MustTestVersions)-2,
len(MustTestVersions),
)
defer scenario.Shutdown()
@ -174,7 +174,7 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
},
},
},
len(MustTestVersions)-2,
len(MustTestVersions),
)
defer scenario.Shutdown()
@ -220,7 +220,7 @@ func TestSSHNoSSHConfigured(t *testing.T) {
},
SSHs: []policy.SSH{},
},
len(MustTestVersions)-2,
len(MustTestVersions),
)
defer scenario.Shutdown()
@ -269,7 +269,7 @@ func TestSSHIsBlockedInACL(t *testing.T) {
},
},
},
len(MustTestVersions)-2,
len(MustTestVersions),
)
defer scenario.Shutdown()
@ -325,7 +325,7 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
},
},
},
len(MustTestVersions)-2,
len(MustTestVersions),
)
defer scenario.Shutdown()

View file

@ -21,6 +21,8 @@ type TailscaleClient interface {
Login(loginServer, authKey string) error
LoginWithURL(loginServer string) (*url.URL, error)
Logout() error
Up() error
Down() error
IPs() ([]netip.Addr, error)
FQDN() (string, error)
Status() (*ipnstate.Status, error)
@ -30,4 +32,5 @@ type TailscaleClient interface {
Ping(hostnameOrIP string, opts ...tsic.PingOption) error
Curl(url string, opts ...tsic.CurlOption) (string, error)
ID() string
PrettyPeers() (string, error)
}

View file

@ -285,6 +285,15 @@ func (t *TailscaleInContainer) hasTLS() bool {
// Shutdown stops and cleans up the Tailscale container.
func (t *TailscaleInContainer) Shutdown() error {
err := t.SaveLog("/tmp/control")
if err != nil {
log.Printf(
"Failed to save log from %s: %s",
t.hostname,
fmt.Errorf("failed to save log: %w", err),
)
}
return t.pool.Purge(t.container)
}
@ -417,6 +426,44 @@ func (t *TailscaleInContainer) Logout() error {
return nil
}
// Helper that runs `tailscale up` with no arguments.
func (t *TailscaleInContainer) Up() error {
command := []string{
"tailscale",
"up",
}
if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil {
return fmt.Errorf(
"%s failed to bring tailscale client up (%s): %w",
t.hostname,
strings.Join(command, " "),
err,
)
}
return nil
}
// Helper that runs `tailscale down` with no arguments.
func (t *TailscaleInContainer) Down() error {
command := []string{
"tailscale",
"down",
}
if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil {
return fmt.Errorf(
"%s failed to bring tailscale client down (%s): %w",
t.hostname,
strings.Join(command, " "),
err,
)
}
return nil
}
// IPs returns the netip.Addr of the Tailscale instance.
func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
if t.ips != nil && len(t.ips) != 0 {
@ -486,6 +533,34 @@ func (t *TailscaleInContainer) FQDN() (string, error) {
return status.Self.DNSName, nil
}
// PrettyPeers returns a formatted-ish table of peers in the client.
func (t *TailscaleInContainer) PrettyPeers() (string, error) {
status, err := t.Status()
if err != nil {
return "", fmt.Errorf("failed to get FQDN: %w", err)
}
str := fmt.Sprintf("Peers of %s\n", t.hostname)
str += "Hostname\tOnline\tLastSeen\n"
peerCount := len(status.Peers())
onlineCount := 0
for _, peerKey := range status.Peers() {
peer := status.Peer[peerKey]
if peer.Online {
onlineCount++
}
str += fmt.Sprintf("%s\t%t\t%s\n", peer.HostName, peer.Online, peer.LastSeen)
}
str += fmt.Sprintf("Peer Count: %d, Online Count: %d\n\n", peerCount, onlineCount)
return str, nil
}
// WaitForNeedsLogin blocks until the Tailscale (tailscaled) instance has
// started and needs to be logged into.
func (t *TailscaleInContainer) WaitForNeedsLogin() error {
@ -531,7 +606,7 @@ func (t *TailscaleInContainer) WaitForRunning() error {
}
// WaitForPeers blocks until N number of peers is present in the
// Peer list of the Tailscale instance.
// Peer list of the Tailscale instance and is reporting Online.
func (t *TailscaleInContainer) WaitForPeers(expected int) error {
return t.pool.Retry(func() error {
status, err := t.Status()
@ -547,6 +622,14 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error {
expected,
len(peers),
)
} else {
for _, peerKey := range peers {
peer := status.Peer[peerKey]
if !peer.Online {
return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)
}
}
}
return nil
@ -738,3 +821,9 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
}
// SaveLog saves the current stdout log of the container to a path
// on the host system.
func (t *TailscaleInContainer) SaveLog(path string) error {
return dockertestutil.SaveLog(t.pool, t.container, path)
}

View file

@ -26,6 +26,13 @@ func assertNoErrf(t *testing.T, msg string, err error) {
}
}
func assertNotNil(t *testing.T, thing interface{}) {
t.Helper()
if thing == nil {
t.Fatal("got unexpected nil")
}
}
func assertNoErrHeadscaleEnv(t *testing.T, err error) {
t.Helper()
assertNoErrf(t, "failed to create headscale environment: %s", err)
@ -68,13 +75,13 @@ func assertContains(t *testing.T, str, subStr string) {
}
}
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int {
t.Helper()
success := 0
for _, client := range clients {
for _, addr := range addrs {
err := client.Ping(addr)
err := client.Ping(addr, opts...)
if err != nil {
t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
} else {