diff --git a/.github/workflows/check-tests.yaml b/.github/workflows/check-tests.yaml index b1b94532..486bed0b 100644 --- a/.github/workflows/check-tests.yaml +++ b/.github/workflows/check-tests.yaml @@ -32,7 +32,7 @@ jobs: - name: Generate and check integration tests if: steps.changed-files.outputs.files == 'true' run: | - nix develop --command bash -c "cd cmd/gh-action-integration-generator/ && go generate" + nix develop --command bash -c "cd .github/workflows && go generate" git diff --exit-code .github/workflows/test-integration.yaml - name: Show missing tests diff --git a/cmd/gh-action-integration-generator/main.go b/.github/workflows/gh-action-integration-generator.go similarity index 77% rename from cmd/gh-action-integration-generator/main.go rename to .github/workflows/gh-action-integration-generator.go index 35e20250..48d96716 100644 --- a/cmd/gh-action-integration-generator/main.go +++ b/.github/workflows/gh-action-integration-generator.go @@ -1,6 +1,6 @@ package main -//go:generate go run ./main.go +//go:generate go run ./gh-action-integration-generator.go import ( "bytes" @@ -42,15 +42,19 @@ func updateYAML(tests []string) { testsForYq := fmt.Sprintf("[%s]", strings.Join(tests, ", ")) yqCommand := fmt.Sprintf( - "yq eval '.jobs.integration-test.strategy.matrix.test = %s' ../../.github/workflows/test-integration.yaml -i", + "yq eval '.jobs.integration-test.strategy.matrix.test = %s' ./test-integration.yaml -i", testsForYq, ) cmd := exec.Command("bash", "-c", yqCommand) - var out bytes.Buffer - cmd.Stdout = &out + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr err := cmd.Run() if err != nil { + log.Printf("stdout: %s", stdout.String()) + log.Printf("stderr: %s", stderr.String()) log.Fatalf("failed to run yq command: %s", err) } diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 83db1c33..45095e03 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -22,10 +22,13 @@ jobs: - TestACLNamedHostsCanReach - TestACLDevice1CanAccessDevice2 - TestPolicyUpdateWhileRunningWithCLIInDatabase + - TestAuthKeyLogoutAndReloginSameUser + - TestAuthKeyLogoutAndReloginNewUser - TestOIDCAuthenticationPingAll - TestOIDCExpireNodesBasedOnTokenExpiry - TestOIDC024UserCreation - TestOIDCAuthenticationWithPKCE + - TestOIDCReloginSameNodeNewUser - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndRelogin - TestUserCommand @@ -50,7 +53,6 @@ jobs: - TestDERPServerWebsocketScenario - TestPingAllByIP - TestPingAllByIPPublicDERP - - TestAuthKeyLogoutAndRelogin - TestEphemeral - TestEphemeralInAlternateTimezone - TestEphemeral2006DeletedTooQuickly diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a56a136..02602313 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,18 @@ ## Next +### BREAKING + +- Authentication flow has been rewritten + [#2374](https://github.com/juanfont/headscale/pull/2374) This change should be + transparent to users with the exception of some buxfixes that has been + discovered and was fixed as part of the rewrite. + - When a node is registered with _a new user_, it will be registered as a new + node ([#2327](https://github.com/juanfont/headscale/issues/2327) and + [#1310](https://github.com/juanfont/headscale/issues/1310)). + - A logged out node logging in with the same user will replace the existing + node. + ### Changes - `oidc.map_legacy_users` is now `false` by default diff --git a/hscontrol/app.go b/hscontrol/app.go index 36f7df5d..c25ca9fc 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -521,25 +521,28 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. // Maybe we should attempt a new in memory state and not go via the DB? -func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { +// A bool is returned indicating if a full update was sent to all nodes +func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) (bool, error) { nodes, err := db.ListNodes() if err != nil { - return err + return false, err } - changed, err := polMan.SetNodes(nodes) + filterChanged, err := polMan.SetNodes(nodes) if err != nil { - return err + return false, err } - if changed { + if filterChanged { ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") notif.NotifyAll(ctx, types.StateUpdate{ Type: types.StateFullUpdate, }) + + return true, nil } - return nil + return false, nil } // Serve launches the HTTP and gRPC server service Headscale and the API. diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 9e22660d..3fa5fa4b 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -2,7 +2,6 @@ package hscontrol import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -13,7 +12,6 @@ import ( "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -25,730 +23,244 @@ type AuthProvider interface { AuthURL(types.RegistrationID) string } -func logAuthFunc( - registerRequest tailcfg.RegisterRequest, +func (h *Headscale) handleRegister( + ctx context.Context, + regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, - registrationId types.RegistrationID, -) (func(string), func(string), func(error, string)) { - return func(msg string) { - log.Info(). - Caller(). - Str("registration_id", registrationId.String()). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("node", registerRequest.Hostinfo.Hostname). - Str("followup", registerRequest.Followup). - Time("expiry", registerRequest.Expiry). - Msg(msg) - }, - func(msg string) { - log.Trace(). - Caller(). - Str("registration_id", registrationId.String()). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("node", registerRequest.Hostinfo.Hostname). - Str("followup", registerRequest.Followup). - Time("expiry", registerRequest.Expiry). - Msg(msg) - }, - func(err error, msg string) { - log.Error(). - Caller(). - Str("registration_id", registrationId.String()). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("node", registerRequest.Hostinfo.Hostname). - Str("followup", registerRequest.Followup). - Time("expiry", registerRequest.Expiry). - Err(err). - Msg(msg) +) (*tailcfg.RegisterResponse, error) { + node, err := h.db.GetNodeByNodeKey(regReq.NodeKey) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("looking up node in database: %w", err) + } + + if node != nil { + resp, err := h.handleExistingNode(node, regReq, machineKey) + if err != nil { + return nil, fmt.Errorf("handling existing node: %w", err) } + + return resp, nil + } + + if regReq.Followup != "" { + // TODO(kradalby): Does this need to return an error of some sort? + // Maybe if the registration fails down the line it can be sent + // on the channel and returned here? + h.waitForFollowup(ctx, regReq) + } + + if regReq.Auth != nil && regReq.Auth.AuthKey != "" { + resp, err := h.handleRegisterWithAuthKey(regReq, machineKey) + if err != nil { + return nil, fmt.Errorf("handling register with auth key: %w", err) + } + + return resp, nil + } + + resp, err := h.handleRegisterInteractive(regReq, machineKey) + if err != nil { + return nil, fmt.Errorf("handling register interactive: %w", err) + } + + return resp, nil +} + +func (h *Headscale) handleExistingNode( + node *types.Node, + regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + if node.MachineKey != machineKey { + return nil, errors.New("node already exists with different machine key") + } + + expired := node.IsExpired() + if !expired && !regReq.Expiry.IsZero() { + requestExpiry := regReq.Expiry + + // The client is trying to extend their key, this is not allowed. + if requestExpiry.After(time.Now()) { + return nil, errors.New("extending key is not allowed") + } + + // If the request expiry is in the past, we consider it a logout. + if requestExpiry.Before(time.Now()) { + if node.IsEphemeral() { + changedNodes, err := h.db.DeleteNode(node, h.nodeNotifier.LikelyConnectedMap()) + if err != nil { + return nil, fmt.Errorf("deleting ephemeral node: %w", err) + } + + ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") + h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []types.NodeID{node.ID}, + }) + if changedNodes != nil { + h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: changedNodes, + }) + } + } + + expired = true + } + + err := h.db.NodeSetExpiry(node.ID, requestExpiry) + if err != nil { + return nil, fmt.Errorf("setting node expiry: %w", err) + } + + ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") + h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, requestExpiry), node.ID) + } + + return &tailcfg.RegisterResponse{ + // TODO(kradalby): Only send for user-owned nodes + // and not tagged nodes when tags is working. + User: *node.User.TailscaleUser(), + Login: *node.User.TailscaleLogin(), + NodeKeyExpired: expired, + + // Headscale does not implement the concept of machine authorization + // so we always return true here. + // Revisit this if #2176 gets implemented. + MachineAuthorized: true, + }, nil } func (h *Headscale) waitForFollowup( - req *http.Request, + ctx context.Context, regReq tailcfg.RegisterRequest, - logTrace func(string), ) { - logTrace("register request is a followup") fu, err := url.Parse(regReq.Followup) if err != nil { - logTrace("failed to parse followup URL") return } followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) if err != nil { - logTrace("followup URL does not contains a valid registration ID") return } - logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg)) - if reg, ok := h.registrationCache.Get(followupReg); ok { - logTrace("Node is waiting for interactive login") - select { - case <-req.Context().Done(): - logTrace("node went away before it was registered") + case <-ctx.Done(): return case <-reg.Registered: - logTrace("node has successfully registered") return } } } -// handleRegister is the logic for registering a client. -func (h *Headscale) handleRegister( - writer http.ResponseWriter, - req *http.Request, +func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, -) { - registrationId, err := types.NewRegistrationID() +) (*tailcfg.RegisterResponse, error) { + pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to generate registration ID") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return + return nil, fmt.Errorf("invalid pre auth key: %w", err) } - logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId) - now := time.Now().UTC() - logTrace("handleRegister called, looking up machine in DB") + nodeToRegister := types.Node{ + Hostname: regReq.Hostinfo.Hostname, + UserID: pak.User.ID, + User: pak.User, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + Hostinfo: regReq.Hostinfo, + LastSeen: ptr.To(time.Now()), + RegisterMethod: util.RegisterMethodAuthKey, - // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs - // key refreshes. This will allow us to remove the machineKey from the registration request. - node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) - logTrace("handleRegister database lookup has returned") - if errors.Is(err, gorm.ErrRecordNotFound) { - // If the node has AuthKey set, handle registration via PreAuthKeys - if regReq.Auth != nil && regReq.Auth.AuthKey != "" { - h.handleAuthKey(writer, regReq, machineKey) - - return - } - - // Check if the node is waiting for interactive login. - if regReq.Followup != "" { - h.waitForFollowup(req, regReq, logTrace) - return - } - - logInfo("Node not found in database, creating new") - - // The node did not have a key to authenticate, which means - // that we rely on a method that calls back some how (OpenID or CLI) - // We create the node and then keep it around until a callback - // happens - newNode := types.RegisterNode{ - Node: types.Node{ - MachineKey: machineKey, - Hostname: regReq.Hostinfo.Hostname, - NodeKey: regReq.NodeKey, - LastSeen: &now, - Expiry: &time.Time{}, - }, - Registered: make(chan struct{}), - } - - if !regReq.Expiry.IsZero() { - logTrace("Non-zero expiry time requested") - newNode.Node.Expiry = ®Req.Expiry - } - - h.registrationCache.Set( - registrationId, - newNode, - ) - - h.handleNewNode(writer, regReq, registrationId) - - return + // TODO(kradalby): This should not be set on the node, + // they should be looked up through the key, which is + // attached to the node. + ForcedTags: pak.Proto().GetAclTags(), + AuthKey: pak, + AuthKeyID: &pak.ID, } - // The node is already in the DB. This could mean one of the following: - // - The node is authenticated and ready to /map - // - We are doing a key refresh - // - The node is logged out (or expired) and pending to be authorized. TODO(juan): We need to keep alive the connection here - if node != nil { - // (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021, - // due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054 - // So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it. - if err != nil || node.MachineKey.IsZero() { - if err := h.db.NodeSetMachineKey(node, machineKey); err != nil { - log.Error(). - Caller(). - Str("func", "RegistrationHandler"). - Str("node", node.Hostname). - Err(err). - Msg("Error saving machine key to database") - - return - } - } - - // If the NodeKey stored in headscale is the same as the key presented in a registration - // request, then we have a node that is either: - // - Trying to log out (sending a expiry in the past) - // - A valid, registered node, looking for /map - // - Expired node wanting to reauthenticate - if node.NodeKey.String() == regReq.NodeKey.String() { - // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) - // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 - if !regReq.Expiry.IsZero() && - regReq.Expiry.UTC().Before(now) { - h.handleNodeLogOut(writer, *node) - - return - } - - // If node is not expired, and it is register, we have a already accepted this node, - // let it proceed with a valid registration - if !node.IsExpired() { - h.handleNodeWithValidRegistration(writer, *node) - - return - } - } - - // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if node.NodeKey.String() == regReq.OldNodeKey.String() && - !node.IsExpired() { - h.handleNodeKeyRefresh( - writer, - regReq, - *node, - ) - - return - } - - // When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed - if node.NodeKey.String() != regReq.NodeKey.String() && - regReq.OldNodeKey.IsZero() && !node.IsExpired() { - h.handleNodeKeyRefresh( - writer, - regReq, - *node, - ) - - return - } - - if regReq.Followup != "" { - h.waitForFollowup(req, regReq, logTrace) - return - } - - // The node has expired or it is logged out - h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId) - - // TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use - node.Expiry = &time.Time{} - - // TODO(kradalby): do we need to rethink this as part of authflow? - // If we are here it means the client needs to be reauthorized, - // we need to make sure the NodeKey matches the one in the request - // TODO(juan): What happens when using fast user switching between two - // headscale-managed tailnets? - node.NodeKey = regReq.NodeKey - h.registrationCache.Set( - registrationId, - types.RegisterNode{ - Node: *node, - Registered: make(chan struct{}), - }, - ) - - return + if !regReq.Expiry.IsZero() { + nodeToRegister.Expiry = ®Req.Expiry } -} -// handleAuthKey contains the logic to manage auth key client registration -// When using Noise, the machineKey is Zero. -func (h *Headscale) handleAuthKey( - writer http.ResponseWriter, - registerRequest tailcfg.RegisterRequest, - machineKey key.MachinePublic, -) { - log.Debug(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) - resp := tailcfg.RegisterResponse{} - - pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey) + ipv4, ipv6, err := h.ipAlloc.Next() if err != nil { - log.Error(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Failed authentication via AuthKey") - resp.MachineAuthorized = false - - respBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusUnauthorized) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - log.Error(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Msg("Failed authentication via AuthKey") - - return + return nil, fmt.Errorf("allocating IPs: %w", err) } - log.Debug(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Msg("Authentication key was valid, proceeding to acquire IP addresses") - - nodeKey := registerRequest.NodeKey - - // retrieve node information if it exist - // The error is not important, because if it does not - // exist, then this is a new node and we will move - // on to registration. - // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs - // key refreshes. This will allow us to remove the machineKey from the registration request. - node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) - if node != nil { - log.Trace(). - Caller(). - Str("node", node.Hostname). - Msg("node was already registered before, refreshing with new auth key") - - node.NodeKey = nodeKey - if pak.ID != 0 { - node.AuthKeyID = ptr.To(pak.ID) - } - - node.Expiry = ®isterRequest.Expiry - node.User = pak.User - node.UserID = pak.UserID - err := h.db.DB.Save(node).Error - if err != nil { - log.Error(). - Caller(). - Str("node", node.Hostname). - Err(err). - Msg("failed to save node after logging in with auth key") - - return - } - - aclTags := pak.Proto().GetAclTags() - if len(aclTags) > 0 { - // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.db.SetTags(node.ID, aclTags) - if err != nil { - log.Error(). - Caller(). - Str("node", node.Hostname). - Strs("aclTags", aclTags). - Err(err). - Msg("Failed to set tags after refreshing node") - - return - } - } - - ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{Type: types.StatePeerChanged, ChangeNodes: []types.NodeID{node.ID}}) - } else { - now := time.Now().UTC() - - nodeToRegister := types.Node{ - Hostname: registerRequest.Hostinfo.Hostname, - UserID: pak.User.ID, - User: pak.User, - MachineKey: machineKey, - RegisterMethod: util.RegisterMethodAuthKey, - Expiry: ®isterRequest.Expiry, - NodeKey: nodeKey, - LastSeen: &now, - ForcedTags: pak.Proto().GetAclTags(), - } - - ipv4, ipv6, err := h.ipAlloc.Next() - if err != nil { - log.Error(). - Caller(). - Str("func", "RegistrationHandler"). - Str("hostinfo.name", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("failed to allocate IP ") - - return - } - - pakID := uint(pak.ID) - if pakID != 0 { - nodeToRegister.AuthKeyID = ptr.To(pak.ID) - } - node, err = h.db.RegisterNode( + node, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.Node, error) { + node, err := db.RegisterNode(tx, nodeToRegister, ipv4, ipv6, ) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("could not register node") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return + return nil, fmt.Errorf("registering node: %w", err) } - err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier) - if err != nil { - http.Error(writer, "Internal server error", http.StatusInternalServerError) - return + if !pak.Reusable { + err = db.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } } - } - err = h.db.Write(func(tx *gorm.DB) error { - return db.UsePreAuthKey(tx, pak) + return node, nil }) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to use pre-auth key") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return + return nil, err } - resp.MachineAuthorized = true - resp.User = *pak.User.TailscaleUser() - // Provide LoginName when registering with pre-auth key - // Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* - resp.Login = *pak.User.TailscaleLogin() - - respBody, err := json.Marshal(resp) + updateSent, err := nodesChangedHook(h.db, h.polMan, h.nodeNotifier) if err != nil { - log.Error(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - return + return nil, fmt.Errorf("nodes changed hook: %w", err) } - log.Info(). - Str("node", registerRequest.Hostinfo.Hostname). - Msg("Successfully authenticated via AuthKey") + if !updateSent { + ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname) + h.nodeNotifier.NotifyAll(ctx, types.StateUpdatePeerAdded(node.ID)) + } + + return &tailcfg.RegisterResponse{ + MachineAuthorized: true, + NodeKeyExpired: node.IsExpired(), + User: *pak.User.TailscaleUser(), + Login: *pak.User.TailscaleLogin(), + }, nil } -// handleNewNode returns the authorisation URL to the client based on what type -// of registration headscale is configured with. -// This url is then showed to the user by the local Tailscale client. -func (h *Headscale) handleNewNode( - writer http.ResponseWriter, - registerRequest tailcfg.RegisterRequest, - registrationId types.RegistrationID, -) { - logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId) - - resp := tailcfg.RegisterResponse{} - - // The node registration is new, redirect the client to the registration URL - logTrace("The node is new, sending auth url") - - resp.AuthURL = h.authProvider.AuthURL(registrationId) - - respBody, err := json.Marshal(resp) - if err != nil { - logErr(err, "Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - logErr(err, "Failed to write response") - } - - logInfo(fmt.Sprintf("Successfully sent auth url: %s", resp.AuthURL)) -} - -func (h *Headscale) handleNodeLogOut( - writer http.ResponseWriter, - node types.Node, -) { - resp := tailcfg.RegisterResponse{} - - log.Info(). - Str("node", node.Hostname). - Msg("Client requested logout") - - now := time.Now() - err := h.db.NodeSetExpiry(node.ID, now) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to expire node") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na") - h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID) - - resp.AuthURL = "" - resp.MachineAuthorized = false - resp.NodeKeyExpired = true - resp.User = *node.User.TailscaleUser() - respBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - - return - } - - if node.IsEphemeral() { - changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.LikelyConnectedMap()) - if err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Msg("Cannot delete ephemeral node from the database") - } - - ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na") - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: []types.NodeID{node.ID}, - }) - if changedNodes != nil { - h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: changedNodes, - }) - } - - return - } - - log.Info(). - Caller(). - Str("node", node.Hostname). - Msg("Successfully logged out") -} - -func (h *Headscale) handleNodeWithValidRegistration( - writer http.ResponseWriter, - node types.Node, -) { - resp := tailcfg.RegisterResponse{} - - // The node registration is valid, respond with redirect to /map - log.Debug(). - Caller(). - Str("node", node.Hostname). - Msg("Client is registered and we have the current NodeKey. All clear to /map") - - resp.AuthURL = "" - resp.MachineAuthorized = true - resp.User = *node.User.TailscaleUser() - resp.Login = *node.User.TailscaleLogin() - - respBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - log.Info(). - Caller(). - Str("node", node.Hostname). - Msg("Node successfully authorized") -} - -func (h *Headscale) handleNodeKeyRefresh( - writer http.ResponseWriter, - registerRequest tailcfg.RegisterRequest, - node types.Node, -) { - resp := tailcfg.RegisterResponse{} - - log.Info(). - Caller(). - Str("node", node.Hostname). - Msg("We have the OldNodeKey in the database. This is a key refresh") - - err := h.db.Write(func(tx *gorm.DB) error { - return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey) - }) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to update machine key in the database") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - resp.AuthURL = "" - resp.User = *node.User.TailscaleUser() - respBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - log.Info(). - Caller(). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("old_node_key", registerRequest.OldNodeKey.ShortString()). - Str("node", node.Hostname). - Msg("Node key successfully refreshed") -} - -func (h *Headscale) handleNodeExpiredOrLoggedOut( - writer http.ResponseWriter, +func (h *Headscale) handleRegisterInteractive( regReq tailcfg.RegisterRequest, - node types.Node, machineKey key.MachinePublic, - registrationId types.RegistrationID, -) { - resp := tailcfg.RegisterResponse{} - - if regReq.Auth != nil && regReq.Auth.AuthKey != "" { - h.handleAuthKey(writer, regReq, machineKey) - - return - } - - // The client has registered before, but has expired or logged out - log.Trace(). - Caller(). - Str("node", node.Hostname). - Str("registration_id", registrationId.String()). - Str("node_key", regReq.NodeKey.ShortString()). - Str("node_key_old", regReq.OldNodeKey.ShortString()). - Msg("Node registration has expired or logged out. Sending a auth url to register") - - resp.AuthURL = h.authProvider.AuthURL(registrationId) - - respBody, err := json.Marshal(resp) +) (*tailcfg.RegisterResponse, error) { + registrationId, err := types.NewRegistrationID() if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return + return nil, fmt.Errorf("generating registration ID: %w", err) } - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + newNode := types.RegisterNode{ + Node: types.Node{ + Hostname: regReq.Hostinfo.Hostname, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + Hostinfo: regReq.Hostinfo, + LastSeen: ptr.To(time.Now()), + }, + Registered: make(chan struct{}), } - log.Trace(). - Caller(). - Str("registration_id", registrationId.String()). - Str("node_key", regReq.NodeKey.ShortString()). - Str("node_key_old", regReq.OldNodeKey.ShortString()). - Str("node", node.Hostname). - Msg("Node logged out. Sent AuthURL for reauthentication") + if !regReq.Expiry.IsZero() { + newNode.Node.Expiry = ®Req.Expiry + } + + h.registrationCache.Set( + registrationId, + newNode, + ) + + return &tailcfg.RegisterResponse{ + AuthURL: h.authProvider.AuthURL(registrationId), + }, nil } diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index f722d9ab..11a13056 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -182,38 +182,6 @@ func GetNodeByNodeKey( return &mach, nil } -func (hsdb *HSDatabase) GetNodeByAnyKey( - machineKey key.MachinePublic, - nodeKey key.NodePublic, - oldNodeKey key.NodePublic, -) (*types.Node, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey) - }) -} - -// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. -// TODO(kradalby): see if we can remove this. -func GetNodeByAnyKey( - tx *gorm.DB, - machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, -) (*types.Node, error) { - node := types.Node{} - if result := tx. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - First(&node, "machine_key = ? OR node_key = ? OR node_key = ?", - machineKey.String(), - nodeKey.String(), - oldNodeKey.String()); result.Error != nil { - return nil, result.Error - } - - return &node, nil -} - func (hsdb *HSDatabase) SetTags( nodeID types.NodeID, tags []string, @@ -437,6 +405,18 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad Str("user", node.User.Username()). Msg("Registering node") + // If the a new node is registered with the same machine key, to the same user, + // update the existing node. + // If the same node is registered again, but to a new user, then that is considered + // a new node. + oldNode, _ := GetNodeByMachineKey(tx, node.MachineKey) + if oldNode != nil && oldNode.UserID == node.UserID { + node.ID = oldNode.ID + node.GivenName = oldNode.GivenName + ipv4 = oldNode.IPv4 + ipv6 = oldNode.IPv6 + } + // If the node exists and it already has IP(s), we just save it // so we store the node.Expire and node.Nodekey that has been set when // adding it to the registrationCache diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 270fd91b..7dc58819 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -84,37 +84,6 @@ func (s *Suite) TestGetNodeByID(c *check.C) { c.Assert(err, check.IsNil) } -func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - oldNodeKey := key.NewNode() - - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) - c.Assert(err, check.IsNil) -} - func (s *Suite) TestHardDeleteNode(c *check.C) { user, err := db.CreateUser(types.User{Name: "test"}) c.Assert(err, check.IsNil) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 7b1c6581..51fb9869 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -256,10 +256,17 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } - err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + updateSent, err := nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) if err != nil { return nil, fmt.Errorf("updating resources using node: %w", err) } + if !updateSent { + ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname) + api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: []types.NodeID{node.ID}, + }) + } return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } diff --git a/hscontrol/noise.go b/hscontrol/noise.go index b4e90f31..318cf5e4 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -156,7 +156,12 @@ func isSupportedVersion(version tailcfg.CapabilityVersion) bool { return version >= MinimumCapVersion } -func rejectUnsupported(writer http.ResponseWriter, version tailcfg.CapabilityVersion, mkey key.MachinePublic, nkey key.NodePublic) bool { +func rejectUnsupported( + writer http.ResponseWriter, + version tailcfg.CapabilityVersion, + mkey key.MachinePublic, + nkey key.NodePublic, +) bool { // Reject unsupported versions if !isSupportedVersion(version) { log.Error(). @@ -204,11 +209,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( ns.nodeKey = mapRequest.NodeKey - node, err := ns.headscale.db.GetNodeByAnyKey( - ns.conn.Peer(), - mapRequest.NodeKey, - key.NodePublic{}, - ) + node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey) if err != nil { httpError(writer, err, "Internal error", http.StatusInternalServerError) return @@ -234,12 +235,38 @@ func (ns *noiseServer) NoiseRegistrationHandler( return } - body, _ := io.ReadAll(req.Body) - var registerRequest tailcfg.RegisterRequest - if err := json.Unmarshal(body, ®isterRequest); err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + registerRequest, registerResponse, err := func() (*tailcfg.RegisterRequest, []byte, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, nil, err + } + var registerRequest tailcfg.RegisterRequest + if err := json.Unmarshal(body, ®isterRequest); err != nil { + return nil, nil, err + } - return + ns.nodeKey = registerRequest.NodeKey + + resp, err := ns.headscale.handleRegister(req.Context(), registerRequest, ns.conn.Peer()) + // TODO(kradalby): Here we could have two error types, one that is surfaced to the client + // and one that returns 500. + if err != nil { + return nil, nil, err + } + + respBody, err := json.Marshal(resp) + if err != nil { + return nil, nil, err + } + + return ®isterRequest, respBody, nil + }() + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Error handling registration") + http.Error(writer, "Internal server error", http.StatusInternalServerError) } // Reject unsupported versions @@ -247,7 +274,13 @@ func (ns *noiseServer) NoiseRegistrationHandler( return } - ns.nodeKey = registerRequest.NodeKey - - ns.headscale.handleRegister(writer, req, registerRequest, ns.conn.Peer()) + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err = writer.Write(registerResponse) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 8364dee1..42032f79 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -512,24 +512,21 @@ func (a *AuthProviderOIDC) handleRegistrationID( // Send an update to all nodes if this is a new node that they need to know // about. // If this is a refresh, just send new expiry updates. - if newNode { - err = nodesChangedHook(a.db, a.polMan, a.notifier) - if err != nil { - return false, fmt.Errorf("updating resources using node: %w", err) - } - } else { + updateSent, err := nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return false, fmt.Errorf("updating resources using node: %w", err) + } + + if !updateSent { ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname) a.notifier.NotifyByNodeID( ctx, - types.StateUpdate{ - Type: types.StateSelfUpdate, - ChangeNodes: []types.NodeID{node.ID}, - }, + types.StateSelf(node.ID), node.ID, ) ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname) - a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID) + a.notifier.NotifyWithIgnore(ctx, types.StateUpdatePeerAdded(node.ID), node.ID) } return newNode, nil diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 3b6c1be1..e5cef8fd 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -102,6 +102,20 @@ func (su *StateUpdate) Empty() bool { return false } +func StateSelf(nodeID NodeID) StateUpdate { + return StateUpdate{ + Type: StateSelfUpdate, + ChangeNodes: []NodeID{nodeID}, + } +} + +func StateUpdatePeerAdded(nodeIDs ...NodeID) StateUpdate { + return StateUpdate{ + Type: StatePeerChanged, + ChangeNodes: nodeIDs, + } +} + func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { return StateUpdate{ Type: StatePeerChangedPatch, diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go new file mode 100644 index 00000000..d1c2c5d1 --- /dev/null +++ b/integration/auth_key_test.go @@ -0,0 +1,230 @@ +package integration + +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + for _, https := range []bool{true, false} { + t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{ + "user1": len(MustTestVersions), + "user2": len(MustTestVersions), + } + + opts := []hsic.Option{hsic.WithTestName("pingallbyip")} + if https { + opts = append(opts, []hsic.Option{ + hsic.WithTLS(), + }...) + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + // assertClientsState(t, allClients) + + clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { + ips, err := client.IPs() + if err != nil { + t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) + } + clientIPs[client] = ips + } + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + listNodes, err := headscale.ListNodes() + assert.Equal(t, len(listNodes), len(allClients)) + nodeCountBeforeLogout := len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + + for _, client := range allClients { + err := client.Logout() + if err != nil { + t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) + } + } + + err = scenario.WaitForTailscaleLogout() + assertNoErrLogout(t, err) + + t.Logf("all clients logged out") + + // if the server is not running with HTTPS, we have to wait a bit before + // reconnection as the newest Tailscale client has a measure that will only + // reconnect over HTTPS if they saw a noise connection previously. + // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 + // https://github.com/juanfont/headscale/issues/2164 + if !https { + time.Sleep(5 * time.Minute) + } + + for userName := range spec { + key, err := scenario.CreatePreAuthKey(userName, true, false) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } + } + + listNodes, err = headscale.ListNodes() + require.Equal(t, nodeCountBeforeLogout, len(listNodes)) + + allIps, err := scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + for _, client := range allClients { + ips, err := client.IPs() + if err != nil { + t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) + } + + // lets check if the IPs are the same + if len(ips) != len(clientIPs[client]) { + t.Fatalf("IPs changed for client %s", client.Hostname()) + } + + for _, ip := range ips { + found := false + for _, oldIP := range clientIPs[client] { + if ip == oldIP { + found = true + + break + } + } + + if !found { + t.Fatalf( + "IPs changed for client %s. Used to be %v now %v", + client.Hostname(), + clientIPs[client], + ips, + ) + } + } + } + }) + } +} + +// This test will first log in two sets of nodes to two sets of users, then +// it will log out all users from user2 and log them in as user1. +// This should leave us with all nodes connected to user1, while user2 +// still has nodes, but they are not connected. +func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{ + "user1": len(MustTestVersions), + "user2": len(MustTestVersions), + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, + hsic.WithTestName("keyrelognewuser"), + hsic.WithTLS(), + ) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + // assertClientsState(t, allClients) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + listNodes, err := headscale.ListNodes() + assert.Equal(t, len(listNodes), len(allClients)) + nodeCountBeforeLogout := len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + + for _, client := range allClients { + err := client.Logout() + if err != nil { + t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) + } + } + + err = scenario.WaitForTailscaleLogout() + assertNoErrLogout(t, err) + + t.Logf("all clients logged out") + + // Create a new authkey for user1, to be used for all clients + key, err := scenario.CreatePreAuthKey("user1", true, false) + if err != nil { + t.Fatalf("failed to create pre-auth key for user1: %s", err) + } + + // Log in all clients as user1, iterating over the spec only returns the + // clients, not the usernames. + for userName := range spec { + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } + } + + user1Nodes, err := headscale.ListNodes("user1") + assertNoErr(t, err) + assert.Len(t, user1Nodes, len(allClients)) + + // Validate that all the old nodes are still present with user2 + user2Nodes, err := headscale.ListNodes("user2") + assertNoErr(t, err) + assert.Len(t, user2Nodes, len(allClients)/2) + + for _, client := range allClients { + status, err := client.Status() + if err != nil { + t.Fatalf("failed to get status for client %s: %s", client.Hostname(), err) + } + + assert.Equal(t, "user1@test.no", status.User[status.Self.UserID].LoginName) + } +} diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 22790f91..f75539be 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -116,20 +116,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { headscale, err := scenario.Headscale() assertNoErr(t, err) - var listUsers []v1.User - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "users", - "list", - "--output", - "json", - }, - &listUsers, - ) + listUsers, err := headscale.ListUsers() assertNoErr(t, err) - want := []v1.User{ + want := []*v1.User{ { Id: 1, Name: "user1", @@ -249,7 +239,7 @@ func TestOIDC024UserCreation(t *testing.T) { emailVerified bool cliUsers []string oidcUsers []string - want func(iss string) []v1.User + want func(iss string) []*v1.User }{ { name: "no-migration-verified-email", @@ -259,8 +249,8 @@ func TestOIDC024UserCreation(t *testing.T) { emailVerified: true, cliUsers: []string{"user1", "user2"}, oidcUsers: []string{"user1", "user2"}, - want: func(iss string) []v1.User { - return []v1.User{ + want: func(iss string) []*v1.User { + return []*v1.User{ { Id: 1, Name: "user1", @@ -296,8 +286,8 @@ func TestOIDC024UserCreation(t *testing.T) { emailVerified: false, cliUsers: []string{"user1", "user2"}, oidcUsers: []string{"user1", "user2"}, - want: func(iss string) []v1.User { - return []v1.User{ + want: func(iss string) []*v1.User { + return []*v1.User{ { Id: 1, Name: "user1", @@ -332,8 +322,8 @@ func TestOIDC024UserCreation(t *testing.T) { emailVerified: true, cliUsers: []string{"user1", "user2"}, oidcUsers: []string{"user1", "user2"}, - want: func(iss string) []v1.User { - return []v1.User{ + want: func(iss string) []*v1.User { + return []*v1.User{ { Id: 1, Name: "user1", @@ -360,8 +350,8 @@ func TestOIDC024UserCreation(t *testing.T) { emailVerified: false, cliUsers: []string{"user1", "user2"}, oidcUsers: []string{"user1", "user2"}, - want: func(iss string) []v1.User { - return []v1.User{ + want: func(iss string) []*v1.User { + return []*v1.User{ { Id: 1, Name: "user1", @@ -396,8 +386,8 @@ func TestOIDC024UserCreation(t *testing.T) { emailVerified: true, cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, oidcUsers: []string{"user1", "user2"}, - want: func(iss string) []v1.User { - return []v1.User{ + want: func(iss string) []*v1.User { + return []*v1.User{ // Hmm I think we will have to overwrite the initial name here // createuser with "user1.headscale.net", but oidc with "user1" { @@ -426,8 +416,8 @@ func TestOIDC024UserCreation(t *testing.T) { emailVerified: false, cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, oidcUsers: []string{"user1", "user2"}, - want: func(iss string) []v1.User { - return []v1.User{ + want: func(iss string) []*v1.User { + return []*v1.User{ { Id: 1, Name: "user1.headscale.net", @@ -509,17 +499,7 @@ func TestOIDC024UserCreation(t *testing.T) { want := tt.want(oidcConfig.Issuer) - var listUsers []v1.User - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "users", - "list", - "--output", - "json", - }, - &listUsers, - ) + listUsers, err := headscale.ListUsers() assertNoErr(t, err) sort.Slice(listUsers, func(i, j int) bool { @@ -587,23 +567,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { err = scenario.WaitForTailscaleSync() assertNoErrSync(t, err) - // Verify PKCE was used in authentication - headscale, err := scenario.Headscale() - assertNoErr(t, err) - - var listUsers []v1.User - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "users", - "list", - "--output", - "json", - }, - &listUsers, - ) - assertNoErr(t, err) - allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) @@ -612,6 +575,228 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } +func TestOIDCReloginSameNodeNewUser(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + baseScenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + + scenario := AuthOIDCScenario{ + Scenario: baseScenario, + } + defer scenario.ShutdownAssertNoPanics(t) + + // Create no nodes and no users + spec := map[string]int{} + + // First login creates the first OIDC user + // Second login logs in the same node, which creates a new node + // Third login logs in the same node back into the original user + mockusers := []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", true), + oidcMockUser("user1", true), + } + + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) + assertNoErrf(t, "failed to run mock OIDC server: %s", err) + // defer scenario.mockOIDC.Close() + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, + "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // TODO(kradalby): Remove when strip_email_domain is removed + // after #2170 is cleaned up + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + } + + err = scenario.CreateHeadscaleEnv( + spec, + hsic.WithTestName("oidcauthrelog"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + hsic.WithEmbeddedDERPServerOnly(), + ) + assertNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + listUsers, err := headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 0) + + ts, err := scenario.CreateTailscaleNode("unstable") + assertNoErr(t, err) + + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + assertNoErr(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + assertNoErr(t, err) + + listUsers, err = headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 1) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user1", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + listNodes, err := headscale.ListNodes() + assertNoErr(t, err) + assert.Len(t, listNodes, 1) + + // Log out user1 and log in user2, this should create a new node + // for user2, the node should have the same machine key and + // a new node key. + err = ts.Logout() + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + // TODO(kradalby): Not sure why we need to logout twice, but it fails and + // logs in immediately after the first logout and I cannot reproduce it + // manually. + err = ts.Logout() + assertNoErr(t, err) + + u, err = ts.LoginWithURL(headscale.GetEndpoint()) + assertNoErr(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + assertNoErr(t, err) + + listUsers, err = headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 2) + wantUsers = []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user1", + }, + { + Id: 2, + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user2", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + listNodesAfterNewUserLogin, err := headscale.ListNodes() + assertNoErr(t, err) + assert.Len(t, listNodesAfterNewUserLogin, 2) + + // Machine key is the same as the "machine" has not changed, + // but Node key is not as it is a new node + assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) + assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) + assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey) + + // Log out user2, and log into user1, no new node should be created, + // the node should now "become" node1 again + err = ts.Logout() + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + // TODO(kradalby): Not sure why we need to logout twice, but it fails and + // logs in immediately after the first logout and I cannot reproduce it + // manually. + err = ts.Logout() + assertNoErr(t, err) + + u, err = ts.LoginWithURL(headscale.GetEndpoint()) + assertNoErr(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + assertNoErr(t, err) + + listUsers, err = headscale.ListUsers() + assertNoErr(t, err) + assert.Len(t, listUsers, 2) + wantUsers = []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user1", + }, + { + Id: 2, + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user2", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + listNodesAfterLoggingBackIn, err := headscale.ListNodes() + assertNoErr(t, err) + assert.Len(t, listNodesAfterLoggingBackIn, 2) + + // Validate that the machine we had when we logged in the first time, has the same + // machine key, but a different ID than the newly logged in version of the same + // machine. + assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey) + assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey) + assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id) + assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey) + assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id) + assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id) + + // Even tho we are logging in again with the same user, the previous key has been expired + // and a new one has been generated. The node entry in the database should be the same + // as the user + machinekey still matches. + assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey) + assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey) + assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id) + + // The "logged back in" machine should have the same machinekey but a different nodekey + // than the version logged in with a different user. + assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey) + assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey) +} + func (s *AuthOIDCScenario) CreateHeadscaleEnv( users map[string]int, opts ...hsic.Option, diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 72703e95..acc96cec 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -11,6 +11,8 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var errParseAuthPage = errors.New("failed to parse auth page") @@ -106,6 +108,14 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + listNodes, err := headscale.ListNodes() + assert.Equal(t, len(listNodes), len(allClients)) + nodeCountBeforeLogout := len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + clientIPs := make(map[TailscaleClient][]netip.Addr) for _, client := range allClients { ips, err := client.IPs() @@ -127,9 +137,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients logged out") - headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) - for userName := range spec { err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) if err != nil { @@ -139,9 +146,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients logged in again") - allClients, err = scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - allIps, err = scenario.ListTailscaleClientsIPs() assertNoErrListClientIPs(t, err) @@ -152,6 +156,10 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { success = pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + listNodes, err = headscale.ListNodes() + require.Equal(t, nodeCountBeforeLogout, len(listNodes)) + t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) + for _, client := range allClients { ips, err := client.IPs() if err != nil { diff --git a/integration/cli_test.go b/integration/cli_test.go index 59d39278..e5e93c3c 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -606,22 +606,12 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String()) } - var listNodes []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listNodes, - ) + listNodes, err := headscale.ListNodes() assert.Nil(t, err) - assert.Len(t, listNodes, 1) + assert.Len(t, listNodes, 2) - assert.Equal(t, "user2", listNodes[0].GetUser().GetName()) + assert.Equal(t, "user1", listNodes[0].GetUser().GetName()) + assert.Equal(t, "user2", listNodes[1].GetUser().GetName()) } func TestApiKeyCommand(t *testing.T) { diff --git a/integration/control.go b/integration/control.go index b5699577..8ec6bad6 100644 --- a/integration/control.go +++ b/integration/control.go @@ -17,7 +17,8 @@ type ControlServer interface { WaitForRunning() error CreateUser(user string) error CreateAuthKey(user string, reusable bool, ephemeral bool) (*v1.PreAuthKey, error) - ListNodesInUser(user string) ([]*v1.Node, error) + ListNodes(users ...string) ([]*v1.Node, error) + ListUsers() ([]*v1.User, error) GetCert() []byte GetHostname() string GetIP() string diff --git a/integration/general_test.go b/integration/general_test.go index eb26cea9..3bdce469 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -105,137 +105,6 @@ func TestPingAllByIPPublicDERP(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } -func TestAuthKeyLogoutAndRelogin(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - for _, https := range []bool{true, false} { - t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { - scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - - opts := []hsic.Option{hsic.WithTestName("pingallbyip")} - if https { - opts = append(opts, []hsic.Option{ - hsic.WithTLS(), - }...) - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) - assertNoErrHeadscaleEnv(t, err) - - allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - - // assertClientsState(t, allClients) - - clientIPs := make(map[TailscaleClient][]netip.Addr) - for _, client := range allClients { - ips, err := client.IPs() - if err != nil { - t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) - } - clientIPs[client] = ips - } - - for _, client := range allClients { - err := client.Logout() - if err != nil { - t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) - } - } - - err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) - - t.Logf("all clients logged out") - - headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) - - // if the server is not running with HTTPS, we have to wait a bit before - // reconnection as the newest Tailscale client has a measure that will only - // reconnect over HTTPS if they saw a noise connection previously. - // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 - // https://github.com/juanfont/headscale/issues/2164 - if !https { - time.Sleep(5 * time.Minute) - } - - for userName := range spec { - key, err := scenario.CreatePreAuthKey(userName, true, false) - if err != nil { - t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) - } - - err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) - if err != nil { - t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) - } - } - - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - - // assertClientsState(t, allClients) - - allClients, err = scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - - allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) - - allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { - return x.String() - }) - - success := pingAllHelper(t, allClients, allAddrs) - t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) - - for _, client := range allClients { - ips, err := client.IPs() - if err != nil { - t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) - } - - // lets check if the IPs are the same - if len(ips) != len(clientIPs[client]) { - t.Fatalf("IPs changed for client %s", client.Hostname()) - } - - for _, ip := range ips { - found := false - for _, oldIP := range clientIPs[client] { - if ip == oldIP { - found = true - - break - } - } - - if !found { - t.Fatalf( - "IPs changed for client %s. Used to be %v now %v", - client.Hostname(), - clientIPs[client], - ips, - ) - } - } - } - }) - } -} - func TestEphemeral(t *testing.T) { testEphemeralWithOptions(t, hsic.WithTestName("ephemeral")) } @@ -314,21 +183,9 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { t.Logf("all clients logged out") - for userName := range spec { - nodes, err := headscale.ListNodesInUser(userName) - if err != nil { - log.Error(). - Err(err). - Str("user", userName). - Msg("Error listing nodes in user") - - return - } - - if len(nodes) != 0 { - t.Fatalf("expected no nodes, got %d in user %s", len(nodes), userName) - } - } + nodes, err := headscale.ListNodes() + assertNoErr(t, err) + require.Len(t, nodes, 0) } // TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not @@ -431,7 +288,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { time.Sleep(3 * time.Minute) for userName := range spec { - nodes, err := headscale.ListNodesInUser(userName) + nodes, err := headscale.ListNodes(userName) if err != nil { log.Error(). Err(err). diff --git a/integration/hsic/config.go b/integration/hsic/config.go index 76a5176c..cf62e3a6 100644 --- a/integration/hsic/config.go +++ b/integration/hsic/config.go @@ -16,7 +16,7 @@ func DefaultConfigEnv() map[string]string { "HEADSCALE_POLICY_PATH": "", "HEADSCALE_DATABASE_TYPE": "sqlite", "HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3", - "HEADSCALE_DATABASE_DEBUG": "1", + "HEADSCALE_DATABASE_DEBUG": "0", "HEADSCALE_DATABASE_GORM_SLOW_THRESHOLD": "1", "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", "HEADSCALE_PREFIXES_V4": "100.64.0.0/10", diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index e38abd1c..cff703ac 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -1,6 +1,7 @@ package hsic import ( + "cmp" "crypto/tls" "encoding/json" "errors" @@ -10,6 +11,7 @@ import ( "net/http" "os" "path" + "sort" "strconv" "strings" "time" @@ -744,12 +746,58 @@ func (t *HeadscaleInContainer) CreateAuthKey( return &preAuthKey, nil } -// ListNodesInUser list the TailscaleClients (Node, Headscale internal representation) -// associated with a user. -func (t *HeadscaleInContainer) ListNodesInUser( - user string, +// ListNodes lists the currently registered Nodes in headscale. +// Optionally a list of usernames can be passed to get users for +// specific users. +func (t *HeadscaleInContainer) ListNodes( + users ...string, ) ([]*v1.Node, error) { - command := []string{"headscale", "--user", user, "nodes", "list", "--output", "json"} + var ret []*v1.Node + execUnmarshal := func(command []string) error { + result, _, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + return fmt.Errorf("failed to execute list node command: %w", err) + } + + var nodes []*v1.Node + err = json.Unmarshal([]byte(result), &nodes) + if err != nil { + return fmt.Errorf("failed to unmarshal nodes: %w", err) + } + + ret = append(ret, nodes...) + return nil + } + + if len(users) == 0 { + err := execUnmarshal([]string{"headscale", "nodes", "list", "--output", "json"}) + if err != nil { + return nil, err + } + } else { + for _, user := range users { + command := []string{"headscale", "--user", user, "nodes", "list", "--output", "json"} + + err := execUnmarshal(command) + if err != nil { + return nil, err + } + } + } + + sort.Slice(ret, func(i, j int) bool { + return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1 + }) + return ret, nil +} + +// ListUsers returns a list of users from Headscale. +func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) { + command := []string{"headscale", "users", "list", "--output", "json"} result, _, err := dockertestutil.ExecuteCommand( t.container, @@ -760,13 +808,13 @@ func (t *HeadscaleInContainer) ListNodesInUser( return nil, fmt.Errorf("failed to execute list node command: %w", err) } - var nodes []*v1.Node - err = json.Unmarshal([]byte(result), &nodes) + var users []*v1.User + err = json.Unmarshal([]byte(result), &users) if err != nil { return nil, fmt.Errorf("failed to unmarshal nodes: %w", err) } - return nodes, nil + return users, nil } // WriteFile save file inside the Headscale container.