ensure online status and route changes are propagated (#1564)

This commit is contained in:
Kristoffer Dalby 2023-12-09 18:09:24 +01:00 committed by GitHub
parent 0153e26392
commit f65f4eca35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 3170 additions and 857 deletions

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestHASubnetRouterFailover
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestHASubnetRouterFailover:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestHASubnetRouterFailover
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestHASubnetRouterFailover$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

View file

@ -0,0 +1,67 @@
# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go
# To regenerate, run "go generate" in cmd/gh-action-integration-generator/
name: Integration Test v2 - TestNodeOnlineLastSeenStatus
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
TestNodeOnlineLastSeenStatus:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 2
- uses: DeterminateSystems/nix-installer-action@main
- uses: DeterminateSystems/magic-nix-cache-action@main
- uses: satackey/action-docker-layer-caching@main
continue-on-error: true
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v34
with:
files: |
*.nix
go.*
**/*.go
integration_test/
config-example.yaml
- name: Run TestNodeOnlineLastSeenStatus
uses: Wandalen/wretry.action@master
if: steps.changed-files.outputs.any_changed == 'true'
with:
attempt_limit: 5
command: |
nix develop --command -- docker run \
--tty --rm \
--volume ~/.cache/hs-integration-go:/go \
--name headscale-test-suite \
--volume $PWD:$PWD -w $PWD/integration \
--volume /var/run/docker.sock:/var/run/docker.sock \
--volume $PWD/control_logs:/tmp/control \
golang:1 \
go run gotest.tools/gotestsum@latest -- ./... \
-failfast \
-timeout 120m \
-parallel 1 \
-run "^TestNodeOnlineLastSeenStatus$"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: logs
path: "control_logs/*.log"
- uses: actions/upload-artifact@v3
if: always() && steps.changed-files.outputs.any_changed == 'true'
with:
name: pprof
path: "control_logs/*.pprof.tar"

1
.gitignore vendored
View file

@ -1,5 +1,6 @@
ignored/ ignored/
tailscale/ tailscale/
.vscode/
# Binaries for programs and plugins # Binaries for programs and plugins
*.exe *.exe

View file

@ -27,6 +27,8 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553) - API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553)
- Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611) - Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611)
- The latest supported client is 1.32 - The latest supported client is 1.32
- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564)
- If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url.
### Changes ### Changes

View file

@ -1,47 +0,0 @@
package main
import (
"log"
"github.com/juanfont/headscale/integration"
"github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3"
)
func main() {
log.Printf("creating docker pool")
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("could not connect to docker: %s", err)
}
log.Printf("creating docker network")
network, err := pool.CreateNetwork("docker-integration-net")
if err != nil {
log.Fatalf("failed to create or get network: %s", err)
}
for _, version := range integration.AllVersions {
log.Printf("creating container image for Tailscale (%s)", version)
tsClient, err := tsic.New(
pool,
version,
network,
)
if err != nil {
log.Fatalf("failed to create tailscale node: %s", err)
}
err = tsClient.Shutdown()
if err != nil {
log.Fatalf("failed to shut down container: %s", err)
}
}
network.Close()
err = pool.RemoveNetwork(network)
if err != nil {
log.Fatalf("failed to remove network: %s", err)
}
}

View file

@ -493,7 +493,7 @@ func nodesToPtables(
"Ephemeral", "Ephemeral",
"Last seen", "Last seen",
"Expiration", "Expiration",
"Online", "Connected",
"Expired", "Expired",
} }
if showTags { if showTags {

View file

@ -31,7 +31,7 @@
# When updating go.mod or go.sum, a new sha will need to be calculated, # When updating go.mod or go.sum, a new sha will need to be calculated,
# update this if you have a mismatch after doing a change to thos files. # update this if you have a mismatch after doing a change to thos files.
vendorHash = "sha256-2ci6m1rKI3QdwbkqaGQlf0R+w4PhD0lkrLAu6wKj1LE="; vendorHash = "sha256-7yqJbF0GkKa3wjiGWJ8BZSJyckrpwmCiX77/aoPGmRc=";
ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"]; ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"];
}; };

2
go.mod
View file

@ -35,6 +35,7 @@ require (
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
go4.org/netipx v0.0.0-20230824141953-6213f710f925 go4.org/netipx v0.0.0-20230824141953-6213f710f925
golang.org/x/crypto v0.16.0 golang.org/x/crypto v0.16.0
golang.org/x/exp v0.0.0-20231127185646-65229373498e
golang.org/x/net v0.19.0 golang.org/x/net v0.19.0
golang.org/x/oauth2 v0.15.0 golang.org/x/oauth2 v0.15.0
golang.org/x/sync v0.5.0 golang.org/x/sync v0.5.0
@ -146,7 +147,6 @@ require (
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
golang.org/x/exp v0.0.0-20231127185646-65229373498e // indirect
golang.org/x/mod v0.14.0 // indirect golang.org/x/mod v0.14.0 // indirect
golang.org/x/sys v0.15.0 // indirect golang.org/x/sys v0.15.0 // indirect
golang.org/x/term v0.15.0 // indirect golang.org/x/term v0.15.0 // indirect

View file

@ -59,6 +59,7 @@ var (
errUnsupportedLetsEncryptChallengeType = errors.New( errUnsupportedLetsEncryptChallengeType = errors.New(
"unknown value for Lets Encrypt challenge type", "unknown value for Lets Encrypt challenge type",
) )
errEmptyInitialDERPMap = errors.New("initial DERPMap is empty, Headscale requries at least one entry")
) )
const ( const (
@ -193,7 +194,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
} }
if derpServerKey.Equal(*noisePrivateKey) { if derpServerKey.Equal(*noisePrivateKey) {
return nil, fmt.Errorf("DERP server private key and noise private key are the same: %w", err) return nil, fmt.Errorf(
"DERP server private key and noise private key are the same: %w",
err,
)
} }
embeddedDERPServer, err := derpServer.NewDERPServer( embeddedDERPServer, err := derpServer.NewDERPServer(
@ -259,21 +263,14 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
h.DERPMap.Regions[region.RegionID] = &region h.DERPMap.Regions[region.RegionID] = &region
} }
h.nodeNotifier.NotifyAll(types.StateUpdate{ stateUpdate := types.StateUpdate{
Type: types.StateDERPUpdated, Type: types.StateDERPUpdated,
DERPMap: *h.DERPMap, DERPMap: h.DERPMap,
})
} }
if stateUpdate.Valid() {
h.nodeNotifier.NotifyAll(stateUpdate)
} }
} }
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C {
err := h.db.HandlePrimarySubnetFailover()
if err != nil {
log.Error().Err(err).Msg("failed to handle primary subnet failover")
}
} }
} }
@ -505,13 +502,15 @@ func (h *Headscale) Serve() error {
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
} }
if len(h.DERPMap.Regions) == 0 {
return errEmptyInitialDERPMap
}
// TODO(kradalby): These should have cancel channels and be cleaned // TODO(kradalby): These should have cancel channels and be cleaned
// up on shutdown. // up on shutdown.
go h.expireEphemeralNodes(updateInterval) go h.expireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval) go h.expireExpiredMachines(updateInterval)
go h.failoverSubnetRoutes(updateInterval)
if zl.GlobalLevel() == zl.TraceLevel { if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true zerolog.RespLog = true
} else { } else {

View file

@ -16,6 +16,46 @@ import (
"tailscale.com/types/key" "tailscale.com/types/key"
) )
func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (func(string), func(string), func(error, string)) {
return func(msg string) {
log.Info().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("node", registerRequest.Hostinfo.Hostname).
Str("followup", registerRequest.Followup).
Time("expiry", registerRequest.Expiry).
Msg(msg)
},
func(msg string) {
log.Trace().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("node", registerRequest.Hostinfo.Hostname).
Str("followup", registerRequest.Followup).
Time("expiry", registerRequest.Expiry).
Msg(msg)
},
func(err error, msg string) {
log.Error().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("node", registerRequest.Hostinfo.Hostname).
Str("followup", registerRequest.Followup).
Time("expiry", registerRequest.Expiry).
Err(err).
Msg(msg)
}
}
// handleRegister is the logic for registering a client. // handleRegister is the logic for registering a client.
func (h *Headscale) handleRegister( func (h *Headscale) handleRegister(
writer http.ResponseWriter, writer http.ResponseWriter,
@ -23,8 +63,11 @@ func (h *Headscale) handleRegister(
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) { ) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
now := time.Now().UTC() now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) node, err := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
logTrace("handleRegister database lookup has returned")
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
// If the node has AuthKey set, handle registration via PreAuthKeys // If the node has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
@ -42,15 +85,9 @@ func (h *Headscale) handleRegister(
// is that the client will hammer headscale with requests until it gets a // is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse. // successful RegisterResponse.
if registerRequest.Followup != "" { if registerRequest.Followup != "" {
logTrace("register request is a followup")
if _, ok := h.registrationCache.Get(machineKey.String()); ok { if _, ok := h.registrationCache.Get(machineKey.String()); ok {
log.Debug(). logTrace("Node is waiting for interactive login")
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Msg("Node is waiting for interactive login")
select { select {
case <-req.Context().Done(): case <-req.Context().Done():
@ -63,26 +100,14 @@ func (h *Headscale) handleRegister(
} }
} }
log.Info(). logInfo("Node not found in database, creating new")
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Msg("New node not yet in the database")
givenName, err := h.db.GenerateGivenName( givenName, err := h.db.GenerateGivenName(
machineKey, machineKey,
registerRequest.Hostinfo.Hostname, registerRequest.Hostinfo.Hostname,
) )
if err != nil { if err != nil {
log.Error(). logErr(err, "Failed to generate given name for node")
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed to generate given name for node")
return return
} }
@ -101,11 +126,7 @@ func (h *Headscale) handleRegister(
} }
if !registerRequest.Expiry.IsZero() { if !registerRequest.Expiry.IsZero() {
log.Trace(). logTrace("Non-zero expiry time requested")
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested")
newNode.Expiry = &registerRequest.Expiry newNode.Expiry = &registerRequest.Expiry
} }
@ -419,13 +440,12 @@ func (h *Headscale) handleNewNode(
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) { ) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The node registration is new, redirect the client to the registration URL // The node registration is new, redirect the client to the registration URL
log.Debug(). logTrace("The node seems to be new, sending auth url")
Caller().
Str("node", registerRequest.Hostinfo.Hostname).
Msg("The node seems to be new, sending auth url")
if h.oauth2Config != nil { if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
@ -441,10 +461,7 @@ func (h *Headscale) handleNewNode(
respBody, err := json.Marshal(resp) respBody, err := json.Marshal(resp)
if err != nil { if err != nil {
log.Error(). logErr(err, "Cannot encode message")
Caller().
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
@ -454,17 +471,10 @@ func (h *Headscale) handleNewNode(
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody) _, err = writer.Write(respBody)
if err != nil { if err != nil {
log.Error(). logErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
log.Info(). logInfo(fmt.Sprintf("Successfully sent auth url: %s", resp.AuthURL))
Caller().
Str("AuthURL", resp.AuthURL).
Str("node", registerRequest.Hostinfo.Hostname).
Msg("Successfully sent auth url")
} }
func (h *Headscale) handleNodeLogOut( func (h *Headscale) handleNodeLogOut(
@ -490,6 +500,19 @@ func (h *Headscale) handleNodeLogOut(
return return
} }
stateUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &now,
},
},
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = false resp.MachineAuthorized = false
resp.NodeKeyExpired = true resp.NodeKeyExpired = true

View file

@ -171,11 +171,13 @@ func NewHeadscaleDatabase(
dKey = "discokey:" + node.DiscoKey dKey = "discokey:" + node.DiscoKey
} }
err := db.db.Exec("UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id", err := db.db.Exec(
"UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id",
sql.Named("mKey", mKey), sql.Named("mKey", mKey),
sql.Named("nKey", nKey), sql.Named("nKey", nKey),
sql.Named("dKey", dKey), sql.Named("dKey", dKey),
sql.Named("id", node.ID)).Error sql.Named("id", node.ID),
).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -61,11 +61,6 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) {
sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID })
log.Trace().
Caller().
Str("node", node.Hostname).
Msgf("Found peers: %s", nodes.String())
return nodes, nil return nodes, nil
} }
@ -176,6 +171,12 @@ func (hsdb *HSDatabase) GetNodeByMachineKey(
hsdb.mu.RLock() hsdb.mu.RLock()
defer hsdb.mu.RUnlock() defer hsdb.mu.RUnlock()
return hsdb.getNodeByMachineKey(machineKey)
}
func (hsdb *HSDatabase) getNodeByMachineKey(
machineKey key.MachinePublic,
) (*types.Node, error) {
mach := types.Node{} mach := types.Node{}
if result := hsdb.db. if result := hsdb.db.
Preload("AuthKey"). Preload("AuthKey").
@ -252,6 +253,10 @@ func (hsdb *HSDatabase) SetTags(
hsdb.mu.Lock() hsdb.mu.Lock()
defer hsdb.mu.Unlock() defer hsdb.mu.Unlock()
if len(tags) == 0 {
return nil
}
newTags := []string{} newTags := []string{}
for _, tag := range tags { for _, tag := range tags {
if !util.StringOrPrefixListContains(newTags, tag) { if !util.StringOrPrefixListContains(newTags, tag) {
@ -265,10 +270,14 @@ func (hsdb *HSDatabase) SetTags(
return fmt.Errorf("failed to update tags for node in the database: %w", err) return fmt.Errorf("failed to update tags for node in the database: %w", err)
} }
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
Changed: types.Nodes{node}, ChangeNodes: types.Nodes{node},
}, node.MachineKey.String()) Message: "called from db.SetTags",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil return nil
} }
@ -301,10 +310,14 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error {
return fmt.Errorf("failed to rename node in the database: %w", err) return fmt.Errorf("failed to rename node in the database: %w", err)
} }
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
Changed: types.Nodes{node}, ChangeNodes: types.Nodes{node},
}, node.MachineKey.String()) Message: "called from db.RenameNode",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String())
}
return nil return nil
} }
@ -327,10 +340,18 @@ func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error
) )
} }
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChangedPatch,
Changed: types.Nodes{node}, ChangePatches: []*tailcfg.PeerChange{
}, node.MachineKey.String()) {
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: &expiry,
},
},
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil return nil
} }
@ -354,10 +375,13 @@ func (hsdb *HSDatabase) deleteNode(node *types.Node) error {
return err return err
} }
hsdb.notifier.NotifyAll(types.StateUpdate{ stateUpdate := types.StateUpdate{
Type: types.StatePeerRemoved, Type: types.StatePeerRemoved,
Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)},
}) }
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil return nil
} }
@ -629,20 +653,6 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool
return false return false
} }
func (hsdb *HSDatabase) ListOnlineNodes(
node *types.Node,
) (map[tailcfg.NodeID]bool, error) {
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()
peers, err := hsdb.listPeers(node)
if err != nil {
return nil, err
}
return peers.OnlineNodeMap(), nil
}
// enableRoutes enables new routes based on a list of new routes. // enableRoutes enables new routes based on a list of new routes.
func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error { func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error {
newRoutes := make([]netip.Prefix, len(routeStrs)) newRoutes := make([]netip.Prefix, len(routeStrs))
@ -694,10 +704,30 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro
} }
} }
hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ // Ensure the node has the latest routes when notifying the other
// nodes
nRoutes, err := hsdb.getNodeRoutes(node)
if err != nil {
return fmt.Errorf("failed to read back routes: %w", err)
}
node.Routes = nRoutes
log.Trace().
Caller().
Str("node", node.Hostname).
Strs("routes", routeStrs).
Msg("enabling routes")
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
Changed: types.Nodes{node}, ChangeNodes: types.Nodes{node},
}, node.MachineKey.String()) Message: "called from db.enableRoutes",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyWithIgnore(
stateUpdate, node.MachineKey.String())
}
return nil return nil
} }
@ -728,7 +758,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
return normalizedHostname, nil return normalizedHostname, nil
} }
func (hsdb *HSDatabase) GenerateGivenName(mkey key.MachinePublic, suppliedName string) (string, error) { func (hsdb *HSDatabase) GenerateGivenName(
mkey key.MachinePublic,
suppliedName string,
) (string, error) {
hsdb.mu.RLock() hsdb.mu.RLock()
defer hsdb.mu.RUnlock() defer hsdb.mu.RUnlock()
@ -823,33 +856,34 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
// checked everything. // checked everything.
started := time.Now() started := time.Now()
users, err := hsdb.listUsers() expired := make([]*tailcfg.PeerChange, 0)
if err != nil {
log.Error().Err(err).Msg("Error listing users")
return time.Unix(0, 0) nodes, err := hsdb.listNodes()
}
for _, user := range users {
nodes, err := hsdb.listNodesByUser(user.Name)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("user", user.Name). Msg("Error listing nodes to find expired nodes")
Msg("Error listing nodes in user")
return time.Unix(0, 0) return time.Unix(0, 0)
} }
expired := make([]tailcfg.NodeID, 0)
for index, node := range nodes { for index, node := range nodes {
if node.IsExpired() && if node.IsExpired() &&
// TODO(kradalby): Replace this, it is very spammy
// It will notify about all nodes that has been expired.
// It should only notify about expired nodes since _last check_.
node.Expiry.After(lastCheck) { node.Expiry.After(lastCheck) {
expired = append(expired, tailcfg.NodeID(node.ID)) expired = append(expired, &tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry,
})
now := time.Now() now := time.Now()
err := hsdb.nodeSetExpiry(nodes[index], now) // Do not use setNodeExpiry as that has a notifier hook, which
if err != nil { // can cause a deadlock, we are updating all changed nodes later
// and there is no point in notifiying twice.
if err := hsdb.db.Model(nodes[index]).Updates(types.Node{
Expiry: &now,
}).Error; err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("node", node.Hostname). Str("node", node.Hostname).
@ -864,12 +898,12 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time {
} }
} }
if len(expired) > 0 { stateUpdate := types.StateUpdate{
hsdb.notifier.NotifyAll(types.StateUpdate{ Type: types.StatePeerChangedPatch,
Type: types.StatePeerRemoved, ChangePatches: expired,
Removed: expired,
})
} }
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
} }
return started return started

View file

@ -603,8 +603,9 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
db.db.Save(&node) db.db.Save(&node)
err = db.SaveNodeRoutes(&node) sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
node0ByID, err := db.GetNodeByID(0) node0ByID, err := db.GetNodeByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

View file

@ -7,7 +7,9 @@ import (
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/types/key"
) )
var ErrRouteIsNotAvailable = errors.New("route is not available") var ErrRouteIsNotAvailable = errors.New("route is not available")
@ -21,7 +23,38 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { func (hsdb *HSDatabase) getRoutes() (types.Routes, error) {
var routes types.Routes var routes types.Routes
err := hsdb.db.Preload("Node").Find(&routes).Error err := hsdb.db.
Preload("Node").
Preload("Node.User").
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
Preload("Node").
Preload("Node.User").
Where("advertised = ? AND enabled = ?", true, true).
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) {
var routes types.Routes
err := hsdb.db.
Preload("Node").
Preload("Node.User").
Where("prefix = ?", types.IPPrefix(pref)).
Find(&routes).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -40,6 +73,7 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes,
var routes types.Routes var routes types.Routes
err := hsdb.db. err := hsdb.db.
Preload("Node"). Preload("Node").
Preload("Node.User").
Where("node_id = ? AND advertised = true", node.ID). Where("node_id = ? AND advertised = true", node.ID).
Find(&routes).Error Find(&routes).Error
if err != nil { if err != nil {
@ -60,6 +94,7 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) {
var routes types.Routes var routes types.Routes
err := hsdb.db. err := hsdb.db.
Preload("Node"). Preload("Node").
Preload("Node.User").
Where("node_id = ?", node.ID). Where("node_id = ?", node.ID).
Find(&routes).Error Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
@ -78,7 +113,10 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) {
var route types.Route var route types.Route
err := hsdb.db.Preload("Node").First(&route, id).Error err := hsdb.db.
Preload("Node").
Preload("Node.User").
First(&route, id).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -122,21 +160,26 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
return err return err
} }
var routes types.Routes
node := route.Node
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.IsExitRoute() { if !route.IsExitRoute() {
err = hsdb.failoverRouteWithNotify(route)
if err != nil {
return err
}
route.Enabled = false route.Enabled = false
route.IsPrimary = false route.IsPrimary = false
err = hsdb.db.Save(route).Error err = hsdb.db.Save(route).Error
if err != nil { if err != nil {
return err return err
} }
} else {
return hsdb.handlePrimarySubnetFailover() routes, err = hsdb.getNodeRoutes(&node)
}
routes, err := hsdb.getNodeRoutes(&route.Node)
if err != nil { if err != nil {
return err return err
} }
@ -151,8 +194,27 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
} }
} }
} }
}
return hsdb.handlePrimarySubnetFailover() if routes == nil {
routes, err = hsdb.getNodeRoutes(&node)
if err != nil {
return err
}
}
node.Routes = routes
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node},
Message: "called from db.DisableRoute",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
} }
func (hsdb *HSDatabase) DeleteRoute(id uint64) error { func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
@ -164,18 +226,23 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
return err return err
} }
var routes types.Routes
node := route.Node
// Tailscale requires both IPv4 and IPv6 exit routes to // Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per // be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.IsExitRoute() { if !route.IsExitRoute() {
err := hsdb.failoverRouteWithNotify(route)
if err != nil {
return nil
}
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
return err return err
} }
} else {
return hsdb.handlePrimarySubnetFailover() routes, err := hsdb.getNodeRoutes(&node)
}
routes, err := hsdb.getNodeRoutes(&route.Node)
if err != nil { if err != nil {
return err return err
} }
@ -190,8 +257,27 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
return err return err
} }
}
return hsdb.handlePrimarySubnetFailover() if routes == nil {
routes, err = hsdb.getNodeRoutes(&node)
if err != nil {
return err
}
}
node.Routes = routes
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{&node},
Message: "called from db.DeleteRoute",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
return nil
} }
func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error { func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
@ -204,9 +290,13 @@ func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error {
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
return err return err
} }
// TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all.
hsdb.failoverRouteWithNotify(&routes[i])
} }
return hsdb.handlePrimarySubnetFailover() return nil
} }
// isUniquePrefix returns if there is another node providing the same route already. // isUniquePrefix returns if there is another node providing the same route already.
@ -259,18 +349,22 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er
// SaveNodeRoutes takes a node and updates the database with // SaveNodeRoutes takes a node and updates the database with
// the new routes. // the new routes.
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) error { // It returns a bool wheter an update should be sent as the
// saved route impacts nodes.
func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) {
hsdb.mu.Lock() hsdb.mu.Lock()
defer hsdb.mu.Unlock() defer hsdb.mu.Unlock()
return hsdb.saveNodeRoutes(node) return hsdb.saveNodeRoutes(node)
} }
func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error { func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) {
sendUpdate := false
currentRoutes := types.Routes{} currentRoutes := types.Routes{}
err := hsdb.db.Where("node_id = ?", node.ID).Find(&currentRoutes).Error err := hsdb.db.Where("node_id = ?", node.ID).Find(&currentRoutes).Error
if err != nil { if err != nil {
return err return sendUpdate, err
} }
advertisedRoutes := map[netip.Prefix]bool{} advertisedRoutes := map[netip.Prefix]bool{}
@ -290,7 +384,14 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
currentRoutes[pos].Advertised = true currentRoutes[pos].Advertised = true
err := hsdb.db.Save(&currentRoutes[pos]).Error err := hsdb.db.Save(&currentRoutes[pos]).Error
if err != nil { if err != nil {
return err return sendUpdate, err
}
// If a route that is newly "saved" is already
// enabled, set sendUpdate to true as it is now
// available.
if route.Enabled {
sendUpdate = true
} }
} }
advertisedRoutes[netip.Prefix(route.Prefix)] = true advertisedRoutes[netip.Prefix(route.Prefix)] = true
@ -299,7 +400,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
currentRoutes[pos].Enabled = false currentRoutes[pos].Enabled = false
err := hsdb.db.Save(&currentRoutes[pos]).Error err := hsdb.db.Save(&currentRoutes[pos]).Error
if err != nil { if err != nil {
return err return sendUpdate, err
} }
} }
} }
@ -313,142 +414,224 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
Enabled: false, Enabled: false,
} }
err := hsdb.db.Create(&route).Error err := hsdb.db.Create(&route).Error
if err != nil {
return sendUpdate, err
}
}
}
return sendUpdate, nil
}
// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route
// currently have a functioning host that exposes the network.
func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error {
nodeRoutes, err := hsdb.getNodeRoutes(node)
if err != nil {
return nil
}
for _, nodeRoute := range nodeRoutes {
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix))
if err != nil { if err != nil {
return err return err
} }
for _, route := range routes {
if route.IsPrimary {
// if we have a primary route, and the node is connected
// nothing needs to be done.
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
continue
}
// if not, we need to failover the route
err := hsdb.failoverRouteWithNotify(&route)
if err != nil {
return err
}
}
} }
} }
return nil return nil
} }
func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error {
hsdb.mu.Lock() routes, err := hsdb.getNodeRoutes(node)
defer hsdb.mu.Unlock()
return hsdb.handlePrimarySubnetFailover()
}
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
// first, get all the enabled routes
var routes types.Routes
err := hsdb.db.
Preload("Node").
Where("advertised = ? AND enabled = ?", true, true).
Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().Err(err).Msg("error getting routes")
}
changedNodes := make(types.Nodes, 0)
for pos, route := range routes {
if route.IsExitRoute() {
continue
}
node := &route.Node
if !route.IsPrimary {
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().
Str("prefix", netip.Prefix(route.Prefix).String()).
Str("node", route.Node.GivenName).
Msg("Setting primary route")
routes[pos].IsPrimary = true
err := hsdb.db.Save(&routes[pos]).Error
if err != nil { if err != nil {
log.Error().Err(err).Msg("error marking route as primary") return nil
}
var changedKeys []key.MachinePublic
for _, route := range routes {
changed, err := hsdb.failoverRoute(&route)
if err != nil {
return err return err
} }
changedNodes = append(changedNodes, node) changedKeys = append(changedKeys, changed...)
continue
}
} }
if route.IsPrimary { changedKeys = lo.Uniq(changedKeys)
if route.Node.IsOnline() {
continue
}
// node offline, find a new primary var nodes types.Nodes
log.Info().
Str("node", route.Node.Hostname).
Str("prefix", netip.Prefix(route.Prefix).String()).
Msgf("node offline, finding a new primary subnet")
// find a new primary route
var newPrimaryRoutes types.Routes
err := hsdb.db.
Preload("Node").
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
route.NodeID,
true, true).
Find(&newPrimaryRoutes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().Err(err).Msg("error finding new primary route")
for _, key := range changedKeys {
node, err := hsdb.GetNodeByMachineKey(key)
if err != nil {
return err return err
} }
var newPrimaryRoute *types.Route nodes = append(nodes, node)
for pos, r := range newPrimaryRoutes { }
if r.Node.IsOnline() {
newPrimaryRoute = &newPrimaryRoutes[pos]
if nodes != nil {
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.FailoverNodeRoutesWithNotify",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
}
return nil
}
func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {
changedKeys, err := hsdb.failoverRoute(r)
if err != nil {
return err
}
if len(changedKeys) == 0 {
return nil
}
var nodes types.Nodes
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("loading machines with new primary routes from db")
for _, key := range changedKeys {
node, err := hsdb.getNodeByMachineKey(key)
if err != nil {
return err
}
nodes = append(nodes, node)
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notifying peers about primary route change")
if nodes != nil {
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: nodes,
Message: "called from db.failoverRouteWithNotify",
}
if stateUpdate.Valid() {
hsdb.notifier.NotifyAll(stateUpdate)
}
}
log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notified peers about primary route change")
return nil
}
// failoverRoute takes a route that is no longer available,
// this can be either from:
// - being disabled
// - being deleted
// - host going offline
//
// and tries to find a new route to take over its place.
// If the given route was not primary, it returns early.
func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) {
if r == nil {
return nil, nil
}
// This route is not a primary route, and it isnt
// being served to nodes.
if !r.IsPrimary {
return nil, nil
}
// We do not have to failover exit nodes
if r.IsExitRoute() {
return nil, nil
}
routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix))
if err != nil {
return nil, err
}
var newPrimary *types.Route
// Find a new suitable route
for idx, route := range routes {
if r.ID == route.ID {
continue
}
if hsdb.notifier.IsConnected(route.Node.MachineKey) {
newPrimary = &routes[idx]
break break
} }
} }
if newPrimaryRoute == nil { // If a new route was not found/available,
log.Warn(). // return with an error.
Str("node", route.Node.Hostname). // We do not want to update the database as
Str("prefix", netip.Prefix(route.Prefix).String()). // the one currently marked as primary is the
Msgf("no alternative primary route found") // best we got.
if newPrimary == nil {
continue return nil, nil
} }
log.Info(). log.Trace().
Str("old_node", route.Node.Hostname). Str("hostname", newPrimary.Node.Hostname).
Str("prefix", netip.Prefix(route.Prefix).String()). Msg("found new primary, updating db")
Str("new_node", newPrimaryRoute.Node.Hostname).
Msgf("found new primary route")
// disable the old primary route // Remove primary from the old route
routes[pos].IsPrimary = false r.IsPrimary = false
err = hsdb.db.Save(&routes[pos]).Error err = hsdb.db.Save(&r).Error
if err != nil { if err != nil {
log.Error().Err(err).Msg("error disabling old primary route") log.Error().Err(err).Msg("error disabling new primary route")
return err return nil, err
} }
// enable the new primary route log.Trace().
newPrimaryRoute.IsPrimary = true Str("hostname", newPrimary.Node.Hostname).
err = hsdb.db.Save(&newPrimaryRoute).Error Msg("removed primary from old route")
// Set primary for the new primary
newPrimary.IsPrimary = true
err = hsdb.db.Save(&newPrimary).Error
if err != nil { if err != nil {
log.Error().Err(err).Msg("error enabling new primary route") log.Error().Err(err).Msg("error enabling new primary route")
return err return nil, err
} }
changedNodes = append(changedNodes, node) log.Trace().
} Str("hostname", newPrimary.Node.Hostname).
} Msg("set primary to new route")
if len(changedNodes) > 0 { // Return a list of the machinekeys of the changed nodes.
hsdb.notifier.NotifyAll(types.StateUpdate{ return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil
Type: types.StatePeerChanged,
Changed: changedNodes,
})
}
return nil
} }
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.

View file

@ -2,12 +2,19 @@ package db
import ( import (
"net/netip" "net/netip"
"os"
"testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
func (s *Suite) TestGetRoutes(c *check.C) { func (s *Suite) TestGetRoutes(c *check.C) {
@ -37,8 +44,9 @@ func (s *Suite) TestGetRoutes(c *check.C) {
} }
db.db.Save(&node) db.db.Save(&node)
err = db.SaveNodeRoutes(&node) su, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(su, check.Equals, false)
advertisedRoutes, err := db.GetAdvertisedRoutes(&node) advertisedRoutes, err := db.GetAdvertisedRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -85,8 +93,9 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
} }
db.db.Save(&node) db.db.Save(&node)
err = db.SaveNodeRoutes(&node) sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
availableRoutes, err := db.GetAdvertisedRoutes(&node) availableRoutes, err := db.GetAdvertisedRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -156,8 +165,9 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
} }
db.db.Save(&node1) db.db.Save(&node1)
err = db.SaveNodeRoutes(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, route.String()) err = db.enableRoutes(&node1, route.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -178,8 +188,9 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
} }
db.db.Save(&node2) db.db.Save(&node2)
err = db.SaveNodeRoutes(&node2) sendUpdate, err = db.SaveNodeRoutes(&node2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node2, route2.String()) err = db.enableRoutes(&node2, route2.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -201,142 +212,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
c.Assert(len(routes), check.Equals, 0) c.Assert(len(routes), check.Equals, 0)
} }
func (s *Suite) TestSubnetFailover(c *check.C) {
user, err := db.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
prefix2, err := netip.ParsePrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
}
now := time.Now()
node1 := types.Node{
ID: 1,
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)
err = db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, prefix.String())
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node1, prefix2.String())
c.Assert(err, check.IsNil)
err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
enabledRoutes1, err := db.GetEnabledRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2)
route, err := db.getPrimaryRoute(prefix)
c.Assert(err, check.IsNil)
c.Assert(route.NodeID, check.Equals, node1.ID)
hostInfo2 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix2},
}
node2 := types.Node{
ID: 2,
Hostname: "test_enable_route_node",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
Hostinfo: &hostInfo2,
LastSeen: &now,
}
db.db.Save(&node2)
err = db.saveNodeRoutes(&node2)
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node2, prefix2.String())
c.Assert(err, check.IsNil)
err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
enabledRoutes1, err = db.GetEnabledRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2)
enabledRoutes2, err := db.GetEnabledRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
routes, err := db.GetNodePrimaryRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
routes, err = db.GetNodePrimaryRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0)
// lets make node1 lastseen 10 mins ago
before := now.Add(-10 * time.Minute)
node1.LastSeen = &before
err = db.db.Save(&node1).Error
c.Assert(err, check.IsNil)
err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
routes, err = db.GetNodePrimaryRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1)
routes, err = db.GetNodePrimaryRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1)
node2.Hostinfo = &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
}
err = db.db.Save(&node2).Error
c.Assert(err, check.IsNil)
err = db.SaveNodeRoutes(&node2)
c.Assert(err, check.IsNil)
err = db.enableRoutes(&node2, prefix.String())
c.Assert(err, check.IsNil)
err = db.HandlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
routes, err = db.GetNodePrimaryRoutes(&node1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0)
routes, err = db.GetNodePrimaryRoutes(&node2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
}
func (s *Suite) TestDeleteRoutes(c *check.C) { func (s *Suite) TestDeleteRoutes(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -373,8 +248,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
} }
db.db.Save(&node1) db.db.Save(&node1)
err = db.SaveNodeRoutes(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(sendUpdate, check.Equals, false)
err = db.enableRoutes(&node1, prefix.String()) err = db.enableRoutes(&node1, prefix.String())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -392,3 +268,362 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1) c.Assert(len(enabledRoutes1), check.Equals, 1)
} }
func TestFailoverRoute(t *testing.T) {
ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Count/verify updates
var sink chan types.StateUpdate
go func() {
for range sink {
}
}()
machineKeys := []key.MachinePublic{
key.NewMachine().Public(),
key.NewMachine().Public(),
key.NewMachine().Public(),
key.NewMachine().Public(),
}
tests := []struct {
name string
failingRoute types.Route
routes types.Routes
want []key.MachinePublic
wantErr bool
}{
{
name: "no-route",
failingRoute: types.Route{},
routes: types.Routes{},
want: nil,
wantErr: false,
},
{
name: "no-prime",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: false,
},
routes: types.Routes{},
want: nil,
wantErr: false,
},
{
name: "exit-node",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("0.0.0.0/0"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{},
want: nil,
wantErr: false,
},
{
name: "no-failover-single-route",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
},
want: nil,
wantErr: false,
},
{
name: "failover-primary",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: false,
},
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
},
wantErr: false,
},
{
name: "failover-none-primary",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: false,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: false,
},
},
want: nil,
wantErr: false,
},
{
name: "failover-primary-multi-route",
failingRoute: types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: false,
},
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: true,
},
types.Route{
Model: gorm.Model{
ID: 3,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[2],
},
IsPrimary: false,
},
},
want: []key.MachinePublic{
machineKeys[1],
machineKeys[0],
},
wantErr: false,
},
{
name: "failover-primary-no-online",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
// Offline
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[3],
},
IsPrimary: false,
},
},
want: nil,
wantErr: false,
},
{
name: "failover-primary-one-not-online",
failingRoute: types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
routes: types.Routes{
types.Route{
Model: gorm.Model{
ID: 1,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[0],
},
IsPrimary: true,
},
// Offline
types.Route{
Model: gorm.Model{
ID: 2,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[3],
},
IsPrimary: false,
},
types.Route{
Model: gorm.Model{
ID: 3,
},
Prefix: ipp("10.0.0.0/24"),
Node: types.Node{
MachineKey: machineKeys[1],
},
IsPrimary: true,
},
},
want: []key.MachinePublic{
machineKeys[0],
machineKeys[1],
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "failover-db-test")
assert.NoError(t, err)
notif := notifier.NewNotifier()
db, err = NewHeadscaleDatabase(
"sqlite3",
tmpDir+"/headscale_test.db",
false,
notif,
[]netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
"",
)
assert.NoError(t, err)
// Pretend that all the nodes are connected to control
for idx, key := range machineKeys {
// Pretend one node is offline
if idx == 3 {
continue
}
notif.AddNode(key, sink)
}
for _, route := range tt.routes {
if err := db.db.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err)
}
}
got, err := db.failoverRoute(&tt.failingRoute)
if (err != nil) != tt.wantErr {
t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -18,6 +18,7 @@ import (
"tailscale.com/net/stun" "tailscale.com/net/stun"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger"
) )
// fastStartHeader is the header (with value "1") that signals to the HTTP // fastStartHeader is the header (with value "1") that signals to the HTTP
@ -33,13 +34,19 @@ type DERPServer struct {
tailscaleDERP *derp.Server tailscaleDERP *derp.Server
} }
func derpLogf() logger.Logf {
return func(format string, args ...any) {
log.Debug().Caller().Msgf(format, args...)
}
}
func NewDERPServer( func NewDERPServer(
serverURL string, serverURL string,
derpKey key.NodePrivate, derpKey key.NodePrivate,
cfg *types.DERPConfig, cfg *types.DERPConfig,
) (*DERPServer, error) { ) (*DERPServer, error) {
log.Trace().Caller().Msg("Creating new embedded DERP server") log.Trace().Caller().Msg("Creating new embedded DERP server")
server := derp.NewServer(derpKey, log.Debug().Msgf) // nolint // zerolinter complains server := derp.NewServer(derpKey, derpLogf()) // nolint // zerolinter complains
return &DERPServer{ return &DERPServer{
serverURL: serverURL, serverURL: serverURL,

View file

@ -204,7 +204,13 @@ func (api headscaleV1APIServer) GetNode(
return nil, err return nil, err
} }
return &v1.GetNodeResponse{Node: node.Proto()}, nil resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
return &v1.GetNodeResponse{Node: resp}, nil
} }
func (api headscaleV1APIServer) SetTags( func (api headscaleV1APIServer) SetTags(
@ -333,7 +339,13 @@ func (api headscaleV1APIServer) ListNodes(
response := make([]*v1.Node, len(nodes)) response := make([]*v1.Node, len(nodes))
for index, node := range nodes { for index, node := range nodes {
response[index] = node.Proto() resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
response[index] = resp
} }
return &v1.ListNodesResponse{Nodes: response}, nil return &v1.ListNodesResponse{Nodes: response}, nil
@ -346,13 +358,18 @@ func (api headscaleV1APIServer) ListNodes(
response := make([]*v1.Node, len(nodes)) response := make([]*v1.Node, len(nodes))
for index, node := range nodes { for index, node := range nodes {
m := node.Proto() resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey)
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
&node, &node,
) )
m.InvalidTags = invalidTags resp.InvalidTags = invalidTags
m.ValidTags = validTags resp.ValidTags = validTags
response[index] = m response[index] = resp
} }
return &v1.ListNodesResponse{Nodes: response}, nil return &v1.ListNodesResponse{Nodes: response}, nil

View file

@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"os" "os"
"path" "path"
"slices"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -21,6 +22,7 @@ import (
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"golang.org/x/exp/maps"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/smallzstd" "tailscale.com/smallzstd"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -45,6 +47,7 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
// - Keep information about the previous mapresponse so we can send a diff // - Keep information about the previous mapresponse so we can send a diff
// - Store hashes // - Store hashes
// - Create a "minifier" that removes info not needed for the node // - Create a "minifier" that removes info not needed for the node
// - some sort of batching, wait for 5 or 60 seconds before sending
type Mapper struct { type Mapper struct {
// Configuration // Configuration
@ -63,6 +66,12 @@ type Mapper struct {
// only one func is accessing it over time. // only one func is accessing it over time.
mu sync.Mutex mu sync.Mutex
peers map[uint64]*types.Node peers map[uint64]*types.Node
patches map[uint64][]patch
}
type patch struct {
timestamp time.Time
change *tailcfg.PeerChange
} }
func NewMapper( func NewMapper(
@ -94,6 +103,7 @@ func NewMapper(
// TODO: populate // TODO: populate
peers: peers.IDMap(), peers: peers.IDMap(),
patches: make(map[uint64][]patch),
} }
} }
@ -235,6 +245,19 @@ func (m *Mapper) FullMapResponse(
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
peers := maps.Keys(m.peers)
peersWithPatches := maps.Keys(m.patches)
slices.Sort(peers)
slices.Sort(peersWithPatches)
if len(peersWithPatches) > 0 {
log.Debug().
Str("node", node.Hostname).
Uints64("peers", peers).
Uints64("pending_patches", peersWithPatches).
Msgf("node requested full map response, but has pending patches")
}
resp, err := m.fullMapResponse(node, pol, mapRequest.Version) resp, err := m.fullMapResponse(node, pol, mapRequest.Version)
if err != nil { if err != nil {
return nil, err return nil, err
@ -272,10 +295,12 @@ func (m *Mapper) KeepAliveResponse(
func (m *Mapper) DERPMapResponse( func (m *Mapper) DERPMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
derpMap tailcfg.DERPMap, derpMap *tailcfg.DERPMap,
) ([]byte, error) { ) ([]byte, error) {
m.derpMap = derpMap
resp := m.baseMapResponse() resp := m.baseMapResponse()
resp.DERPMap = &derpMap resp.DERPMap = derpMap
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
} }
@ -285,18 +310,29 @@ func (m *Mapper) PeerChangedResponse(
node *types.Node, node *types.Node,
changed types.Nodes, changed types.Nodes,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) { ) ([]byte, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
lastSeen := make(map[tailcfg.NodeID]bool)
// Update our internal map. // Update our internal map.
for _, node := range changed { for _, node := range changed {
m.peers[node.ID] = node if patches, ok := m.patches[node.ID]; ok {
// preserve online status in case the patch has an outdated one
online := node.IsOnline
// We have just seen the node, let the peers update their list. for _, p := range patches {
lastSeen[tailcfg.NodeID(node.ID)] = true // TODO(kradalby): Figure if this needs to be sorted by timestamp
node.ApplyPeerChange(p.change)
}
// Ensure the patches are not applied again later
delete(m.patches, node.ID)
node.IsOnline = online
}
m.peers[node.ID] = node
} }
resp := m.baseMapResponse() resp := m.baseMapResponse()
@ -316,11 +352,55 @@ func (m *Mapper) PeerChangedResponse(
return nil, err return nil, err
} }
// resp.PeerSeenChange = lastSeen return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
}
// PeerChangedPatchResponse creates a patch MapResponse with
// incoming update from a state change.
func (m *Mapper) PeerChangedPatchResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
changed []*tailcfg.PeerChange,
pol *policy.ACLPolicy,
) ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
sendUpdate := false
// patch the internal map
for _, change := range changed {
if peer, ok := m.peers[uint64(change.NodeID)]; ok {
peer.ApplyPeerChange(change)
sendUpdate = true
} else {
log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname)
p := patch{
timestamp: time.Now(),
change: change,
}
if patches, ok := m.patches[uint64(change.NodeID)]; ok {
patches := append(patches, p)
m.patches[uint64(change.NodeID)] = patches
} else {
m.patches[uint64(change.NodeID)] = []patch{p}
}
}
}
if !sendUpdate {
return nil, nil
}
resp := m.baseMapResponse()
resp.PeersChangedPatch = changed
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
} }
// TODO(kradalby): We need some integration tests for this.
func (m *Mapper) PeerRemovedResponse( func (m *Mapper) PeerRemovedResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
@ -329,13 +409,23 @@ func (m *Mapper) PeerRemovedResponse(
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
// Some nodes might have been removed already
// so we dont want to ask downstream to remove
// twice, than can cause a panic in tailscaled.
notYetRemoved := []tailcfg.NodeID{}
// remove from our internal map // remove from our internal map
for _, id := range removed { for _, id := range removed {
if _, ok := m.peers[uint64(id)]; ok {
notYetRemoved = append(notYetRemoved, id)
}
delete(m.peers, uint64(id)) delete(m.peers, uint64(id))
delete(m.patches, uint64(id))
} }
resp := m.baseMapResponse() resp := m.baseMapResponse()
resp.PeersRemoved = removed resp.PeersRemoved = notYetRemoved
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
} }
@ -345,6 +435,7 @@ func (m *Mapper) marshalMapResponse(
resp *tailcfg.MapResponse, resp *tailcfg.MapResponse,
node *types.Node, node *types.Node,
compression string, compression string,
messages ...string,
) ([]byte, error) { ) ([]byte, error) {
atomic.AddUint64(&m.seq, 1) atomic.AddUint64(&m.seq, 1)
@ -358,11 +449,25 @@ func (m *Mapper) marshalMapResponse(
if debugDumpMapResponsePath != "" { if debugDumpMapResponsePath != "" {
data := map[string]interface{}{ data := map[string]interface{}{
"Messages": messages,
"MapRequest": mapRequest, "MapRequest": mapRequest,
"MapResponse": resp, "MapResponse": resp,
} }
body, err := json.Marshal(data) responseType := "keepalive"
switch {
case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
responseType = "changed"
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
responseType = "patch"
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
responseType = "removed"
}
body, err := json.MarshalIndent(data, "", " ")
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -381,7 +486,7 @@ func (m *Mapper) marshalMapResponse(
mapResponsePath := path.Join( mapResponsePath := path.Join(
mPath, mPath,
fmt.Sprintf("%d-%s-%d.json", now, m.uid, atomic.LoadUint64(&m.seq)), fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
) )
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@ -438,6 +543,7 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
resp := tailcfg.MapResponse{ resp := tailcfg.MapResponse{
KeepAlive: false, KeepAlive: false,
ControlTime: &now, ControlTime: &now,
// TODO(kradalby): Implement PingRequest?
} }
return resp return resp
@ -559,8 +665,5 @@ func appendPeerChanges(
resp.UserProfiles = profiles resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy resp.SSHPolicy = sshPolicy
// TODO(kradalby): This currently does not take last seen in keepalives into account
resp.OnlineChange = peers.OnlineNodeMap()
return nil return nil
} }

View file

@ -237,7 +237,6 @@ func Test_fullMapResponse(t *testing.T) {
Tags: []string{}, Tags: []string{},
PrimaryRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, PrimaryRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Online: new(bool),
MachineAuthorized: true, MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{ Capabilities: []tailcfg.NodeCapability{
tailcfg.CapabilityFileSharing, tailcfg.CapabilityFileSharing,
@ -293,7 +292,6 @@ func Test_fullMapResponse(t *testing.T) {
Tags: []string{}, Tags: []string{},
PrimaryRoutes: []netip.Prefix{}, PrimaryRoutes: []netip.Prefix{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Online: new(bool),
MachineAuthorized: true, MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{ Capabilities: []tailcfg.NodeCapability{
tailcfg.CapabilityFileSharing, tailcfg.CapabilityFileSharing,
@ -400,7 +398,6 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{}, DNSConfig: &tailcfg.DNSConfig{},
Domain: "", Domain: "",
CollectServices: "false", CollectServices: "false",
OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false},
PacketFilter: []tailcfg.FilterRule{}, PacketFilter: []tailcfg.FilterRule{},
UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}},
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
@ -442,10 +439,6 @@ func Test_fullMapResponse(t *testing.T) {
DNSConfig: &tailcfg.DNSConfig{}, DNSConfig: &tailcfg.DNSConfig{},
Domain: "", Domain: "",
CollectServices: "false", CollectServices: "false",
OnlineChange: map[tailcfg.NodeID]bool{
tailPeer1.ID: false,
tailcfg.NodeID(peer2.ID): false,
},
PacketFilter: []tailcfg.FilterRule{ PacketFilter: []tailcfg.FilterRule{
{ {
SrcIPs: []string{"100.64.0.2/32"}, SrcIPs: []string{"100.64.0.2/32"},

View file

@ -87,11 +87,9 @@ func tailNode(
hostname, err := node.GetFQDN(dnsConfig, baseDomain) hostname, err := node.GetFQDN(dnsConfig, baseDomain)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
} }
online := node.IsOnline()
tags, _ := pol.TagsOfNode(node) tags, _ := pol.TagsOfNode(node)
tags = lo.Uniq(append(tags, node.ForcedTags...)) tags = lo.Uniq(append(tags, node.ForcedTags...))
@ -101,6 +99,7 @@ func tailNode(
strconv.FormatUint(node.ID, util.Base10), strconv.FormatUint(node.ID, util.Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent ), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname, Name: hostname,
Cap: capVer,
User: tailcfg.UserID(node.UserID), User: tailcfg.UserID(node.UserID),
@ -116,13 +115,14 @@ func tailNode(
Hostinfo: node.Hostinfo.View(), Hostinfo: node.Hostinfo.View(),
Created: node.CreatedAt, Created: node.CreatedAt,
Online: node.IsOnline,
Tags: tags, Tags: tags,
PrimaryRoutes: primaryPrefixes, PrimaryRoutes: primaryPrefixes,
LastSeen: node.LastSeen,
Online: &online,
MachineAuthorized: !node.IsExpired(), MachineAuthorized: !node.IsExpired(),
Expired: node.IsExpired(),
} }
// - 74: 2023-09-18: Client understands NodeCapMap // - 74: 2023-09-18: Client understands NodeCapMap
@ -153,5 +153,11 @@ func tailNode(
tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrDisableUPnP) tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrDisableUPnP)
} }
if node.IsOnline == nil || !*node.IsOnline {
// LastSeen is only set when node is
// not connected to the control server.
tNode.LastSeen = node.LastSeen
}
return &tNode, nil return &tNode, nil
} }

View file

@ -68,7 +68,6 @@ func TestTailNode(t *testing.T) {
Hostinfo: hiview(tailcfg.Hostinfo{}), Hostinfo: hiview(tailcfg.Hostinfo{}),
Tags: []string{}, Tags: []string{},
PrimaryRoutes: []netip.Prefix{}, PrimaryRoutes: []netip.Prefix{},
Online: new(bool),
MachineAuthorized: true, MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{ Capabilities: []tailcfg.NodeCapability{
"https://tailscale.com/cap/file-sharing", "https://tailscale.com/cap/is-admin", "https://tailscale.com/cap/file-sharing", "https://tailscale.com/cap/is-admin",
@ -165,7 +164,6 @@ func TestTailNode(t *testing.T) {
}, },
LastSeen: &lastSeen, LastSeen: &lastSeen,
Online: new(bool),
MachineAuthorized: true, MachineAuthorized: true,
Capabilities: []tailcfg.NodeCapability{ Capabilities: []tailcfg.NodeCapability{

View file

@ -1,6 +1,8 @@
package notifier package notifier
import ( import (
"fmt"
"strings"
"sync" "sync"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@ -56,6 +58,19 @@ func (n *Notifier) RemoveNode(machineKey key.MachinePublic) {
Msg("Removed channel") Msg("Removed channel")
} }
// IsConnected reports if a node is connected to headscale and has a
// poll session open.
func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool {
n.l.RLock()
defer n.l.RUnlock()
if _, ok := n.nodes[machineKey.String()]; ok {
return true
}
return false
}
func (n *Notifier) NotifyAll(update types.StateUpdate) { func (n *Notifier) NotifyAll(update types.StateUpdate) {
n.NotifyWithIgnore(update) n.NotifyWithIgnore(update)
} }
@ -79,3 +94,16 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string)
c <- update c <- update
} }
} }
func (n *Notifier) String() string {
n.l.RLock()
defer n.l.RUnlock()
str := []string{"Notifier, in map:\n"}
for k, v := range n.nodes {
str = append(str, fmt.Sprintf("\t%s: %v\n", k, v))
}
return strings.Join(str, "")
}

View file

@ -14,29 +14,8 @@ import (
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
var ipComparer = cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0
})
var mkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
return x.String() == y.String()
})
var nkeyComparer = cmp.Comparer(func(x, y key.NodePublic) bool {
return x.String() == y.String()
})
var dkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
return x.String() == y.String()
})
var keyComparers []cmp.Option = []cmp.Option{
mkeyComparer, nkeyComparer, dkeyComparer,
}
func Test(t *testing.T) { func Test(t *testing.T) {
check.TestingT(t) check.TestingT(t)
} }
@ -969,7 +948,7 @@ func Test_listNodesInUser(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got := filterNodesByUser(test.args.nodes, test.args.user) got := filterNodesByUser(test.args.nodes, test.args.user)
if diff := cmp.Diff(test.want, got, keyComparers...); diff != "" { if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) t.Errorf("listNodesInUser() = (-want +got):\n%s", diff)
} }
}) })
@ -1733,7 +1712,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
test.args.nodes, test.args.nodes,
test.args.user, test.args.user,
) )
if diff := cmp.Diff(test.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" { if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("excludeCorrectlyTaggedNodes() (-want +got):\n%s", diff) t.Errorf("excludeCorrectlyTaggedNodes() (-want +got):\n%s", diff)
} }
}) })
@ -2085,10 +2064,6 @@ func Test_getTags(t *testing.T) {
} }
func Test_getFilteredByACLPeers(t *testing.T) { func Test_getFilteredByACLPeers(t *testing.T) {
ipComparer := cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0
})
type args struct { type args struct {
nodes types.Nodes nodes types.Nodes
rules []tailcfg.FilterRule rules []tailcfg.FilterRule
@ -2752,7 +2727,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
tt.args.nodes, tt.args.nodes,
tt.args.rules, tt.args.rules,
) )
if diff := cmp.Diff(tt.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" { if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
} }
}) })

View file

@ -9,6 +9,7 @@ import (
"github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
xslices "golang.org/x/exp/slices"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -61,7 +62,7 @@ func (h *Headscale) handlePoll(
) { ) {
logInfo, logErr := logPollFunc(mapRequest, node) logInfo, logErr := logPollFunc(mapRequest, node)
// This is the mechanism where the node gives us inforamtion about its // This is the mechanism where the node gives us information about its
// current configuration. // current configuration.
// //
// If OmitPeers is true, Stream is false, and ReadOnly is false, // If OmitPeers is true, Stream is false, and ReadOnly is false,
@ -69,6 +70,7 @@ func (h *Headscale) handlePoll(
// breaking existing long-polling (Stream == true) connections. // breaking existing long-polling (Stream == true) connections.
// In this case, the server can omit the entire response; the client // In this case, the server can omit the entire response; the client
// only checks the HTTP response status code. // only checks the HTTP response status code.
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
if mapRequest.OmitPeers && !mapRequest.Stream && !mapRequest.ReadOnly { if mapRequest.OmitPeers && !mapRequest.Stream && !mapRequest.ReadOnly {
log.Info(). log.Info().
Caller(). Caller().
@ -78,14 +80,85 @@ func (h *Headscale) handlePoll(
Str("node_key", node.NodeKey.ShortString()). Str("node_key", node.NodeKey.ShortString()).
Str("node", node.Hostname). Str("node", node.Hostname).
Int("cap_ver", int(mapRequest.Version)). Int("cap_ver", int(mapRequest.Version)).
Msg("Received endpoint update") Msg("Received update")
change := node.PeerChangeFromMapRequest(mapRequest)
online := h.nodeNotifier.IsConnected(node.MachineKey)
change.Online = &online
node.ApplyPeerChange(&change)
hostInfoChange := node.Hostinfo.Equal(mapRequest.Hostinfo)
logTracePeerChange(node.Hostname, hostInfoChange, &change)
// Check if the Hostinfo of the node has changed.
// If it has changed, check if there has been a change tod
// the routable IPs of the host and update update them in
// the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the route change.
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if !hostInfoChange {
oldRoutes := node.Hostinfo.RoutableIPs
newRoutes := mapRequest.Hostinfo.RoutableIPs
oldServicesCount := len(node.Hostinfo.Services)
newServicesCount := len(mapRequest.Hostinfo.Services)
now := time.Now().UTC()
node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname
node.Hostinfo = mapRequest.Hostinfo node.Hostinfo = mapRequest.Hostinfo
node.DiscoKey = mapRequest.DiscoKey
node.Endpoints = mapRequest.Endpoints sendUpdate := false
// Route changes come as part of Hostinfo, which means that
// when an update comes, the Node Route logic need to run.
// This will require a "change" in comparison to a "patch",
// which is more costly.
if !xslices.Equal(oldRoutes, newRoutes) {
var err error
sendUpdate, err = h.db.SaveNodeRoutes(node)
if err != nil {
logErr(err, "Error processing node routes")
http.Error(writer, "", http.StatusInternalServerError)
return
}
}
// Services is mostly useful for discovery and not critical,
// except for peerapi, which is how nodes talk to eachother.
// If peerapi was not part of the initial mapresponse, we
// need to make sure its sent out later as it is needed for
// Taildrop.
// TODO(kradalby): Length comparison is a bit naive, replace.
if oldServicesCount != newServicesCount {
sendUpdate = true
}
if sendUpdate {
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
stateUpdate := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: types.Nodes{node},
Message: "called from handlePoll -> update -> new hostinfo",
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(
stateUpdate,
node.MachineKey.String())
}
return
}
}
if err := h.db.NodeSave(node); err != nil { if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database") logErr(err, "Failed to persist/update node in the database")
@ -94,20 +167,15 @@ func (h *Headscale) handlePoll(
return return
} }
err := h.db.SaveNodeRoutes(node) stateUpdate := types.StateUpdate{
if err != nil { Type: types.StatePeerChangedPatch,
logErr(err, "Error processing node routes") ChangePatches: []*tailcfg.PeerChange{&change},
http.Error(writer, "", http.StatusInternalServerError)
return
} }
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore( h.nodeNotifier.NotifyWithIgnore(
types.StateUpdate{ stateUpdate,
Type: types.StatePeerChanged,
Changed: types.Nodes{node},
},
node.MachineKey.String()) node.MachineKey.String())
}
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
if f, ok := writer.(http.Flusher); ok { if f, ok := writer.(http.Flusher); ok {
@ -115,7 +183,7 @@ func (h *Headscale) handlePoll(
} }
return return
} else if mapRequest.OmitPeers && !mapRequest.Stream && mapRequest.ReadOnly {
// ReadOnly is whether the client just wants to fetch the // ReadOnly is whether the client just wants to fetch the
// MapResponse, without updating their Endpoints. The // MapResponse, without updating their Endpoints. The
// Endpoints field will be ignored and LastSeen will not be // Endpoints field will be ignored and LastSeen will not be
@ -133,12 +201,39 @@ func (h *Headscale) handlePoll(
return return
} }
now := time.Now().UTC() change := node.PeerChangeFromMapRequest(mapRequest)
node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname // A stream is being set up, the node is Online
online := true
change.Online = &online
node.ApplyPeerChange(&change)
// Only save HostInfo if changed, update routes if changed
// TODO(kradalby): Remove when capver is over 68
if !node.Hostinfo.Equal(mapRequest.Hostinfo) {
oldRoutes := node.Hostinfo.RoutableIPs
newRoutes := mapRequest.Hostinfo.RoutableIPs
node.Hostinfo = mapRequest.Hostinfo node.Hostinfo = mapRequest.Hostinfo
node.DiscoKey = mapRequest.DiscoKey
node.Endpoints = mapRequest.Endpoints if !xslices.Equal(oldRoutes, newRoutes) {
_, err := h.db.SaveNodeRoutes(node)
if err != nil {
logErr(err, "Error processing node routes")
http.Error(writer, "", http.StatusInternalServerError)
return
}
}
}
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
// When a node connects to control, list the peers it has at // When a node connects to control, list the peers it has at
// that given point, further updates are kept in memory in // that given point, further updates are kept in memory in
@ -152,6 +247,11 @@ func (h *Headscale) handlePoll(
return return
} }
for _, peer := range peers {
online := h.nodeNotifier.IsConnected(peer.MachineKey)
peer.IsOnline = &online
}
mapp := mapper.NewMapper( mapp := mapper.NewMapper(
node, node,
peers, peers,
@ -162,11 +262,6 @@ func (h *Headscale) handlePoll(
h.cfg.RandomizeClientPort, h.cfg.RandomizeClientPort,
) )
err = h.db.SaveNodeRoutes(node)
if err != nil {
logErr(err, "Error processing node routes")
}
// update ACLRules with peer informations (to update server tags if necessary) // update ACLRules with peer informations (to update server tags if necessary)
if h.ACLPolicy != nil { if h.ACLPolicy != nil {
// update routes with peer information // update routes with peer information
@ -176,14 +271,6 @@ func (h *Headscale) handlePoll(
} }
} }
// TODO(kradalby): Save specific stuff, not whole object.
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
logInfo("Sending initial map") logInfo("Sending initial map")
mapResp, err := mapp.FullMapResponse(mapRequest, node, h.ACLPolicy) mapResp, err := mapp.FullMapResponse(mapRequest, node, h.ACLPolicy)
@ -208,18 +295,26 @@ func (h *Headscale) handlePoll(
return return
} }
h.nodeNotifier.NotifyWithIgnore( stateUpdate := types.StateUpdate{
types.StateUpdate{
Type: types.StatePeerChanged, Type: types.StatePeerChanged,
Changed: types.Nodes{node}, ChangeNodes: types.Nodes{node},
}, Message: "called from handlePoll -> new node added",
}
if stateUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(
stateUpdate,
node.MachineKey.String()) node.MachineKey.String())
}
// Set up the client stream // Set up the client stream
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
updateChan := make(chan types.StateUpdate) // Use a buffered channel in case a node is not fully ready
// to receive a message to make sure we dont block the entire
// notifier.
// 12 is arbitrarily chosen.
updateChan := make(chan types.StateUpdate, 12)
defer closeChanWithLog(updateChan, node.Hostname, "updateChan") defer closeChanWithLog(updateChan, node.Hostname, "updateChan")
// Register the node's update channel // Register the node's update channel
@ -233,6 +328,10 @@ func (h *Headscale) handlePoll(
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
if len(node.Routes) > 0 {
go h.db.EnsureFailoverRouteIsAvailable(node)
}
for { for {
logInfo("Waiting for update on stream channel") logInfo("Waiting for update on stream channel")
select { select {
@ -262,14 +361,7 @@ func (h *Headscale) handlePoll(
// One alternative is to split these different channels into // One alternative is to split these different channels into
// goroutines, but then you might have a problem without a lock // goroutines, but then you might have a problem without a lock
// if a keepalive is written at the same time as an update. // if a keepalive is written at the same time as an update.
go func() { go h.updateNodeOnlineStatus(true, node)
err = h.db.UpdateLastSeen(node)
if err != nil {
logErr(err, "Cannot update node LastSeen")
return
}
}()
case update := <-updateChan: case update := <-updateChan:
logInfo("Received update") logInfo("Received update")
@ -279,18 +371,35 @@ func (h *Headscale) handlePoll(
var err error var err error
switch update.Type { switch update.Type {
case types.StateFullUpdate:
logInfo("Sending Full MapResponse")
data, err = mapp.FullMapResponse(mapRequest, node, h.ACLPolicy)
case types.StatePeerChanged: case types.StatePeerChanged:
logInfo("Sending PeerChanged MapResponse") logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message))
data, err = mapp.PeerChangedResponse(mapRequest, node, update.Changed, h.ACLPolicy)
for _, node := range update.ChangeNodes {
// If a node is not reported to be online, it might be
// because the value is outdated, check with the notifier.
// However, if it is set to Online, and not in the notifier,
// this might be because it has announced itself, but not
// reached the stage to actually create the notifier channel.
if node.IsOnline != nil && !*node.IsOnline {
isOnline := h.nodeNotifier.IsConnected(node.MachineKey)
node.IsOnline = &isOnline
}
}
data, err = mapp.PeerChangedResponse(mapRequest, node, update.ChangeNodes, h.ACLPolicy, update.Message)
case types.StatePeerChangedPatch:
logInfo("Sending PeerChangedPatch MapResponse")
data, err = mapp.PeerChangedPatchResponse(mapRequest, node, update.ChangePatches, h.ACLPolicy)
case types.StatePeerRemoved: case types.StatePeerRemoved:
logInfo("Sending PeerRemoved MapResponse") logInfo("Sending PeerRemoved MapResponse")
data, err = mapp.PeerRemovedResponse(mapRequest, node, update.Removed) data, err = mapp.PeerRemovedResponse(mapRequest, node, update.Removed)
case types.StateDERPUpdated: case types.StateDERPUpdated:
logInfo("Sending DERPUpdate MapResponse") logInfo("Sending DERPUpdate MapResponse")
data, err = mapp.DERPMapResponse(mapRequest, node, update.DERPMap) data, err = mapp.DERPMapResponse(mapRequest, node, update.DERPMap)
case types.StateFullUpdate:
logInfo("Sending Full MapResponse")
data, err = mapp.FullMapResponse(mapRequest, node, h.ACLPolicy)
} }
if err != nil { if err != nil {
@ -299,6 +408,8 @@ func (h *Headscale) handlePoll(
return return
} }
// Only send update if there is change
if data != nil {
_, err = writer.Write(data) _, err = writer.Write(data)
if err != nil { if err != nil {
logErr(err, "Could not write the map response") logErr(err, "Could not write the map response")
@ -317,36 +428,25 @@ func (h *Headscale) handlePoll(
return return
} }
// See comment in keepAliveTicker
go func() {
err = h.db.UpdateLastSeen(node)
if err != nil {
logErr(err, "Cannot update node LastSeen")
return
}
}()
log.Info(). log.Info().
Caller(). Caller().
Bool("readOnly", mapRequest.ReadOnly). Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers). Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream). Bool("stream", mapRequest.Stream).
Str("node_key", node.NodeKey.ShortString()). Str("node_key", node.NodeKey.ShortString()).
Str("machine_key", node.MachineKey.ShortString()).
Str("node", node.Hostname). Str("node", node.Hostname).
TimeDiff("timeSpent", time.Now(), now). TimeDiff("timeSpent", time.Now(), now).
Msg("update sent") Msg("update sent")
}
case <-ctx.Done(): case <-ctx.Done():
logInfo("The client has closed the connection") logInfo("The client has closed the connection")
go func() { go h.updateNodeOnlineStatus(false, node)
err = h.db.UpdateLastSeen(node)
if err != nil {
logErr(err, "Cannot update node LastSeen")
return // Failover the node's routes if any.
} go h.db.FailoverNodeRoutesWithNotify(node)
}()
// The connection has been closed, so we can stop polling. // The connection has been closed, so we can stop polling.
return return
@ -359,6 +459,36 @@ func (h *Headscale) handlePoll(
} }
} }
// updateNodeOnlineStatus records the last seen status of a node and notifies peers
// about change in their online/offline status.
// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged.
func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
now := time.Now()
node.LastSeen = &now
statusUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: tailcfg.NodeID(node.ID),
Online: &online,
LastSeen: &now,
},
},
}
if statusUpdate.Valid() {
h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String())
}
err := h.db.UpdateLastSeen(node)
if err != nil {
log.Error().Err(err).Msg("Cannot update node LastSeen")
return
}
}
func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, node, name string) { func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, node, name string) {
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
@ -378,8 +508,6 @@ func (h *Headscale) handleLiteRequest(
mapp := mapper.NewMapper( mapp := mapper.NewMapper(
node, node,
// TODO(kradalby): It might not be acceptable to send
// an empty peer list here.
types.Nodes{}, types.Nodes{},
h.DERPMap, h.DERPMap,
h.cfg.BaseDomain, h.cfg.BaseDomain,
@ -405,3 +533,38 @@ func (h *Headscale) handleLiteRequest(
logErr(err, "Failed to write response") logErr(err, "Failed to write response")
} }
} }
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
trace := log.Trace().Str("node_id", change.NodeID.String()).Str("hostname", hostname)
if change.Key != nil {
trace = trace.Str("node_key", change.Key.ShortString())
}
if change.DiscoKey != nil {
trace = trace.Str("disco_key", change.DiscoKey.ShortString())
}
if change.Online != nil {
trace = trace.Bool("online", *change.Online)
}
if change.Endpoints != nil {
eps := make([]string, len(change.Endpoints))
for idx, ep := range change.Endpoints {
eps[idx] = ep.String()
}
trace = trace.Strs("endpoints", eps)
}
if hostinfoChange {
trace = trace.Bool("hostinfo_changed", hostinfoChange)
}
if change.DERPRegion != 0 {
trace = trace.Int("derp_region", change.DERPRegion)
}
trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received")
}

View file

@ -84,20 +84,31 @@ type StateUpdateType int
const ( const (
StateFullUpdate StateUpdateType = iota StateFullUpdate StateUpdateType = iota
// StatePeerChanged is used for updates that needs
// to be calculated with all peers and all policy rules.
// This would typically be things that include tags, routes
// and similar.
StatePeerChanged StatePeerChanged
StatePeerChangedPatch
StatePeerRemoved StatePeerRemoved
StateDERPUpdated StateDERPUpdated
) )
// StateUpdate is an internal message containing information about // StateUpdate is an internal message containing information about
// a state change that has happened to the network. // a state change that has happened to the network.
// If type is StateFullUpdate, all fields are ignored.
type StateUpdate struct { type StateUpdate struct {
// The type of update // The type of update
Type StateUpdateType Type StateUpdateType
// Changed must be set when Type is StatePeerChanged and // ChangeNodes must be set when Type is StatePeerAdded
// contain the Node IDs of nodes that have changed. // and StatePeerChanged and contains the full node
Changed Nodes // object for added nodes.
ChangeNodes Nodes
// ChangePatches must be set when Type is StatePeerChangedPatch
// and contains a populated PeerChange object.
ChangePatches []*tailcfg.PeerChange
// Removed must be set when Type is StatePeerRemoved and // Removed must be set when Type is StatePeerRemoved and
// contain a list of the nodes that has been removed from // contain a list of the nodes that has been removed from
@ -106,5 +117,36 @@ type StateUpdate struct {
// DERPMap must be set when Type is StateDERPUpdated and // DERPMap must be set when Type is StateDERPUpdated and
// contain the new DERP Map. // contain the new DERP Map.
DERPMap tailcfg.DERPMap DERPMap *tailcfg.DERPMap
// Additional message for tracking origin or what being
// updated, useful for ambiguous updates like StatePeerChanged.
Message string
}
// Valid reports if a StateUpdate is correctly filled and
// panics if the mandatory fields for a type is not
// filled.
// Reports true if valid.
func (su *StateUpdate) Valid() bool {
switch su.Type {
case StatePeerChanged:
if su.ChangeNodes == nil {
panic("Mandatory field ChangeNodes is not set on StatePeerChanged update")
}
case StatePeerChangedPatch:
if su.ChangePatches == nil {
panic("Mandatory field ChangePatches is not set on StatePeerChangedPatch update")
}
case StatePeerRemoved:
if su.Removed == nil {
panic("Mandatory field Removed is not set on StatePeerRemove update")
}
case StateDERPUpdated:
if su.DERPMap == nil {
panic("Mandatory field DERPMap is not set on StateDERPUpdated update")
}
}
return true
} }

View file

@ -21,7 +21,9 @@ import (
var ( var (
ErrNodeAddressesInvalid = errors.New("failed to parse node addresses") ErrNodeAddressesInvalid = errors.New("failed to parse node addresses")
ErrHostnameTooLong = errors.New("hostname too long") ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars")
ErrNodeHasNoGivenName = errors.New("node has no given name")
ErrNodeUserHasNoName = errors.New("node user has no name")
) )
// Node is a Headscale client. // Node is a Headscale client.
@ -95,22 +97,14 @@ type Node struct {
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt *time.Time DeletedAt *time.Time
IsOnline *bool `gorm:"-"`
} }
type ( type (
Nodes []*Node Nodes []*Node
) )
func (nodes Nodes) OnlineNodeMap() map[tailcfg.NodeID]bool {
ret := make(map[tailcfg.NodeID]bool)
for _, node := range nodes {
ret[tailcfg.NodeID(node.ID)] = node.IsOnline()
}
return ret
}
type NodeAddresses []netip.Addr type NodeAddresses []netip.Addr
func (na NodeAddresses) Sort() { func (na NodeAddresses) Sort() {
@ -206,21 +200,6 @@ func (node Node) IsExpired() bool {
return time.Now().UTC().After(*node.Expiry) return time.Now().UTC().After(*node.Expiry)
} }
// IsOnline returns if the node is connected to Headscale.
// This is really a naive implementation, as we don't really see
// if there is a working connection between the client and the server.
func (node *Node) IsOnline() bool {
if node.LastSeen == nil {
return false
}
if node.IsExpired() {
return false
}
return node.LastSeen.After(time.Now().Add(-KeepAliveInterval))
}
// IsEphemeral returns if the node is registered as an Ephemeral node. // IsEphemeral returns if the node is registered as an Ephemeral node.
// https://tailscale.com/kb/1111/ephemeral-nodes/ // https://tailscale.com/kb/1111/ephemeral-nodes/
func (node *Node) IsEphemeral() bool { func (node *Node) IsEphemeral() bool {
@ -339,7 +318,6 @@ func (node *Node) Proto() *v1.Node {
GivenName: node.GivenName, GivenName: node.GivenName,
User: node.User.Proto(), User: node.User.Proto(),
ForcedTags: node.ForcedTags, ForcedTags: node.ForcedTags,
Online: node.IsOnline(),
// TODO(kradalby): Implement register method enum converter // TODO(kradalby): Implement register method enum converter
// RegisterMethod: , // RegisterMethod: ,
@ -365,6 +343,14 @@ func (node *Node) Proto() *v1.Node {
func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) { func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) {
var hostname string var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
if node.GivenName == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName)
}
if node.User.Name == "" {
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName)
}
hostname = fmt.Sprintf( hostname = fmt.Sprintf(
"%s.%s.%s", "%s.%s.%s",
node.GivenName, node.GivenName,
@ -373,7 +359,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
) )
if len(hostname) > MaxHostnameLength { if len(hostname) > MaxHostnameLength {
return "", fmt.Errorf( return "", fmt.Errorf(
"hostname %q is too long it cannot except 255 ASCII chars: %w", "failed to create valid FQDN (%s): %w",
hostname, hostname,
ErrHostnameTooLong, ErrHostnameTooLong,
) )
@ -385,8 +371,98 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri
return hostname, nil return hostname, nil
} }
func (node Node) String() string { // func (node *Node) String() string {
return node.Hostname // return node.Hostname
// }
// PeerChangeFromMapRequest takes a MapRequest and compares it to the node
// to produce a PeerChange struct that can be used to updated the node and
// inform peers about smaller changes to the node.
// When a field is added to this function, remember to also add it to:
// - node.ApplyPeerChange
// - logTracePeerChange in poll.go
func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
ret := tailcfg.PeerChange{
NodeID: tailcfg.NodeID(node.ID),
}
if node.NodeKey.String() != req.NodeKey.String() {
ret.Key = &req.NodeKey
}
if node.DiscoKey.String() != req.DiscoKey.String() {
ret.DiscoKey = &req.DiscoKey
}
if node.Hostinfo != nil &&
node.Hostinfo.NetInfo != nil &&
req.Hostinfo != nil &&
req.Hostinfo.NetInfo != nil &&
node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
}
if req.Hostinfo != nil && req.Hostinfo.NetInfo != nil {
// If there is no stored Hostinfo or NetInfo, use
// the new PreferredDERP.
if node.Hostinfo == nil {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
} else if node.Hostinfo.NetInfo == nil {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
} else {
// If there is a PreferredDERP check if it has changed.
if node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP {
ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP
}
}
}
// TODO(kradalby): Find a good way to compare updates
ret.Endpoints = req.Endpoints
now := time.Now()
ret.LastSeen = &now
return ret
}
// ApplyPeerChange takes a PeerChange struct and updates the node.
func (node *Node) ApplyPeerChange(change *tailcfg.PeerChange) {
if change.Key != nil {
node.NodeKey = *change.Key
}
if change.DiscoKey != nil {
node.DiscoKey = *change.DiscoKey
}
if change.Online != nil {
node.IsOnline = change.Online
}
if change.Endpoints != nil {
node.Endpoints = change.Endpoints
}
// This might technically not be useful as we replace
// the whole hostinfo blob when it has changed.
if change.DERPRegion != 0 {
if node.Hostinfo == nil {
node.Hostinfo = &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: change.DERPRegion,
},
}
} else if node.Hostinfo.NetInfo == nil {
node.Hostinfo.NetInfo = &tailcfg.NetInfo{
PreferredDERP: change.DERPRegion,
}
} else {
node.Hostinfo.NetInfo.PreferredDERP = change.DERPRegion
}
}
node.LastSeen = change.LastSeen
} }
func (nodes Nodes) String() string { func (nodes Nodes) String() string {

View file

@ -4,7 +4,10 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
func Test_NodeCanAccess(t *testing.T) { func Test_NodeCanAccess(t *testing.T) {
@ -139,3 +142,227 @@ func TestNodeAddressesOrder(t *testing.T) {
} }
} }
} }
func TestNodeFQDN(t *testing.T) {
tests := []struct {
name string
node Node
dns tailcfg.DNSConfig
domain string
want string
wantErr string
}{
{
name: "all-set",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
dns: tailcfg.DNSConfig{
Proxied: true,
},
domain: "example.com",
want: "test.user.example.com",
},
{
name: "no-given-name",
node: Node{
User: User{
Name: "user",
},
},
dns: tailcfg.DNSConfig{
Proxied: true,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node has no given name",
},
{
name: "no-user-name",
node: Node{
GivenName: "test",
User: User{},
},
dns: tailcfg.DNSConfig{
Proxied: true,
},
domain: "example.com",
wantErr: "failed to create valid FQDN: node user has no name",
},
{
name: "no-magic-dns",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
dns: tailcfg.DNSConfig{
Proxied: false,
},
domain: "example.com",
want: "test",
},
{
name: "no-dnsconfig",
node: Node{
GivenName: "test",
User: User{
Name: "user",
},
},
domain: "example.com",
want: "test",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.node.GetFQDN(&tc.dns, tc.domain)
if (err != nil) && (err.Error() != tc.wantErr) {
t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr)
return
}
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("GetFQDN unexpected result (-want +got):\n%s", diff)
}
})
}
}
func TestPeerChangeFromMapRequest(t *testing.T) {
nKeys := []key.NodePublic{
key.NewNode().Public(),
key.NewNode().Public(),
key.NewNode().Public(),
}
dKeys := []key.DiscoPublic{
key.NewDisco().Public(),
key.NewDisco().Public(),
key.NewDisco().Public(),
}
tests := []struct {
name string
node Node
mapReq tailcfg.MapRequest
want tailcfg.PeerChange
}{
{
name: "preferred-derp-changed",
node: Node{
ID: 1,
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Endpoints: []netip.AddrPort{},
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 998,
},
},
},
mapReq: tailcfg.MapRequest{
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 999,
},
},
},
want: tailcfg.PeerChange{
NodeID: 1,
DERPRegion: 999,
},
},
{
name: "preferred-derp-no-changed",
node: Node{
ID: 1,
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Endpoints: []netip.AddrPort{},
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 100,
},
},
},
mapReq: tailcfg.MapRequest{
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 100,
},
},
},
want: tailcfg.PeerChange{
NodeID: 1,
DERPRegion: 0,
},
},
{
name: "preferred-derp-no-mapreq-netinfo",
node: Node{
ID: 1,
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Endpoints: []netip.AddrPort{},
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 200,
},
},
},
mapReq: tailcfg.MapRequest{
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Hostinfo: &tailcfg.Hostinfo{},
},
want: tailcfg.PeerChange{
NodeID: 1,
DERPRegion: 0,
},
},
{
name: "preferred-derp-no-node-netinfo",
node: Node{
ID: 1,
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Endpoints: []netip.AddrPort{},
Hostinfo: &tailcfg.Hostinfo{},
},
mapReq: tailcfg.MapRequest{
NodeKey: nKeys[0],
DiscoKey: dKeys[0],
Hostinfo: &tailcfg.Hostinfo{
NetInfo: &tailcfg.NetInfo{
PreferredDERP: 200,
},
},
},
want: tailcfg.PeerChange{
NodeID: 1,
DERPRegion: 200,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.node.PeerChangeFromMapRequest(tc.mapReq)
if diff := cmp.Diff(tc.want, got, cmpopts.IgnoreFields(tailcfg.PeerChange{}, "LastSeen")); diff != "" {
t.Errorf("Patch unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -19,6 +19,8 @@ type Route struct {
NodeID uint64 NodeID uint64
Node Node Node Node
// TODO(kradalby): change this custom type to netip.Prefix
Prefix IPPrefix Prefix IPPrefix
Advertised bool Advertised bool
@ -29,13 +31,17 @@ type Route struct {
type Routes []Route type Routes []Route
func (r *Route) String() string { func (r *Route) String() string {
return fmt.Sprintf("%s:%s", r.Node, netip.Prefix(r.Prefix).String()) return fmt.Sprintf("%s:%s", r.Node.Hostname, netip.Prefix(r.Prefix).String())
} }
func (r *Route) IsExitRoute() bool { func (r *Route) IsExitRoute() bool {
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
} }
func (r *Route) IsAnnouncable() bool {
return r.Advertised && r.Enabled
}
func (rs Routes) Prefixes() []netip.Prefix { func (rs Routes) Prefixes() []netip.Prefix {
prefixes := make([]netip.Prefix, len(rs)) prefixes := make([]netip.Prefix, len(rs))
for i, r := range rs { for i, r := range rs {
@ -45,6 +51,32 @@ func (rs Routes) Prefixes() []netip.Prefix {
return prefixes return prefixes
} }
// Primaries returns Primary routes from a list of routes.
func (rs Routes) Primaries() Routes {
res := make(Routes, 0)
for _, route := range rs {
if route.IsPrimary {
res = append(res, route)
}
}
return res
}
func (rs Routes) PrefixMap() map[IPPrefix][]Route {
res := map[IPPrefix][]Route{}
for _, route := range rs {
if _, ok := res[route.Prefix]; ok {
res[route.Prefix] = append(res[route.Prefix], route)
} else {
res[route.Prefix] = []Route{route}
}
}
return res
}
func (rs Routes) Proto() []*v1.Route { func (rs Routes) Proto() []*v1.Route {
protoRoutes := []*v1.Route{} protoRoutes := []*v1.Route{}

View file

@ -0,0 +1,94 @@
package types
import (
"fmt"
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/util"
)
func TestPrefixMap(t *testing.T) {
ipp := func(s string) IPPrefix { return IPPrefix(netip.MustParsePrefix(s)) }
// TODO(kradalby): Remove when we have gotten rid of IPPrefix type
prefixComparer := cmp.Comparer(func(x, y IPPrefix) bool {
return x == y
})
tests := []struct {
rs Routes
want map[IPPrefix][]Route
}{
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
want: map[IPPrefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
},
},
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
Route{
Prefix: ipp("10.0.1.0/24"),
},
},
want: map[IPPrefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
},
},
ipp("10.0.1.0/24"): Routes{
Route{
Prefix: ipp("10.0.1.0/24"),
},
},
},
},
{
rs: Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: true,
},
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: false,
},
},
want: map[IPPrefix][]Route{
ipp("10.0.0.0/24"): Routes{
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: true,
},
Route{
Prefix: ipp("10.0.0.0/24"),
Enabled: false,
},
},
},
},
}
for idx, tt := range tests {
t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) {
got := tt.rs.PrefixMap()
if diff := cmp.Diff(tt.want, got, prefixComparer, util.MkeyComparer, util.NkeyComparer, util.DkeyComparer); diff != "" {
t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff)
}
})
}
}

32
hscontrol/util/test.go Normal file
View file

@ -0,0 +1,32 @@
package util
import (
"net/netip"
"github.com/google/go-cmp/cmp"
"tailscale.com/types/key"
)
var PrefixComparer = cmp.Comparer(func(x, y netip.Prefix) bool {
return x == y
})
var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool {
return x.Compare(y) == 0
})
var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool {
return x.String() == y.String()
})
var NkeyComparer = cmp.Comparer(func(x, y key.NodePublic) bool {
return x.String() == y.String()
})
var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
return x.String() == y.String()
})
var Comparers []cmp.Option = []cmp.Option{
IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer,
}

View file

@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort" "sort"
"strconv"
"testing" "testing"
"time" "time"
@ -22,7 +21,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul
err = json.Unmarshal([]byte(str), result) err = json.Unmarshal([]byte(str), result)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str)
} }
return nil return nil
@ -178,7 +177,11 @@ func TestPreAuthKeyCommand(t *testing.T) {
assert.Equal( assert.Equal(
t, t,
[]string{keys[0].GetId(), keys[1].GetId(), keys[2].GetId()}, []string{keys[0].GetId(), keys[1].GetId(), keys[2].GetId()},
[]string{listedPreAuthKeys[1].GetId(), listedPreAuthKeys[2].GetId(), listedPreAuthKeys[3].GetId()}, []string{
listedPreAuthKeys[1].GetId(),
listedPreAuthKeys[2].GetId(),
listedPreAuthKeys[3].GetId(),
},
) )
assert.NotEmpty(t, listedPreAuthKeys[1].GetKey()) assert.NotEmpty(t, listedPreAuthKeys[1].GetKey())
@ -384,141 +387,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
assert.Len(t, listedPreAuthKeys, 3) assert.Len(t, listedPreAuthKeys, 3)
} }
func TestEnablingRoutes(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "enable-routing"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
// advertise routes using the up command
for i, client := range allClients {
routeStr := fmt.Sprintf("10.0.%d.0/24", i)
command := []string{
"tailscale",
"set",
"--advertise-routes=" + routeStr,
}
_, _, err := client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 3)
for _, route := range routes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
}
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)
}
var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
}
routeIDToBeDisabled := enablingRoutes[0].GetId()
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"disable",
"--route",
strconv.Itoa(int(routeIDToBeDisabled)),
})
assertNoErr(t, err)
var disablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&disablingRoutes,
)
assertNoErr(t, err)
for _, route := range disablingRoutes {
assert.Equal(t, true, route.GetAdvertised())
if route.GetId() == routeIDToBeDisabled {
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
} else {
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
}
}
}
func TestApiKeyCommand(t *testing.T) { func TestApiKeyCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()

View file

@ -44,6 +44,9 @@ func TestDERPServerScenario(t *testing.T) {
headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP" headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP"
headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478" headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478"
headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key" headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key"
// Envknob for enabling DERP debug logs
headscaleConfig["DERP_DEBUG_LOGS"] = "true"
headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true"
err = scenario.CreateHeadscaleEnv( err = scenario.CreateHeadscaleEnv(
spec, spec,

View file

@ -14,6 +14,8 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/types/key"
) )
func TestPingAllByIP(t *testing.T) { func TestPingAllByIP(t *testing.T) {
@ -248,9 +250,8 @@ func TestPingAllByHostname(t *testing.T) {
defer scenario.Shutdown() defer scenario.Shutdown()
spec := map[string]int{ spec := map[string]int{
// Omit 1.16.2 (-1) because it does not have the FQDN field "user3": len(MustTestVersions),
"user3": len(MustTestVersions) - 1, "user4": len(MustTestVersions),
"user4": len(MustTestVersions) - 1,
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname"))
@ -296,8 +297,7 @@ func TestTaildrop(t *testing.T) {
defer scenario.Shutdown() defer scenario.Shutdown()
spec := map[string]int{ spec := map[string]int{
// Omit 1.16.2 (-1) because it does not have the FQDN field "taildrop": len(MustTestVersions),
"taildrop": len(MustTestVersions) - 1,
} }
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("taildrop")) err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("taildrop"))
@ -313,6 +313,42 @@ func TestTaildrop(t *testing.T) {
_, err = scenario.ListTailscaleClientsFQDNs() _, err = scenario.ListTailscaleClientsFQDNs()
assertNoErrListFQDN(t, err) assertNoErrListFQDN(t, err)
for _, client := range allClients {
if !strings.Contains(client.Hostname(), "head") {
command := []string{"apk", "add", "curl"}
_, _, err := client.Execute(command)
if err != nil {
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
}
}
curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"}
err = retry(10, 1*time.Second, func() error {
result, _, err := client.Execute(curlCommand)
if err != nil {
return err
}
var fts []apitype.FileTarget
err = json.Unmarshal([]byte(result), &fts)
if err != nil {
return err
}
if len(fts) != len(allClients)-1 {
ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname())
for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
}
return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr)
}
return err
})
if err != nil {
t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err)
}
}
for _, client := range allClients { for _, client := range allClients {
command := []string{"touch", fmt.Sprintf("/tmp/file_from_%s", client.Hostname())} command := []string{"touch", fmt.Sprintf("/tmp/file_from_%s", client.Hostname())}
@ -347,8 +383,9 @@ func TestTaildrop(t *testing.T) {
}) })
if err != nil { if err != nil {
t.Fatalf( t.Fatalf(
"failed to send taildrop file on %s, err: %s", "failed to send taildrop file on %s with command %q, err: %s",
client.Hostname(), client.Hostname(),
strings.Join(command, " "),
err, err,
) )
} }
@ -517,10 +554,139 @@ func TestExpireNode(t *testing.T) {
err = json.Unmarshal([]byte(result), &node) err = json.Unmarshal([]byte(result), &node)
assertNoErr(t, err) assertNoErr(t, err)
var expiredNodeKey key.NodePublic
err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey()))
assertNoErr(t, err)
t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String())
time.Sleep(30 * time.Second) time.Sleep(30 * time.Second)
// Verify that the expired not is no longer present in the Peer list now := time.Now()
// of connected nodes.
// Verify that the expired node has been marked in all peers list.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
if client.Hostname() != node.GetName() {
t.Logf("available peers of %s: %v", client.Hostname(), status.Peers())
// In addition to marking nodes expired, we filter them out during the map response
// this check ensures that the node is either not present, or that it is expired
// if it is in the map response.
if peerStatus, ok := status.Peer[expiredNodeKey]; ok {
assertNotNil(t, peerStatus.Expired)
assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %s should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry)
assert.Truef(t, peerStatus.Expired, "node %s should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired)
}
// TODO(kradalby): We do not propogate expiry correctly, nodes should be aware
// of their status, and this should be sent directly to the node when its
// expired. This needs a notifier that goes directly to the node (currently we only do peers)
// so fix this in a follow up PR.
// } else {
// assert.True(t, status.Self.Expired)
}
}
}
func TestNodeOnlineLastSeenStatus(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario()
assertNoErr(t, err)
defer scenario.Shutdown()
spec := map[string]int{
"user1": len(MustTestVersions),
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("onlinelastseen"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps))
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
// Assert that we have the original count - self
assert.Len(t, status.Peers(), len(MustTestVersions)-1)
}
headscale, err := scenario.Headscale()
assertNoErr(t, err)
keepAliveInterval := 60 * time.Second
// Duration is chosen arbitrarily, 10m is reported in #1561
testDuration := 12 * time.Minute
start := time.Now()
end := start.Add(testDuration)
log.Printf("Starting online test from %v to %v", start, end)
for {
// Let the test run continuously for X minutes to verify
// all nodes stay connected and has the expected status over time.
if end.Before(time.Now()) {
return
}
result, err := headscale.Execute([]string{
"headscale", "nodes", "list", "--output", "json",
})
assertNoErr(t, err)
var nodes []*v1.Node
err = json.Unmarshal([]byte(result), &nodes)
assertNoErr(t, err)
now := time.Now()
// Threshold with some leeway
lastSeenThreshold := now.Add(-keepAliveInterval - (10 * time.Second))
// Verify that headscale reports the nodes as online
for _, node := range nodes {
// All nodes should be online
assert.Truef(
t,
node.GetOnline(),
"expected %s to have online status in Headscale, marked as offline %s after start",
node.GetName(),
time.Since(start),
)
lastSeen := node.GetLastSeen().AsTime()
// All nodes should have been last seen between now and the keepAliveInterval
assert.Truef(
t,
lastSeen.After(lastSeenThreshold),
"lastSeen (%v) was not %s after the threshold (%v)",
lastSeen,
keepAliveInterval,
lastSeenThreshold,
)
}
// Verify that all nodes report all nodes to be online
for _, client := range allClients { for _, client := range allClients {
status, err := client.Status() status, err := client.Status()
assertNoErr(t, err) assertNoErr(t, err)
@ -528,14 +694,36 @@ func TestExpireNode(t *testing.T) {
for _, peerKey := range status.Peers() { for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey] peerStatus := status.Peer[peerKey]
peerPublicKey := strings.TrimPrefix(peerStatus.PublicKey.String(), "nodekey:") // .Online is only available from CapVer 16, which
// is not present in 1.18 which is the lowest we
assert.NotEqual(t, node.GetNodeKey(), peerPublicKey) // test.
if strings.Contains(client.Hostname(), "1-18") {
continue
} }
if client.Hostname() != node.GetName() { // All peers of this nodess are reporting to be
// Assert that we have the original count - self - expired node // connected to the control server
assert.Len(t, status.Peers(), len(MustTestVersions)-2) assert.Truef(
t,
peerStatus.Online,
"expected node %s to be marked as online in %s peer list, marked as offline %s after start",
peerStatus.HostName,
client.Hostname(),
time.Since(start),
)
// from docs: last seen to tailcontrol; only present if offline
// assert.Nilf(
// t,
// peerStatus.LastSeen,
// "expected node %s to not have LastSeen set, got %s",
// peerStatus.HostName,
// peerStatus.LastSeen,
// )
} }
} }
// Check maximum once per second
time.Sleep(time.Second)
}
} }

780
integration/route_test.go Normal file
View file

@ -0,0 +1,780 @@
package integration
import (
"fmt"
"log"
"net/netip"
"sort"
"strconv"
"testing"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
)
// This test is both testing the routes command and the propagation of
// routes.
func TestEnablingRoutes(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "enable-routing"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
expectedRoutes := map[string]string{
"1": "10.0.0.0/24",
"2": "10.0.1.0/24",
"3": "10.0.2.0/24",
}
// advertise routes using the up command
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
command := []string{
"tailscale",
"set",
"--advertise-routes=" + expectedRoutes[string(status.Self.ID)],
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 3)
for _, route := range routes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
}
// Verify that no routes has been sent to the client,
// they are not yet enabled.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.Nil(t, peerStatus.PrimaryRoutes)
}
}
// Enable all routes
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)
}
var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 3)
for _, route := range enablingRoutes {
assert.Equal(t, route.GetAdvertised(), true)
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
}
time.Sleep(5 * time.Second)
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.NotNil(t, peerStatus.PrimaryRoutes)
if peerStatus.PrimaryRoutes == nil {
continue
}
pRoutes := peerStatus.PrimaryRoutes.AsSlice()
assert.Len(t, pRoutes, 1)
if len(pRoutes) > 0 {
peerRoute := peerStatus.PrimaryRoutes.AsSlice()[0]
// id starts at 1, we created routes with 0 index
assert.Equalf(
t,
expectedRoutes[string(peerStatus.ID)],
peerRoute.String(),
"expected route %s to be present on peer %s (%s) in %s (%s) status",
expectedRoutes[string(peerStatus.ID)],
peerStatus.HostName,
peerStatus.ID,
client.Hostname(),
client.ID(),
)
}
}
}
routeToBeDisabled := enablingRoutes[0]
log.Printf("preparing to disable %v", routeToBeDisabled)
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"disable",
"--route",
strconv.Itoa(int(routeToBeDisabled.GetId())),
})
assertNoErr(t, err)
var disablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&disablingRoutes,
)
assertNoErr(t, err)
for _, route := range disablingRoutes {
assert.Equal(t, true, route.GetAdvertised())
if route.GetId() == routeToBeDisabled.GetId() {
assert.Equal(t, route.GetEnabled(), false)
assert.Equal(t, route.GetIsPrimary(), false)
} else {
assert.Equal(t, route.GetEnabled(), true)
assert.Equal(t, route.GetIsPrimary(), true)
}
}
time.Sleep(5 * time.Second)
// Verify that the clients can see the new routes
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
if string(peerStatus.ID) == fmt.Sprintf("%d", routeToBeDisabled.GetNode().GetId()) {
assert.Nilf(
t,
peerStatus.PrimaryRoutes,
"expected node %s to have no routes, got primary route (%v)",
peerStatus.HostName,
peerStatus.PrimaryRoutes,
)
}
}
}
}
func TestHASubnetRouterFailover(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user := "enable-routing"
scenario, err := NewScenario()
assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown()
spec := map[string]int{
user: 3,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
headscale, err := scenario.Headscale()
assertNoErrGetHeadscale(t, err)
expectedRoutes := map[string]string{
"1": "10.0.0.0/24",
"2": "10.0.0.0/24",
}
// Sort nodes by ID
sort.SliceStable(allClients, func(i, j int) bool {
statusI, err := allClients[i].Status()
if err != nil {
return false
}
statusJ, err := allClients[j].Status()
if err != nil {
return false
}
return statusI.Self.ID < statusJ.Self.ID
})
subRouter1 := allClients[0]
subRouter2 := allClients[1]
client := allClients[2]
// advertise HA route on node 1 and 2
// ID 1 will be primary
// ID 2 will be secondary
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
if route, ok := expectedRoutes[string(status.Self.ID)]; ok {
command := []string{
"tailscale",
"set",
"--advertise-routes=" + route,
}
_, _, err = client.Execute(command)
assertNoErrf(t, "failed to advertise route: %s", err)
}
}
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
var routes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routes,
)
assertNoErr(t, err)
assert.Len(t, routes, 2)
for _, route := range routes {
assert.Equal(t, true, route.GetAdvertised())
assert.Equal(t, false, route.GetEnabled())
assert.Equal(t, false, route.GetIsPrimary())
}
// Verify that no routes has been sent to the client,
// they are not yet enabled.
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
assert.Nil(t, peerStatus.PrimaryRoutes)
}
}
// Enable all routes
for _, route := range routes {
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
strconv.Itoa(int(route.GetId())),
})
assertNoErr(t, err)
time.Sleep(time.Second)
}
var enablingRoutes []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&enablingRoutes,
)
assertNoErr(t, err)
assert.Len(t, enablingRoutes, 2)
// Node 1 is primary
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
// Node 2 is not primary
assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
assert.Equal(t, true, enablingRoutes[1].GetEnabled())
assert.Equal(t, false, enablingRoutes[1].GetIsPrimary())
// Verify that the client has routes from the primary machine
srs1, err := subRouter1.Status()
srs2, err := subRouter2.Status()
clientStatus, err := client.Status()
assertNoErr(t, err)
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
// Take down the current primary
t.Logf("taking down subnet router 1 (%s)", subRouter1.Hostname())
err = subRouter1.Down()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterMove []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterMove,
)
assertNoErr(t, err)
assert.Len(t, routesAfterMove, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfterMove[0].GetAdvertised())
assert.Equal(t, true, routesAfterMove[0].GetEnabled())
assert.Equal(t, false, routesAfterMove[0].GetIsPrimary())
// Node 2 is primary
assert.Equal(t, true, routesAfterMove[1].GetAdvertised())
assert.Equal(t, true, routesAfterMove[1].GetEnabled())
assert.Equal(t, true, routesAfterMove[1].GetIsPrimary())
// TODO(kradalby): Check client status
// Route is expected to be on SR2
srs2, err = subRouter2.Status()
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}
// Take down subnet router 2, leaving none available
t.Logf("taking down subnet router 2 (%s)", subRouter2.Hostname())
err = subRouter2.Down()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterBothDown []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterBothDown,
)
assertNoErr(t, err)
assert.Len(t, routesAfterBothDown, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised())
assert.Equal(t, true, routesAfterBothDown[0].GetEnabled())
assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary())
// Node 2 is primary
// if the node goes down, but no other suitable route is
// available, keep the last known good route.
assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised())
assert.Equal(t, true, routesAfterBothDown[1].GetEnabled())
assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary())
// TODO(kradalby): Check client status
// Both are expected to be down
// Verify that the route is not presented from either router
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)
if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}
// Bring up subnet router 1, making the route available from there.
t.Logf("bringing up subnet router 1 (%s)", subRouter1.Hostname())
err = subRouter1.Up()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfter1Up []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfter1Up,
)
assertNoErr(t, err)
assert.Len(t, routesAfter1Up, 2)
// Node 1 is primary
assert.Equal(t, true, routesAfter1Up[0].GetAdvertised())
assert.Equal(t, true, routesAfter1Up[0].GetEnabled())
assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary())
// Node 2 is not primary
assert.Equal(t, true, routesAfter1Up[1].GetAdvertised())
assert.Equal(t, true, routesAfter1Up[1].GetEnabled())
assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
}
// Bring up subnet router 2, should result in no change.
t.Logf("bringing up subnet router 2 (%s)", subRouter2.Hostname())
err = subRouter2.Up()
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfter2Up []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfter2Up,
)
assertNoErr(t, err)
assert.Len(t, routesAfter2Up, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfter2Up[0].GetAdvertised())
assert.Equal(t, true, routesAfter2Up[0].GetEnabled())
assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary())
// Node 2 is primary
assert.Equal(t, true, routesAfter2Up[1].GetAdvertised())
assert.Equal(t, true, routesAfter2Up[1].GetEnabled())
assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
}
// Disable the route of subnet router 1, making it failover to 2
t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"disable",
"--route",
fmt.Sprintf("%d", routesAfter2Up[0].GetId()),
})
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterDisabling1 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterDisabling1,
)
assertNoErr(t, err)
assert.Len(t, routesAfterDisabling1, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())
assert.Equal(t, false, routesAfterDisabling1[0].GetIsPrimary())
// Node 2 is primary
assert.Equal(t, true, routesAfterDisabling1[1].GetAdvertised())
assert.Equal(t, true, routesAfterDisabling1[1].GetEnabled())
assert.Equal(t, true, routesAfterDisabling1[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.NotNil(t, srs2PeerStatus.PrimaryRoutes)
if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}
// enable the route of subnet router 1, no change expected
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
fmt.Sprintf("%d", routesAfter2Up[0].GetId()),
})
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterEnabling1 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterEnabling1,
)
assertNoErr(t, err)
assert.Len(t, routesAfterEnabling1, 2)
// Node 1 is not primary
assert.Equal(t, true, routesAfterEnabling1[0].GetAdvertised())
assert.Equal(t, true, routesAfterEnabling1[0].GetEnabled())
assert.Equal(t, false, routesAfterEnabling1[0].GetIsPrimary())
// Node 2 is primary
assert.Equal(t, true, routesAfterEnabling1[1].GetAdvertised())
assert.Equal(t, true, routesAfterEnabling1[1].GetEnabled())
assert.Equal(t, true, routesAfterEnabling1[1].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.NotNil(t, srs2PeerStatus.PrimaryRoutes)
if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}
// delete the route of subnet router 2, failover to one expected
t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"delete",
"--route",
fmt.Sprintf("%d", routesAfterEnabling1[1].GetId()),
})
assertNoErr(t, err)
time.Sleep(5 * time.Second)
var routesAfterDeleting2 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterDeleting2,
)
assertNoErr(t, err)
assert.Len(t, routesAfterDeleting2, 1)
t.Logf("routes after deleting2 %#v", routesAfterDeleting2)
// Node 1 is primary
assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised())
assert.Equal(t, true, routesAfterDeleting2[0].GetEnabled())
assert.Equal(t, true, routesAfterDeleting2[0].GetIsPrimary())
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]
assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)
if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
}
}

View file

@ -29,7 +29,7 @@ run_tests() {
-failfast \ -failfast \
-timeout 120m \ -timeout 120m \
-parallel 1 \ -parallel 1 \
-run "^$test_name\$" >/dev/null 2>&1 -run "^$test_name\$" >./control_logs/"$test_name"_"$i".log 2>&1
status=$? status=$?
end=$(date +%s) end=$(date +%s)

View file

@ -15,6 +15,7 @@ import (
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/samber/lo"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -93,7 +94,7 @@ var (
// //
// - Two unstable (HEAD and unstable) // - Two unstable (HEAD and unstable)
// - Two latest versions // - Two latest versions
// - Two oldest versions. // - Two oldest supported version.
MustTestVersions = append( MustTestVersions = append(
AllVersions[0:4], AllVersions[0:4],
AllVersions[len(AllVersions)-2:]..., AllVersions[len(AllVersions)-2:]...,
@ -296,11 +297,13 @@ func (s *Scenario) CreateTailscaleNodesInUser(
opts ...tsic.Option, opts ...tsic.Option,
) error { ) error {
if user, ok := s.users[userStr]; ok { if user, ok := s.users[userStr]; ok {
var versions []string
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
version := requestedVersion version := requestedVersion
if requestedVersion == "all" { if requestedVersion == "all" {
version = MustTestVersions[i%len(MustTestVersions)] version = MustTestVersions[i%len(MustTestVersions)]
} }
versions = append(versions, version)
headscale, err := s.Headscale() headscale, err := s.Headscale()
if err != nil { if err != nil {
@ -350,6 +353,8 @@ func (s *Scenario) CreateTailscaleNodesInUser(
return err return err
} }
log.Printf("testing versions %v", lo.Uniq(versions))
return nil return nil
} }
@ -403,7 +408,17 @@ func (s *Scenario) CountTailscale() int {
func (s *Scenario) WaitForTailscaleSync() error { func (s *Scenario) WaitForTailscaleSync() error {
tsCount := s.CountTailscale() tsCount := s.CountTailscale()
return s.WaitForTailscaleSyncWithPeerCount(tsCount - 1) err := s.WaitForTailscaleSyncWithPeerCount(tsCount - 1)
if err != nil {
for _, user := range s.users {
for _, client := range user.Clients {
peers, _ := client.PrettyPeers()
log.Println(peers)
}
}
}
return err
} }
// WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports // WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports

View file

@ -109,7 +109,7 @@ func TestSSHOneUserToAll(t *testing.T) {
}, },
}, },
}, },
len(MustTestVersions)-2, len(MustTestVersions),
) )
defer scenario.Shutdown() defer scenario.Shutdown()
@ -174,7 +174,7 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
}, },
}, },
}, },
len(MustTestVersions)-2, len(MustTestVersions),
) )
defer scenario.Shutdown() defer scenario.Shutdown()
@ -220,7 +220,7 @@ func TestSSHNoSSHConfigured(t *testing.T) {
}, },
SSHs: []policy.SSH{}, SSHs: []policy.SSH{},
}, },
len(MustTestVersions)-2, len(MustTestVersions),
) )
defer scenario.Shutdown() defer scenario.Shutdown()
@ -269,7 +269,7 @@ func TestSSHIsBlockedInACL(t *testing.T) {
}, },
}, },
}, },
len(MustTestVersions)-2, len(MustTestVersions),
) )
defer scenario.Shutdown() defer scenario.Shutdown()
@ -325,7 +325,7 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
}, },
}, },
}, },
len(MustTestVersions)-2, len(MustTestVersions),
) )
defer scenario.Shutdown() defer scenario.Shutdown()

View file

@ -21,6 +21,8 @@ type TailscaleClient interface {
Login(loginServer, authKey string) error Login(loginServer, authKey string) error
LoginWithURL(loginServer string) (*url.URL, error) LoginWithURL(loginServer string) (*url.URL, error)
Logout() error Logout() error
Up() error
Down() error
IPs() ([]netip.Addr, error) IPs() ([]netip.Addr, error)
FQDN() (string, error) FQDN() (string, error)
Status() (*ipnstate.Status, error) Status() (*ipnstate.Status, error)
@ -30,4 +32,5 @@ type TailscaleClient interface {
Ping(hostnameOrIP string, opts ...tsic.PingOption) error Ping(hostnameOrIP string, opts ...tsic.PingOption) error
Curl(url string, opts ...tsic.CurlOption) (string, error) Curl(url string, opts ...tsic.CurlOption) (string, error)
ID() string ID() string
PrettyPeers() (string, error)
} }

View file

@ -285,6 +285,15 @@ func (t *TailscaleInContainer) hasTLS() bool {
// Shutdown stops and cleans up the Tailscale container. // Shutdown stops and cleans up the Tailscale container.
func (t *TailscaleInContainer) Shutdown() error { func (t *TailscaleInContainer) Shutdown() error {
err := t.SaveLog("/tmp/control")
if err != nil {
log.Printf(
"Failed to save log from %s: %s",
t.hostname,
fmt.Errorf("failed to save log: %w", err),
)
}
return t.pool.Purge(t.container) return t.pool.Purge(t.container)
} }
@ -417,6 +426,44 @@ func (t *TailscaleInContainer) Logout() error {
return nil return nil
} }
// Helper that runs `tailscale up` with no arguments.
func (t *TailscaleInContainer) Up() error {
command := []string{
"tailscale",
"up",
}
if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil {
return fmt.Errorf(
"%s failed to bring tailscale client up (%s): %w",
t.hostname,
strings.Join(command, " "),
err,
)
}
return nil
}
// Helper that runs `tailscale down` with no arguments.
func (t *TailscaleInContainer) Down() error {
command := []string{
"tailscale",
"down",
}
if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil {
return fmt.Errorf(
"%s failed to bring tailscale client down (%s): %w",
t.hostname,
strings.Join(command, " "),
err,
)
}
return nil
}
// IPs returns the netip.Addr of the Tailscale instance. // IPs returns the netip.Addr of the Tailscale instance.
func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
if t.ips != nil && len(t.ips) != 0 { if t.ips != nil && len(t.ips) != 0 {
@ -486,6 +533,34 @@ func (t *TailscaleInContainer) FQDN() (string, error) {
return status.Self.DNSName, nil return status.Self.DNSName, nil
} }
// PrettyPeers returns a formatted-ish table of peers in the client.
func (t *TailscaleInContainer) PrettyPeers() (string, error) {
status, err := t.Status()
if err != nil {
return "", fmt.Errorf("failed to get FQDN: %w", err)
}
str := fmt.Sprintf("Peers of %s\n", t.hostname)
str += "Hostname\tOnline\tLastSeen\n"
peerCount := len(status.Peers())
onlineCount := 0
for _, peerKey := range status.Peers() {
peer := status.Peer[peerKey]
if peer.Online {
onlineCount++
}
str += fmt.Sprintf("%s\t%t\t%s\n", peer.HostName, peer.Online, peer.LastSeen)
}
str += fmt.Sprintf("Peer Count: %d, Online Count: %d\n\n", peerCount, onlineCount)
return str, nil
}
// WaitForNeedsLogin blocks until the Tailscale (tailscaled) instance has // WaitForNeedsLogin blocks until the Tailscale (tailscaled) instance has
// started and needs to be logged into. // started and needs to be logged into.
func (t *TailscaleInContainer) WaitForNeedsLogin() error { func (t *TailscaleInContainer) WaitForNeedsLogin() error {
@ -531,7 +606,7 @@ func (t *TailscaleInContainer) WaitForRunning() error {
} }
// WaitForPeers blocks until N number of peers is present in the // WaitForPeers blocks until N number of peers is present in the
// Peer list of the Tailscale instance. // Peer list of the Tailscale instance and is reporting Online.
func (t *TailscaleInContainer) WaitForPeers(expected int) error { func (t *TailscaleInContainer) WaitForPeers(expected int) error {
return t.pool.Retry(func() error { return t.pool.Retry(func() error {
status, err := t.Status() status, err := t.Status()
@ -547,6 +622,14 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error {
expected, expected,
len(peers), len(peers),
) )
} else {
for _, peerKey := range peers {
peer := status.Peer[peerKey]
if !peer.Online {
return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)
}
}
} }
return nil return nil
@ -738,3 +821,9 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
func (t *TailscaleInContainer) WriteFile(path string, data []byte) error { func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data) return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
} }
// SaveLog saves the current stdout log of the container to a path
// on the host system.
func (t *TailscaleInContainer) SaveLog(path string) error {
return dockertestutil.SaveLog(t.pool, t.container, path)
}

View file

@ -26,6 +26,13 @@ func assertNoErrf(t *testing.T, msg string, err error) {
} }
} }
func assertNotNil(t *testing.T, thing interface{}) {
t.Helper()
if thing == nil {
t.Fatal("got unexpected nil")
}
}
func assertNoErrHeadscaleEnv(t *testing.T, err error) { func assertNoErrHeadscaleEnv(t *testing.T, err error) {
t.Helper() t.Helper()
assertNoErrf(t, "failed to create headscale environment: %s", err) assertNoErrf(t, "failed to create headscale environment: %s", err)
@ -68,13 +75,13 @@ func assertContains(t *testing.T, str, subStr string) {
} }
} }
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int {
t.Helper() t.Helper()
success := 0 success := 0
for _, client := range clients { for _, client := range clients {
for _, addr := range addrs { for _, addr := range addrs {
err := client.Ping(addr) err := client.Ping(addr, opts...)
if err != nil { if err != nil {
t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err) t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err)
} else { } else {