mirror of
https://github.com/juanfont/headscale.git
synced 2025-02-08 10:18:01 +09:00
Rewrite authentication flow (#2374)
This commit is contained in:
parent
e172c29360
commit
d57a55c024
20 changed files with 848 additions and 996 deletions
2
.github/workflows/check-tests.yaml
vendored
2
.github/workflows/check-tests.yaml
vendored
|
@ -32,7 +32,7 @@ jobs:
|
||||||
- name: Generate and check integration tests
|
- name: Generate and check integration tests
|
||||||
if: steps.changed-files.outputs.files == 'true'
|
if: steps.changed-files.outputs.files == 'true'
|
||||||
run: |
|
run: |
|
||||||
nix develop --command bash -c "cd cmd/gh-action-integration-generator/ && go generate"
|
nix develop --command bash -c "cd .github/workflows && go generate"
|
||||||
git diff --exit-code .github/workflows/test-integration.yaml
|
git diff --exit-code .github/workflows/test-integration.yaml
|
||||||
|
|
||||||
- name: Show missing tests
|
- name: Show missing tests
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
//go:generate go run ./main.go
|
//go:generate go run ./gh-action-integration-generator.go
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -42,15 +42,19 @@ func updateYAML(tests []string) {
|
||||||
testsForYq := fmt.Sprintf("[%s]", strings.Join(tests, ", "))
|
testsForYq := fmt.Sprintf("[%s]", strings.Join(tests, ", "))
|
||||||
|
|
||||||
yqCommand := fmt.Sprintf(
|
yqCommand := fmt.Sprintf(
|
||||||
"yq eval '.jobs.integration-test.strategy.matrix.test = %s' ../../.github/workflows/test-integration.yaml -i",
|
"yq eval '.jobs.integration-test.strategy.matrix.test = %s' ./test-integration.yaml -i",
|
||||||
testsForYq,
|
testsForYq,
|
||||||
)
|
)
|
||||||
cmd := exec.Command("bash", "-c", yqCommand)
|
cmd := exec.Command("bash", "-c", yqCommand)
|
||||||
|
|
||||||
var out bytes.Buffer
|
var stdout bytes.Buffer
|
||||||
cmd.Stdout = &out
|
var stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
err := cmd.Run()
|
err := cmd.Run()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("stdout: %s", stdout.String())
|
||||||
|
log.Printf("stderr: %s", stderr.String())
|
||||||
log.Fatalf("failed to run yq command: %s", err)
|
log.Fatalf("failed to run yq command: %s", err)
|
||||||
}
|
}
|
||||||
|
|
4
.github/workflows/test-integration.yaml
vendored
4
.github/workflows/test-integration.yaml
vendored
|
@ -22,10 +22,13 @@ jobs:
|
||||||
- TestACLNamedHostsCanReach
|
- TestACLNamedHostsCanReach
|
||||||
- TestACLDevice1CanAccessDevice2
|
- TestACLDevice1CanAccessDevice2
|
||||||
- TestPolicyUpdateWhileRunningWithCLIInDatabase
|
- TestPolicyUpdateWhileRunningWithCLIInDatabase
|
||||||
|
- TestAuthKeyLogoutAndReloginSameUser
|
||||||
|
- TestAuthKeyLogoutAndReloginNewUser
|
||||||
- TestOIDCAuthenticationPingAll
|
- TestOIDCAuthenticationPingAll
|
||||||
- TestOIDCExpireNodesBasedOnTokenExpiry
|
- TestOIDCExpireNodesBasedOnTokenExpiry
|
||||||
- TestOIDC024UserCreation
|
- TestOIDC024UserCreation
|
||||||
- TestOIDCAuthenticationWithPKCE
|
- TestOIDCAuthenticationWithPKCE
|
||||||
|
- TestOIDCReloginSameNodeNewUser
|
||||||
- TestAuthWebFlowAuthenticationPingAll
|
- TestAuthWebFlowAuthenticationPingAll
|
||||||
- TestAuthWebFlowLogoutAndRelogin
|
- TestAuthWebFlowLogoutAndRelogin
|
||||||
- TestUserCommand
|
- TestUserCommand
|
||||||
|
@ -50,7 +53,6 @@ jobs:
|
||||||
- TestDERPServerWebsocketScenario
|
- TestDERPServerWebsocketScenario
|
||||||
- TestPingAllByIP
|
- TestPingAllByIP
|
||||||
- TestPingAllByIPPublicDERP
|
- TestPingAllByIPPublicDERP
|
||||||
- TestAuthKeyLogoutAndRelogin
|
|
||||||
- TestEphemeral
|
- TestEphemeral
|
||||||
- TestEphemeralInAlternateTimezone
|
- TestEphemeralInAlternateTimezone
|
||||||
- TestEphemeral2006DeletedTooQuickly
|
- TestEphemeral2006DeletedTooQuickly
|
||||||
|
|
12
CHANGELOG.md
12
CHANGELOG.md
|
@ -2,6 +2,18 @@
|
||||||
|
|
||||||
## Next
|
## Next
|
||||||
|
|
||||||
|
### BREAKING
|
||||||
|
|
||||||
|
- Authentication flow has been rewritten
|
||||||
|
[#2374](https://github.com/juanfont/headscale/pull/2374) This change should be
|
||||||
|
transparent to users with the exception of some buxfixes that has been
|
||||||
|
discovered and was fixed as part of the rewrite.
|
||||||
|
- When a node is registered with _a new user_, it will be registered as a new
|
||||||
|
node ([#2327](https://github.com/juanfont/headscale/issues/2327) and
|
||||||
|
[#1310](https://github.com/juanfont/headscale/issues/1310)).
|
||||||
|
- A logged out node logging in with the same user will replace the existing
|
||||||
|
node.
|
||||||
|
|
||||||
### Changes
|
### Changes
|
||||||
|
|
||||||
- `oidc.map_legacy_users` is now `false` by default
|
- `oidc.map_legacy_users` is now `false` by default
|
||||||
|
|
|
@ -521,25 +521,28 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
|
||||||
|
|
||||||
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||||
// Maybe we should attempt a new in memory state and not go via the DB?
|
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||||
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
// A bool is returned indicating if a full update was sent to all nodes
|
||||||
|
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) (bool, error) {
|
||||||
nodes, err := db.ListNodes()
|
nodes, err := db.ListNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
changed, err := polMan.SetNodes(nodes)
|
filterChanged, err := polMan.SetNodes(nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if changed {
|
if filterChanged {
|
||||||
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
|
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
|
||||||
notif.NotifyAll(ctx, types.StateUpdate{
|
notif.NotifyAll(ctx, types.StateUpdate{
|
||||||
Type: types.StateFullUpdate,
|
Type: types.StateFullUpdate,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||||
|
|
|
@ -2,7 +2,6 @@ package hscontrol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -13,7 +12,6 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
@ -25,555 +23,73 @@ type AuthProvider interface {
|
||||||
AuthURL(types.RegistrationID) string
|
AuthURL(types.RegistrationID) string
|
||||||
}
|
}
|
||||||
|
|
||||||
func logAuthFunc(
|
|
||||||
registerRequest tailcfg.RegisterRequest,
|
|
||||||
machineKey key.MachinePublic,
|
|
||||||
registrationId types.RegistrationID,
|
|
||||||
) (func(string), func(string), func(error, string)) {
|
|
||||||
return func(msg string) {
|
|
||||||
log.Info().
|
|
||||||
Caller().
|
|
||||||
Str("registration_id", registrationId.String()).
|
|
||||||
Str("machine_key", machineKey.ShortString()).
|
|
||||||
Str("node_key", registerRequest.NodeKey.ShortString()).
|
|
||||||
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Str("followup", registerRequest.Followup).
|
|
||||||
Time("expiry", registerRequest.Expiry).
|
|
||||||
Msg(msg)
|
|
||||||
},
|
|
||||||
func(msg string) {
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("registration_id", registrationId.String()).
|
|
||||||
Str("machine_key", machineKey.ShortString()).
|
|
||||||
Str("node_key", registerRequest.NodeKey.ShortString()).
|
|
||||||
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Str("followup", registerRequest.Followup).
|
|
||||||
Time("expiry", registerRequest.Expiry).
|
|
||||||
Msg(msg)
|
|
||||||
},
|
|
||||||
func(err error, msg string) {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("registration_id", registrationId.String()).
|
|
||||||
Str("machine_key", machineKey.ShortString()).
|
|
||||||
Str("node_key", registerRequest.NodeKey.ShortString()).
|
|
||||||
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Str("followup", registerRequest.Followup).
|
|
||||||
Time("expiry", registerRequest.Expiry).
|
|
||||||
Err(err).
|
|
||||||
Msg(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) waitForFollowup(
|
|
||||||
req *http.Request,
|
|
||||||
regReq tailcfg.RegisterRequest,
|
|
||||||
logTrace func(string),
|
|
||||||
) {
|
|
||||||
logTrace("register request is a followup")
|
|
||||||
fu, err := url.Parse(regReq.Followup)
|
|
||||||
if err != nil {
|
|
||||||
logTrace("failed to parse followup URL")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
|
|
||||||
if err != nil {
|
|
||||||
logTrace("followup URL does not contains a valid registration ID")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg))
|
|
||||||
|
|
||||||
if reg, ok := h.registrationCache.Get(followupReg); ok {
|
|
||||||
logTrace("Node is waiting for interactive login")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-req.Context().Done():
|
|
||||||
logTrace("node went away before it was registered")
|
|
||||||
return
|
|
||||||
case <-reg.Registered:
|
|
||||||
logTrace("node has successfully registered")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleRegister is the logic for registering a client.
|
|
||||||
func (h *Headscale) handleRegister(
|
func (h *Headscale) handleRegister(
|
||||||
writer http.ResponseWriter,
|
ctx context.Context,
|
||||||
req *http.Request,
|
|
||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
registrationId, err := types.NewRegistrationID()
|
node, err := h.db.GetNodeByNodeKey(regReq.NodeKey)
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
log.Error().
|
return nil, fmt.Errorf("looking up node in database: %w", err)
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to generate registration ID")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
|
if node != nil {
|
||||||
now := time.Now().UTC()
|
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
||||||
logTrace("handleRegister called, looking up machine in DB")
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if regReq.Followup != "" {
|
||||||
|
// TODO(kradalby): Does this need to return an error of some sort?
|
||||||
|
// Maybe if the registration fails down the line it can be sent
|
||||||
|
// on the channel and returned here?
|
||||||
|
h.waitForFollowup(ctx, regReq)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
|
|
||||||
// key refreshes. This will allow us to remove the machineKey from the registration request.
|
|
||||||
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
|
|
||||||
logTrace("handleRegister database lookup has returned")
|
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
||||||
// If the node has AuthKey set, handle registration via PreAuthKeys
|
|
||||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||||
h.handleAuthKey(writer, regReq, machineKey)
|
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||||
|
if err != nil {
|
||||||
return
|
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the node is waiting for interactive login.
|
return resp, nil
|
||||||
if regReq.Followup != "" {
|
|
||||||
h.waitForFollowup(req, regReq, logTrace)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logInfo("Node not found in database, creating new")
|
resp, err := h.handleRegisterInteractive(regReq, machineKey)
|
||||||
|
if err != nil {
|
||||||
// The node did not have a key to authenticate, which means
|
return nil, fmt.Errorf("handling register interactive: %w", err)
|
||||||
// that we rely on a method that calls back some how (OpenID or CLI)
|
|
||||||
// We create the node and then keep it around until a callback
|
|
||||||
// happens
|
|
||||||
newNode := types.RegisterNode{
|
|
||||||
Node: types.Node{
|
|
||||||
MachineKey: machineKey,
|
|
||||||
Hostname: regReq.Hostinfo.Hostname,
|
|
||||||
NodeKey: regReq.NodeKey,
|
|
||||||
LastSeen: &now,
|
|
||||||
Expiry: &time.Time{},
|
|
||||||
},
|
|
||||||
Registered: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !regReq.Expiry.IsZero() {
|
return resp, nil
|
||||||
logTrace("Non-zero expiry time requested")
|
|
||||||
newNode.Node.Expiry = ®Req.Expiry
|
|
||||||
}
|
|
||||||
|
|
||||||
h.registrationCache.Set(
|
|
||||||
registrationId,
|
|
||||||
newNode,
|
|
||||||
)
|
|
||||||
|
|
||||||
h.handleNewNode(writer, regReq, registrationId)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The node is already in the DB. This could mean one of the following:
|
|
||||||
// - The node is authenticated and ready to /map
|
|
||||||
// - We are doing a key refresh
|
|
||||||
// - The node is logged out (or expired) and pending to be authorized. TODO(juan): We need to keep alive the connection here
|
|
||||||
if node != nil {
|
|
||||||
// (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021,
|
|
||||||
// due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054
|
|
||||||
// So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
|
|
||||||
if err != nil || node.MachineKey.IsZero() {
|
|
||||||
if err := h.db.NodeSetMachineKey(node, machineKey); err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("func", "RegistrationHandler").
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Error saving machine key to database")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the NodeKey stored in headscale is the same as the key presented in a registration
|
|
||||||
// request, then we have a node that is either:
|
|
||||||
// - Trying to log out (sending a expiry in the past)
|
|
||||||
// - A valid, registered node, looking for /map
|
|
||||||
// - Expired node wanting to reauthenticate
|
|
||||||
if node.NodeKey.String() == regReq.NodeKey.String() {
|
|
||||||
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
|
|
||||||
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
|
|
||||||
if !regReq.Expiry.IsZero() &&
|
|
||||||
regReq.Expiry.UTC().Before(now) {
|
|
||||||
h.handleNodeLogOut(writer, *node)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If node is not expired, and it is register, we have a already accepted this node,
|
|
||||||
// let it proceed with a valid registration
|
|
||||||
if !node.IsExpired() {
|
|
||||||
h.handleNodeWithValidRegistration(writer, *node)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
|
||||||
if node.NodeKey.String() == regReq.OldNodeKey.String() &&
|
|
||||||
!node.IsExpired() {
|
|
||||||
h.handleNodeKeyRefresh(
|
|
||||||
writer,
|
|
||||||
regReq,
|
|
||||||
*node,
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// When logged out and reauthenticating with OIDC, the OldNodeKey is not passed, but the NodeKey has changed
|
|
||||||
if node.NodeKey.String() != regReq.NodeKey.String() &&
|
|
||||||
regReq.OldNodeKey.IsZero() && !node.IsExpired() {
|
|
||||||
h.handleNodeKeyRefresh(
|
|
||||||
writer,
|
|
||||||
regReq,
|
|
||||||
*node,
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if regReq.Followup != "" {
|
|
||||||
h.waitForFollowup(req, regReq, logTrace)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The node has expired or it is logged out
|
|
||||||
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId)
|
|
||||||
|
|
||||||
// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
|
|
||||||
node.Expiry = &time.Time{}
|
|
||||||
|
|
||||||
// TODO(kradalby): do we need to rethink this as part of authflow?
|
|
||||||
// If we are here it means the client needs to be reauthorized,
|
|
||||||
// we need to make sure the NodeKey matches the one in the request
|
|
||||||
// TODO(juan): What happens when using fast user switching between two
|
|
||||||
// headscale-managed tailnets?
|
|
||||||
node.NodeKey = regReq.NodeKey
|
|
||||||
h.registrationCache.Set(
|
|
||||||
registrationId,
|
|
||||||
types.RegisterNode{
|
|
||||||
Node: *node,
|
|
||||||
Registered: make(chan struct{}),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAuthKey contains the logic to manage auth key client registration
|
func (h *Headscale) handleExistingNode(
|
||||||
// When using Noise, the machineKey is Zero.
|
node *types.Node,
|
||||||
func (h *Headscale) handleAuthKey(
|
regReq tailcfg.RegisterRequest,
|
||||||
writer http.ResponseWriter,
|
|
||||||
registerRequest tailcfg.RegisterRequest,
|
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
log.Debug().
|
if node.MachineKey != machineKey {
|
||||||
Caller().
|
return nil, errors.New("node already exists with different machine key")
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
|
|
||||||
resp := tailcfg.RegisterResponse{}
|
|
||||||
|
|
||||||
pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed authentication via AuthKey")
|
|
||||||
resp.MachineAuthorized = false
|
|
||||||
|
|
||||||
respBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot encode message")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
expired := node.IsExpired()
|
||||||
writer.WriteHeader(http.StatusUnauthorized)
|
if !expired && !regReq.Expiry.IsZero() {
|
||||||
_, err = writer.Write(respBody)
|
requestExpiry := regReq.Expiry
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
// The client is trying to extend their key, this is not allowed.
|
||||||
Caller().
|
if requestExpiry.After(time.Now()) {
|
||||||
Err(err).
|
return nil, errors.New("extending key is not allowed")
|
||||||
Msg("Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Msg("Failed authentication via AuthKey")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().
|
|
||||||
Caller().
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Msg("Authentication key was valid, proceeding to acquire IP addresses")
|
|
||||||
|
|
||||||
nodeKey := registerRequest.NodeKey
|
|
||||||
|
|
||||||
// retrieve node information if it exist
|
|
||||||
// The error is not important, because if it does not
|
|
||||||
// exist, then this is a new node and we will move
|
|
||||||
// on to registration.
|
|
||||||
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
|
|
||||||
// key refreshes. This will allow us to remove the machineKey from the registration request.
|
|
||||||
node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
|
|
||||||
if node != nil {
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("node was already registered before, refreshing with new auth key")
|
|
||||||
|
|
||||||
node.NodeKey = nodeKey
|
|
||||||
if pak.ID != 0 {
|
|
||||||
node.AuthKeyID = ptr.To(pak.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
node.Expiry = ®isterRequest.Expiry
|
|
||||||
node.User = pak.User
|
|
||||||
node.UserID = pak.UserID
|
|
||||||
err := h.db.DB.Save(node).Error
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("failed to save node after logging in with auth key")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
aclTags := pak.Proto().GetAclTags()
|
|
||||||
if len(aclTags) > 0 {
|
|
||||||
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
|
||||||
err = h.db.SetTags(node.ID, aclTags)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Strs("aclTags", aclTags).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to set tags after refreshing node")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
|
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{Type: types.StatePeerChanged, ChangeNodes: []types.NodeID{node.ID}})
|
|
||||||
} else {
|
|
||||||
now := time.Now().UTC()
|
|
||||||
|
|
||||||
nodeToRegister := types.Node{
|
|
||||||
Hostname: registerRequest.Hostinfo.Hostname,
|
|
||||||
UserID: pak.User.ID,
|
|
||||||
User: pak.User,
|
|
||||||
MachineKey: machineKey,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
Expiry: ®isterRequest.Expiry,
|
|
||||||
NodeKey: nodeKey,
|
|
||||||
LastSeen: &now,
|
|
||||||
ForcedTags: pak.Proto().GetAclTags(),
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv4, ipv6, err := h.ipAlloc.Next()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("func", "RegistrationHandler").
|
|
||||||
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("failed to allocate IP ")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
pakID := uint(pak.ID)
|
|
||||||
if pakID != 0 {
|
|
||||||
nodeToRegister.AuthKeyID = ptr.To(pak.ID)
|
|
||||||
}
|
|
||||||
node, err = h.db.RegisterNode(
|
|
||||||
nodeToRegister,
|
|
||||||
ipv4, ipv6,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("could not register node")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = h.db.Write(func(tx *gorm.DB) error {
|
|
||||||
return db.UsePreAuthKey(tx, pak)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to use pre-auth key")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.MachineAuthorized = true
|
|
||||||
resp.User = *pak.User.TailscaleUser()
|
|
||||||
// Provide LoginName when registering with pre-auth key
|
|
||||||
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
|
|
||||||
resp.Login = *pak.User.TailscaleLogin()
|
|
||||||
|
|
||||||
respBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot encode message")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusOK)
|
|
||||||
_, err = writer.Write(respBody)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to write response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info().
|
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
|
||||||
Msg("Successfully authenticated via AuthKey")
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleNewNode returns the authorisation URL to the client based on what type
|
|
||||||
// of registration headscale is configured with.
|
|
||||||
// This url is then showed to the user by the local Tailscale client.
|
|
||||||
func (h *Headscale) handleNewNode(
|
|
||||||
writer http.ResponseWriter,
|
|
||||||
registerRequest tailcfg.RegisterRequest,
|
|
||||||
registrationId types.RegistrationID,
|
|
||||||
) {
|
|
||||||
logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId)
|
|
||||||
|
|
||||||
resp := tailcfg.RegisterResponse{}
|
|
||||||
|
|
||||||
// The node registration is new, redirect the client to the registration URL
|
|
||||||
logTrace("The node is new, sending auth url")
|
|
||||||
|
|
||||||
resp.AuthURL = h.authProvider.AuthURL(registrationId)
|
|
||||||
|
|
||||||
respBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
logErr(err, "Cannot encode message")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusOK)
|
|
||||||
_, err = writer.Write(respBody)
|
|
||||||
if err != nil {
|
|
||||||
logErr(err, "Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
logInfo(fmt.Sprintf("Successfully sent auth url: %s", resp.AuthURL))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) handleNodeLogOut(
|
|
||||||
writer http.ResponseWriter,
|
|
||||||
node types.Node,
|
|
||||||
) {
|
|
||||||
resp := tailcfg.RegisterResponse{}
|
|
||||||
|
|
||||||
log.Info().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("Client requested logout")
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
err := h.db.NodeSetExpiry(node.ID, now)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to expire node")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
|
||||||
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, now), node.ID)
|
|
||||||
|
|
||||||
resp.AuthURL = ""
|
|
||||||
resp.MachineAuthorized = false
|
|
||||||
resp.NodeKeyExpired = true
|
|
||||||
resp.User = *node.User.TailscaleUser()
|
|
||||||
respBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot encode message")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusOK)
|
|
||||||
_, err = writer.Write(respBody)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to write response")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the request expiry is in the past, we consider it a logout.
|
||||||
|
if requestExpiry.Before(time.Now()) {
|
||||||
if node.IsEphemeral() {
|
if node.IsEphemeral() {
|
||||||
changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.LikelyConnectedMap())
|
changedNodes, err := h.db.DeleteNode(node, h.nodeNotifier.LikelyConnectedMap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||||
Err(err).
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("Cannot delete ephemeral node from the database")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
||||||
|
@ -587,168 +103,164 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
ChangeNodes: changedNodes,
|
ChangeNodes: changedNodes,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expired = true
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.db.NodeSetExpiry(node.ID, requestExpiry)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
||||||
|
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, requestExpiry), node.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tailcfg.RegisterResponse{
|
||||||
|
// TODO(kradalby): Only send for user-owned nodes
|
||||||
|
// and not tagged nodes when tags is working.
|
||||||
|
User: *node.User.TailscaleUser(),
|
||||||
|
Login: *node.User.TailscaleLogin(),
|
||||||
|
NodeKeyExpired: expired,
|
||||||
|
|
||||||
|
// Headscale does not implement the concept of machine authorization
|
||||||
|
// so we always return true here.
|
||||||
|
// Revisit this if #2176 gets implemented.
|
||||||
|
MachineAuthorized: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) waitForFollowup(
|
||||||
|
ctx context.Context,
|
||||||
|
regReq tailcfg.RegisterRequest,
|
||||||
|
) {
|
||||||
|
fu, err := url.Parse(regReq.Followup)
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("Successfully logged out")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) handleNodeWithValidRegistration(
|
|
||||||
writer http.ResponseWriter,
|
|
||||||
node types.Node,
|
|
||||||
) {
|
|
||||||
resp := tailcfg.RegisterResponse{}
|
|
||||||
|
|
||||||
// The node registration is valid, respond with redirect to /map
|
|
||||||
log.Debug().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
|
||||||
|
|
||||||
resp.AuthURL = ""
|
|
||||||
resp.MachineAuthorized = true
|
|
||||||
resp.User = *node.User.TailscaleUser()
|
|
||||||
resp.Login = *node.User.TailscaleLogin()
|
|
||||||
|
|
||||||
respBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot encode message")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
if reg, ok := h.registrationCache.Get(followupReg); ok {
|
||||||
writer.WriteHeader(http.StatusOK)
|
select {
|
||||||
_, err = writer.Write(respBody)
|
case <-ctx.Done():
|
||||||
if err != nil {
|
return
|
||||||
log.Error().
|
case <-reg.Registered:
|
||||||
Caller().
|
return
|
||||||
Err(err).
|
}
|
||||||
Msg("Failed to write response")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Msg("Node successfully authorized")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleNodeKeyRefresh(
|
func (h *Headscale) handleRegisterWithAuthKey(
|
||||||
writer http.ResponseWriter,
|
regReq tailcfg.RegisterRequest,
|
||||||
registerRequest tailcfg.RegisterRequest,
|
machineKey key.MachinePublic,
|
||||||
node types.Node,
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
) {
|
pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey)
|
||||||
resp := tailcfg.RegisterResponse{}
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid pre auth key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Info().
|
nodeToRegister := types.Node{
|
||||||
Caller().
|
Hostname: regReq.Hostinfo.Hostname,
|
||||||
Str("node", node.Hostname).
|
UserID: pak.User.ID,
|
||||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
User: pak.User,
|
||||||
|
MachineKey: machineKey,
|
||||||
|
NodeKey: regReq.NodeKey,
|
||||||
|
Hostinfo: regReq.Hostinfo,
|
||||||
|
LastSeen: ptr.To(time.Now()),
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
|
||||||
err := h.db.Write(func(tx *gorm.DB) error {
|
// TODO(kradalby): This should not be set on the node,
|
||||||
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
|
// they should be looked up through the key, which is
|
||||||
|
// attached to the node.
|
||||||
|
ForcedTags: pak.Proto().GetAclTags(),
|
||||||
|
AuthKey: pak,
|
||||||
|
AuthKeyID: &pak.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !regReq.Expiry.IsZero() {
|
||||||
|
nodeToRegister.Expiry = ®Req.Expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv4, ipv6, err := h.ipAlloc.Next()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("allocating IPs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
node, err := db.Write(h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
|
node, err := db.RegisterNode(tx,
|
||||||
|
nodeToRegister,
|
||||||
|
ipv4, ipv6,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("registering node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pak.Reusable {
|
||||||
|
err = db.UsePreAuthKey(tx, pak)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("using pre auth key: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return node, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
return nil, err
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to update machine key in the database")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.AuthURL = ""
|
updateSent, err := nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
|
||||||
resp.User = *node.User.TailscaleUser()
|
|
||||||
respBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
return nil, fmt.Errorf("nodes changed hook: %w", err)
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot encode message")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
if !updateSent {
|
||||||
writer.WriteHeader(http.StatusOK)
|
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
|
||||||
_, err = writer.Write(respBody)
|
h.nodeNotifier.NotifyAll(ctx, types.StateUpdatePeerAdded(node.ID))
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to write response")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
return &tailcfg.RegisterResponse{
|
||||||
Caller().
|
MachineAuthorized: true,
|
||||||
Str("node_key", registerRequest.NodeKey.ShortString()).
|
NodeKeyExpired: node.IsExpired(),
|
||||||
Str("old_node_key", registerRequest.OldNodeKey.ShortString()).
|
User: *pak.User.TailscaleUser(),
|
||||||
Str("node", node.Hostname).
|
Login: *pak.User.TailscaleLogin(),
|
||||||
Msg("Node key successfully refreshed")
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) handleNodeExpiredOrLoggedOut(
|
func (h *Headscale) handleRegisterInteractive(
|
||||||
writer http.ResponseWriter,
|
|
||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
node types.Node,
|
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
registrationId types.RegistrationID,
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
) {
|
registrationId, err := types.NewRegistrationID()
|
||||||
resp := tailcfg.RegisterResponse{}
|
|
||||||
|
|
||||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
|
||||||
h.handleAuthKey(writer, regReq, machineKey)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The client has registered before, but has expired or logged out
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("node", node.Hostname).
|
|
||||||
Str("registration_id", registrationId.String()).
|
|
||||||
Str("node_key", regReq.NodeKey.ShortString()).
|
|
||||||
Str("node_key_old", regReq.OldNodeKey.ShortString()).
|
|
||||||
Msg("Node registration has expired or logged out. Sending a auth url to register")
|
|
||||||
|
|
||||||
resp.AuthURL = h.authProvider.AuthURL(registrationId)
|
|
||||||
|
|
||||||
respBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
return nil, fmt.Errorf("generating registration ID: %w", err)
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot encode message")
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
newNode := types.RegisterNode{
|
||||||
writer.WriteHeader(http.StatusOK)
|
Node: types.Node{
|
||||||
_, err = writer.Write(respBody)
|
Hostname: regReq.Hostinfo.Hostname,
|
||||||
if err != nil {
|
MachineKey: machineKey,
|
||||||
log.Error().
|
NodeKey: regReq.NodeKey,
|
||||||
Caller().
|
Hostinfo: regReq.Hostinfo,
|
||||||
Err(err).
|
LastSeen: ptr.To(time.Now()),
|
||||||
Msg("Failed to write response")
|
},
|
||||||
|
Registered: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
if !regReq.Expiry.IsZero() {
|
||||||
Caller().
|
newNode.Node.Expiry = ®Req.Expiry
|
||||||
Str("registration_id", registrationId.String()).
|
}
|
||||||
Str("node_key", regReq.NodeKey.ShortString()).
|
|
||||||
Str("node_key_old", regReq.OldNodeKey.ShortString()).
|
h.registrationCache.Set(
|
||||||
Str("node", node.Hostname).
|
registrationId,
|
||||||
Msg("Node logged out. Sent AuthURL for reauthentication")
|
newNode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &tailcfg.RegisterResponse{
|
||||||
|
AuthURL: h.authProvider.AuthURL(registrationId),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -182,38 +182,6 @@ func GetNodeByNodeKey(
|
||||||
return &mach, nil
|
return &mach, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetNodeByAnyKey(
|
|
||||||
machineKey key.MachinePublic,
|
|
||||||
nodeKey key.NodePublic,
|
|
||||||
oldNodeKey key.NodePublic,
|
|
||||||
) (*types.Node, error) {
|
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
|
||||||
return GetNodeByAnyKey(rx, machineKey, nodeKey, oldNodeKey)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct.
|
|
||||||
// TODO(kradalby): see if we can remove this.
|
|
||||||
func GetNodeByAnyKey(
|
|
||||||
tx *gorm.DB,
|
|
||||||
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
|
||||||
) (*types.Node, error) {
|
|
||||||
node := types.Node{}
|
|
||||||
if result := tx.
|
|
||||||
Preload("AuthKey").
|
|
||||||
Preload("AuthKey.User").
|
|
||||||
Preload("User").
|
|
||||||
Preload("Routes").
|
|
||||||
First(&node, "machine_key = ? OR node_key = ? OR node_key = ?",
|
|
||||||
machineKey.String(),
|
|
||||||
nodeKey.String(),
|
|
||||||
oldNodeKey.String()); result.Error != nil {
|
|
||||||
return nil, result.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
return &node, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) SetTags(
|
func (hsdb *HSDatabase) SetTags(
|
||||||
nodeID types.NodeID,
|
nodeID types.NodeID,
|
||||||
tags []string,
|
tags []string,
|
||||||
|
@ -437,6 +405,18 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||||
Str("user", node.User.Username()).
|
Str("user", node.User.Username()).
|
||||||
Msg("Registering node")
|
Msg("Registering node")
|
||||||
|
|
||||||
|
// If the a new node is registered with the same machine key, to the same user,
|
||||||
|
// update the existing node.
|
||||||
|
// If the same node is registered again, but to a new user, then that is considered
|
||||||
|
// a new node.
|
||||||
|
oldNode, _ := GetNodeByMachineKey(tx, node.MachineKey)
|
||||||
|
if oldNode != nil && oldNode.UserID == node.UserID {
|
||||||
|
node.ID = oldNode.ID
|
||||||
|
node.GivenName = oldNode.GivenName
|
||||||
|
ipv4 = oldNode.IPv4
|
||||||
|
ipv6 = oldNode.IPv6
|
||||||
|
}
|
||||||
|
|
||||||
// If the node exists and it already has IP(s), we just save it
|
// If the node exists and it already has IP(s), we just save it
|
||||||
// so we store the node.Expire and node.Nodekey that has been set when
|
// so we store the node.Expire and node.Nodekey that has been set when
|
||||||
// adding it to the registrationCache
|
// adding it to the registrationCache
|
||||||
|
|
|
@ -84,37 +84,6 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
|
||||||
user, err := db.CreateUser(types.User{Name: "test"})
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
|
||||||
c.Assert(err, check.NotNil)
|
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
|
||||||
oldNodeKey := key.NewNode()
|
|
||||||
|
|
||||||
machineKey := key.NewMachine()
|
|
||||||
|
|
||||||
node := types.Node{
|
|
||||||
ID: 0,
|
|
||||||
MachineKey: machineKey.Public(),
|
|
||||||
NodeKey: nodeKey.Public(),
|
|
||||||
Hostname: "testnode",
|
|
||||||
UserID: user.ID,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
|
||||||
}
|
|
||||||
trx := db.DB.Save(&node)
|
|
||||||
c.Assert(trx.Error, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) TestHardDeleteNode(c *check.C) {
|
func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||||
user, err := db.CreateUser(types.User{Name: "test"})
|
user, err := db.CreateUser(types.User{Name: "test"})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
|
@ -256,10 +256,17 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
|
updateSent, err := nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("updating resources using node: %w", err)
|
return nil, fmt.Errorf("updating resources using node: %w", err)
|
||||||
}
|
}
|
||||||
|
if !updateSent {
|
||||||
|
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
|
||||||
|
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StatePeerChanged,
|
||||||
|
ChangeNodes: []types.NodeID{node.ID},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,7 +156,12 @@ func isSupportedVersion(version tailcfg.CapabilityVersion) bool {
|
||||||
return version >= MinimumCapVersion
|
return version >= MinimumCapVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
func rejectUnsupported(writer http.ResponseWriter, version tailcfg.CapabilityVersion, mkey key.MachinePublic, nkey key.NodePublic) bool {
|
func rejectUnsupported(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
version tailcfg.CapabilityVersion,
|
||||||
|
mkey key.MachinePublic,
|
||||||
|
nkey key.NodePublic,
|
||||||
|
) bool {
|
||||||
// Reject unsupported versions
|
// Reject unsupported versions
|
||||||
if !isSupportedVersion(version) {
|
if !isSupportedVersion(version) {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -204,11 +209,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||||
|
|
||||||
ns.nodeKey = mapRequest.NodeKey
|
ns.nodeKey = mapRequest.NodeKey
|
||||||
|
|
||||||
node, err := ns.headscale.db.GetNodeByAnyKey(
|
node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||||
ns.conn.Peer(),
|
|
||||||
mapRequest.NodeKey,
|
|
||||||
key.NodePublic{},
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(writer, err, "Internal error", http.StatusInternalServerError)
|
httpError(writer, err, "Internal error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -234,12 +235,38 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
body, _ := io.ReadAll(req.Body)
|
registerRequest, registerResponse, err := func() (*tailcfg.RegisterRequest, []byte, error) {
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
var registerRequest tailcfg.RegisterRequest
|
var registerRequest tailcfg.RegisterRequest
|
||||||
if err := json.Unmarshal(body, ®isterRequest); err != nil {
|
if err := json.Unmarshal(body, ®isterRequest); err != nil {
|
||||||
httpError(writer, err, "Internal error", http.StatusInternalServerError)
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return
|
ns.nodeKey = registerRequest.NodeKey
|
||||||
|
|
||||||
|
resp, err := ns.headscale.handleRegister(req.Context(), registerRequest, ns.conn.Peer())
|
||||||
|
// TODO(kradalby): Here we could have two error types, one that is surfaced to the client
|
||||||
|
// and one that returns 500.
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ®isterRequest, respBody, nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Error handling registration")
|
||||||
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject unsupported versions
|
// Reject unsupported versions
|
||||||
|
@ -247,7 +274,13 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ns.nodeKey = registerRequest.NodeKey
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
|
writer.WriteHeader(http.StatusOK)
|
||||||
ns.headscale.handleRegister(writer, req, registerRequest, ns.conn.Peer())
|
_, err = writer.Write(registerResponse)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to write response")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -512,24 +512,21 @@ func (a *AuthProviderOIDC) handleRegistrationID(
|
||||||
// Send an update to all nodes if this is a new node that they need to know
|
// Send an update to all nodes if this is a new node that they need to know
|
||||||
// about.
|
// about.
|
||||||
// If this is a refresh, just send new expiry updates.
|
// If this is a refresh, just send new expiry updates.
|
||||||
if newNode {
|
updateSent, err := nodesChangedHook(a.db, a.polMan, a.notifier)
|
||||||
err = nodesChangedHook(a.db, a.polMan, a.notifier)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("updating resources using node: %w", err)
|
return false, fmt.Errorf("updating resources using node: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
|
if !updateSent {
|
||||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
||||||
a.notifier.NotifyByNodeID(
|
a.notifier.NotifyByNodeID(
|
||||||
ctx,
|
ctx,
|
||||||
types.StateUpdate{
|
types.StateSelf(node.ID),
|
||||||
Type: types.StateSelfUpdate,
|
|
||||||
ChangeNodes: []types.NodeID{node.ID},
|
|
||||||
},
|
|
||||||
node.ID,
|
node.ID,
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
||||||
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
|
a.notifier.NotifyWithIgnore(ctx, types.StateUpdatePeerAdded(node.ID), node.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newNode, nil
|
return newNode, nil
|
||||||
|
|
|
@ -102,6 +102,20 @@ func (su *StateUpdate) Empty() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func StateSelf(nodeID NodeID) StateUpdate {
|
||||||
|
return StateUpdate{
|
||||||
|
Type: StateSelfUpdate,
|
||||||
|
ChangeNodes: []NodeID{nodeID},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func StateUpdatePeerAdded(nodeIDs ...NodeID) StateUpdate {
|
||||||
|
return StateUpdate{
|
||||||
|
Type: StatePeerChanged,
|
||||||
|
ChangeNodes: nodeIDs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
||||||
return StateUpdate{
|
return StateUpdate{
|
||||||
Type: StatePeerChangedPatch,
|
Type: StatePeerChangedPatch,
|
||||||
|
|
230
integration/auth_key_test.go
Normal file
230
integration/auth_key_test.go
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, https := range []bool{true, false} {
|
||||||
|
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
|
||||||
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
spec := map[string]int{
|
||||||
|
"user1": len(MustTestVersions),
|
||||||
|
"user2": len(MustTestVersions),
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := []hsic.Option{hsic.WithTestName("pingallbyip")}
|
||||||
|
if https {
|
||||||
|
opts = append(opts, []hsic.Option{
|
||||||
|
hsic.WithTLS(),
|
||||||
|
}...)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
|
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
||||||
|
for _, client := range allClients {
|
||||||
|
ips, err := client.IPs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err)
|
||||||
|
}
|
||||||
|
clientIPs[client] = ips
|
||||||
|
}
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErrGetHeadscale(t, err)
|
||||||
|
|
||||||
|
listNodes, err := headscale.ListNodes()
|
||||||
|
assert.Equal(t, len(listNodes), len(allClients))
|
||||||
|
nodeCountBeforeLogout := len(listNodes)
|
||||||
|
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||||
|
|
||||||
|
for _, client := range allClients {
|
||||||
|
err := client.Logout()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to logout client %s: %s", client.Hostname(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleLogout()
|
||||||
|
assertNoErrLogout(t, err)
|
||||||
|
|
||||||
|
t.Logf("all clients logged out")
|
||||||
|
|
||||||
|
// if the server is not running with HTTPS, we have to wait a bit before
|
||||||
|
// reconnection as the newest Tailscale client has a measure that will only
|
||||||
|
// reconnect over HTTPS if they saw a noise connection previously.
|
||||||
|
// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38
|
||||||
|
// https://github.com/juanfont/headscale/issues/2164
|
||||||
|
if !https {
|
||||||
|
time.Sleep(5 * time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
for userName := range spec {
|
||||||
|
key, err := scenario.CreatePreAuthKey(userName, true, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
listNodes, err = headscale.ListNodes()
|
||||||
|
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
||||||
|
|
||||||
|
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||||
|
assertNoErrListClientIPs(t, err)
|
||||||
|
|
||||||
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
|
return x.String()
|
||||||
|
})
|
||||||
|
|
||||||
|
success := pingAllHelper(t, allClients, allAddrs)
|
||||||
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
|
|
||||||
|
for _, client := range allClients {
|
||||||
|
ips, err := client.IPs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lets check if the IPs are the same
|
||||||
|
if len(ips) != len(clientIPs[client]) {
|
||||||
|
t.Fatalf("IPs changed for client %s", client.Hostname())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
found := false
|
||||||
|
for _, oldIP := range clientIPs[client] {
|
||||||
|
if ip == oldIP {
|
||||||
|
found = true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
t.Fatalf(
|
||||||
|
"IPs changed for client %s. Used to be %v now %v",
|
||||||
|
client.Hostname(),
|
||||||
|
clientIPs[client],
|
||||||
|
ips,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This test will first log in two sets of nodes to two sets of users, then
|
||||||
|
// it will log out all users from user2 and log them in as user1.
|
||||||
|
// This should leave us with all nodes connected to user1, while user2
|
||||||
|
// still has nodes, but they are not connected.
|
||||||
|
func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
spec := map[string]int{
|
||||||
|
"user1": len(MustTestVersions),
|
||||||
|
"user2": len(MustTestVersions),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{},
|
||||||
|
hsic.WithTestName("keyrelognewuser"),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
// assertClientsState(t, allClients)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErrGetHeadscale(t, err)
|
||||||
|
|
||||||
|
listNodes, err := headscale.ListNodes()
|
||||||
|
assert.Equal(t, len(listNodes), len(allClients))
|
||||||
|
nodeCountBeforeLogout := len(listNodes)
|
||||||
|
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||||
|
|
||||||
|
for _, client := range allClients {
|
||||||
|
err := client.Logout()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to logout client %s: %s", client.Hostname(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.WaitForTailscaleLogout()
|
||||||
|
assertNoErrLogout(t, err)
|
||||||
|
|
||||||
|
t.Logf("all clients logged out")
|
||||||
|
|
||||||
|
// Create a new authkey for user1, to be used for all clients
|
||||||
|
key, err := scenario.CreatePreAuthKey("user1", true, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create pre-auth key for user1: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log in all clients as user1, iterating over the spec only returns the
|
||||||
|
// clients, not the usernames.
|
||||||
|
for userName := range spec {
|
||||||
|
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
user1Nodes, err := headscale.ListNodes("user1")
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, user1Nodes, len(allClients))
|
||||||
|
|
||||||
|
// Validate that all the old nodes are still present with user2
|
||||||
|
user2Nodes, err := headscale.ListNodes("user2")
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, user2Nodes, len(allClients)/2)
|
||||||
|
|
||||||
|
for _, client := range allClients {
|
||||||
|
status, err := client.Status()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get status for client %s: %s", client.Hostname(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "user1@test.no", status.User[status.Self.UserID].LoginName)
|
||||||
|
}
|
||||||
|
}
|
|
@ -116,20 +116,10 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
var listUsers []v1.User
|
listUsers, err := headscale.ListUsers()
|
||||||
err = executeAndUnmarshal(headscale,
|
|
||||||
[]string{
|
|
||||||
"headscale",
|
|
||||||
"users",
|
|
||||||
"list",
|
|
||||||
"--output",
|
|
||||||
"json",
|
|
||||||
},
|
|
||||||
&listUsers,
|
|
||||||
)
|
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
want := []v1.User{
|
want := []*v1.User{
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Name: "user1",
|
Name: "user1",
|
||||||
|
@ -249,7 +239,7 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||||
emailVerified bool
|
emailVerified bool
|
||||||
cliUsers []string
|
cliUsers []string
|
||||||
oidcUsers []string
|
oidcUsers []string
|
||||||
want func(iss string) []v1.User
|
want func(iss string) []*v1.User
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "no-migration-verified-email",
|
name: "no-migration-verified-email",
|
||||||
|
@ -259,8 +249,8 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||||
emailVerified: true,
|
emailVerified: true,
|
||||||
cliUsers: []string{"user1", "user2"},
|
cliUsers: []string{"user1", "user2"},
|
||||||
oidcUsers: []string{"user1", "user2"},
|
oidcUsers: []string{"user1", "user2"},
|
||||||
want: func(iss string) []v1.User {
|
want: func(iss string) []*v1.User {
|
||||||
return []v1.User{
|
return []*v1.User{
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Name: "user1",
|
Name: "user1",
|
||||||
|
@ -296,8 +286,8 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||||
emailVerified: false,
|
emailVerified: false,
|
||||||
cliUsers: []string{"user1", "user2"},
|
cliUsers: []string{"user1", "user2"},
|
||||||
oidcUsers: []string{"user1", "user2"},
|
oidcUsers: []string{"user1", "user2"},
|
||||||
want: func(iss string) []v1.User {
|
want: func(iss string) []*v1.User {
|
||||||
return []v1.User{
|
return []*v1.User{
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Name: "user1",
|
Name: "user1",
|
||||||
|
@ -332,8 +322,8 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||||
emailVerified: true,
|
emailVerified: true,
|
||||||
cliUsers: []string{"user1", "user2"},
|
cliUsers: []string{"user1", "user2"},
|
||||||
oidcUsers: []string{"user1", "user2"},
|
oidcUsers: []string{"user1", "user2"},
|
||||||
want: func(iss string) []v1.User {
|
want: func(iss string) []*v1.User {
|
||||||
return []v1.User{
|
return []*v1.User{
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Name: "user1",
|
Name: "user1",
|
||||||
|
@ -360,8 +350,8 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||||
emailVerified: false,
|
emailVerified: false,
|
||||||
cliUsers: []string{"user1", "user2"},
|
cliUsers: []string{"user1", "user2"},
|
||||||
oidcUsers: []string{"user1", "user2"},
|
oidcUsers: []string{"user1", "user2"},
|
||||||
want: func(iss string) []v1.User {
|
want: func(iss string) []*v1.User {
|
||||||
return []v1.User{
|
return []*v1.User{
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Name: "user1",
|
Name: "user1",
|
||||||
|
@ -396,8 +386,8 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||||
emailVerified: true,
|
emailVerified: true,
|
||||||
cliUsers: []string{"user1.headscale.net", "user2.headscale.net"},
|
cliUsers: []string{"user1.headscale.net", "user2.headscale.net"},
|
||||||
oidcUsers: []string{"user1", "user2"},
|
oidcUsers: []string{"user1", "user2"},
|
||||||
want: func(iss string) []v1.User {
|
want: func(iss string) []*v1.User {
|
||||||
return []v1.User{
|
return []*v1.User{
|
||||||
// Hmm I think we will have to overwrite the initial name here
|
// Hmm I think we will have to overwrite the initial name here
|
||||||
// createuser with "user1.headscale.net", but oidc with "user1"
|
// createuser with "user1.headscale.net", but oidc with "user1"
|
||||||
{
|
{
|
||||||
|
@ -426,8 +416,8 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||||
emailVerified: false,
|
emailVerified: false,
|
||||||
cliUsers: []string{"user1.headscale.net", "user2.headscale.net"},
|
cliUsers: []string{"user1.headscale.net", "user2.headscale.net"},
|
||||||
oidcUsers: []string{"user1", "user2"},
|
oidcUsers: []string{"user1", "user2"},
|
||||||
want: func(iss string) []v1.User {
|
want: func(iss string) []*v1.User {
|
||||||
return []v1.User{
|
return []*v1.User{
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Name: "user1.headscale.net",
|
Name: "user1.headscale.net",
|
||||||
|
@ -509,17 +499,7 @@ func TestOIDC024UserCreation(t *testing.T) {
|
||||||
|
|
||||||
want := tt.want(oidcConfig.Issuer)
|
want := tt.want(oidcConfig.Issuer)
|
||||||
|
|
||||||
var listUsers []v1.User
|
listUsers, err := headscale.ListUsers()
|
||||||
err = executeAndUnmarshal(headscale,
|
|
||||||
[]string{
|
|
||||||
"headscale",
|
|
||||||
"users",
|
|
||||||
"list",
|
|
||||||
"--output",
|
|
||||||
"json",
|
|
||||||
},
|
|
||||||
&listUsers,
|
|
||||||
)
|
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
sort.Slice(listUsers, func(i, j int) bool {
|
sort.Slice(listUsers, func(i, j int) bool {
|
||||||
|
@ -587,23 +567,6 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
// Verify PKCE was used in authentication
|
|
||||||
headscale, err := scenario.Headscale()
|
|
||||||
assertNoErr(t, err)
|
|
||||||
|
|
||||||
var listUsers []v1.User
|
|
||||||
err = executeAndUnmarshal(headscale,
|
|
||||||
[]string{
|
|
||||||
"headscale",
|
|
||||||
"users",
|
|
||||||
"list",
|
|
||||||
"--output",
|
|
||||||
"json",
|
|
||||||
},
|
|
||||||
&listUsers,
|
|
||||||
)
|
|
||||||
assertNoErr(t, err)
|
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
||||||
return x.String()
|
return x.String()
|
||||||
})
|
})
|
||||||
|
@ -612,6 +575,228 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
||||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
scenario := AuthOIDCScenario{
|
||||||
|
Scenario: baseScenario,
|
||||||
|
}
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
// Create no nodes and no users
|
||||||
|
spec := map[string]int{}
|
||||||
|
|
||||||
|
// First login creates the first OIDC user
|
||||||
|
// Second login logs in the same node, which creates a new node
|
||||||
|
// Third login logs in the same node back into the original user
|
||||||
|
mockusers := []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
oidcMockUser("user2", true),
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
|
||||||
|
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||||
|
// defer scenario.mockOIDC.Close()
|
||||||
|
|
||||||
|
oidcMap := map[string]string{
|
||||||
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
|
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||||
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
// TODO(kradalby): Remove when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(
|
||||||
|
spec,
|
||||||
|
hsic.WithTestName("oidcauthrelog"),
|
||||||
|
hsic.WithConfigEnv(oidcMap),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
|
||||||
|
hsic.WithEmbeddedDERPServerOnly(),
|
||||||
|
)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
listUsers, err := headscale.ListUsers()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listUsers, 0)
|
||||||
|
|
||||||
|
ts, err := scenario.CreateTailscaleNode("unstable")
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
u, err := ts.LoginWithURL(headscale.GetEndpoint())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
_, err = doLoginURL(ts.Hostname(), u)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
listUsers, err = headscale.ListUsers()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listUsers, 1)
|
||||||
|
wantUsers := []*v1.User{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||||
|
})
|
||||||
|
|
||||||
|
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||||
|
t.Fatalf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
listNodes, err := headscale.ListNodes()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listNodes, 1)
|
||||||
|
|
||||||
|
// Log out user1 and log in user2, this should create a new node
|
||||||
|
// for user2, the node should have the same machine key and
|
||||||
|
// a new node key.
|
||||||
|
err = ts.Logout()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||||
|
// logs in immediately after the first logout and I cannot reproduce it
|
||||||
|
// manually.
|
||||||
|
err = ts.Logout()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
u, err = ts.LoginWithURL(headscale.GetEndpoint())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
_, err = doLoginURL(ts.Hostname(), u)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
listUsers, err = headscale.ListUsers()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listUsers, 2)
|
||||||
|
wantUsers = []*v1.User{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 2,
|
||||||
|
Name: "user2",
|
||||||
|
Email: "user2@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||||
|
})
|
||||||
|
|
||||||
|
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||||
|
t.Fatalf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
listNodesAfterNewUserLogin, err := headscale.ListNodes()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listNodesAfterNewUserLogin, 2)
|
||||||
|
|
||||||
|
// Machine key is the same as the "machine" has not changed,
|
||||||
|
// but Node key is not as it is a new node
|
||||||
|
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey)
|
||||||
|
assert.Equal(t, listNodesAfterNewUserLogin[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey)
|
||||||
|
assert.NotEqual(t, listNodesAfterNewUserLogin[0].NodeKey, listNodesAfterNewUserLogin[1].NodeKey)
|
||||||
|
|
||||||
|
// Log out user2, and log into user1, no new node should be created,
|
||||||
|
// the node should now "become" node1 again
|
||||||
|
err = ts.Logout()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
// TODO(kradalby): Not sure why we need to logout twice, but it fails and
|
||||||
|
// logs in immediately after the first logout and I cannot reproduce it
|
||||||
|
// manually.
|
||||||
|
err = ts.Logout()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
u, err = ts.LoginWithURL(headscale.GetEndpoint())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
_, err = doLoginURL(ts.Hostname(), u)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
listUsers, err = headscale.ListUsers()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listUsers, 2)
|
||||||
|
wantUsers = []*v1.User{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 2,
|
||||||
|
Name: "user2",
|
||||||
|
Email: "user2@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].GetId() < listUsers[j].GetId()
|
||||||
|
})
|
||||||
|
|
||||||
|
if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||||
|
t.Fatalf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
listNodesAfterLoggingBackIn, err := headscale.ListNodes()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
assert.Len(t, listNodesAfterLoggingBackIn, 2)
|
||||||
|
|
||||||
|
// Validate that the machine we had when we logged in the first time, has the same
|
||||||
|
// machine key, but a different ID than the newly logged in version of the same
|
||||||
|
// machine.
|
||||||
|
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[0].MachineKey)
|
||||||
|
assert.Equal(t, listNodes[0].NodeKey, listNodesAfterNewUserLogin[0].NodeKey)
|
||||||
|
assert.Equal(t, listNodes[0].Id, listNodesAfterNewUserLogin[0].Id)
|
||||||
|
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterNewUserLogin[1].MachineKey)
|
||||||
|
assert.NotEqual(t, listNodes[0].Id, listNodesAfterNewUserLogin[1].Id)
|
||||||
|
assert.NotEqual(t, listNodes[0].User.Id, listNodesAfterNewUserLogin[1].User.Id)
|
||||||
|
|
||||||
|
// Even tho we are logging in again with the same user, the previous key has been expired
|
||||||
|
// and a new one has been generated. The node entry in the database should be the same
|
||||||
|
// as the user + machinekey still matches.
|
||||||
|
assert.Equal(t, listNodes[0].MachineKey, listNodesAfterLoggingBackIn[0].MachineKey)
|
||||||
|
assert.NotEqual(t, listNodes[0].NodeKey, listNodesAfterLoggingBackIn[0].NodeKey)
|
||||||
|
assert.Equal(t, listNodes[0].Id, listNodesAfterLoggingBackIn[0].Id)
|
||||||
|
|
||||||
|
// The "logged back in" machine should have the same machinekey but a different nodekey
|
||||||
|
// than the version logged in with a different user.
|
||||||
|
assert.Equal(t, listNodesAfterLoggingBackIn[0].MachineKey, listNodesAfterLoggingBackIn[1].MachineKey)
|
||||||
|
assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||||
users map[string]int,
|
users map[string]int,
|
||||||
opts ...hsic.Option,
|
opts ...hsic.Option,
|
||||||
|
|
|
@ -11,6 +11,8 @@ import (
|
||||||
|
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errParseAuthPage = errors.New("failed to parse auth page")
|
var errParseAuthPage = errors.New("failed to parse auth page")
|
||||||
|
@ -106,6 +108,14 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||||
success := pingAllHelper(t, allClients, allAddrs)
|
success := pingAllHelper(t, allClients, allAddrs)
|
||||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErrGetHeadscale(t, err)
|
||||||
|
|
||||||
|
listNodes, err := headscale.ListNodes()
|
||||||
|
assert.Equal(t, len(listNodes), len(allClients))
|
||||||
|
nodeCountBeforeLogout := len(listNodes)
|
||||||
|
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||||
|
|
||||||
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
ips, err := client.IPs()
|
ips, err := client.IPs()
|
||||||
|
@ -127,9 +137,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||||
|
|
||||||
t.Logf("all clients logged out")
|
t.Logf("all clients logged out")
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
|
||||||
assertNoErrGetHeadscale(t, err)
|
|
||||||
|
|
||||||
for userName := range spec {
|
for userName := range spec {
|
||||||
err = scenario.runTailscaleUp(userName, headscale.GetEndpoint())
|
err = scenario.runTailscaleUp(userName, headscale.GetEndpoint())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -139,9 +146,6 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||||
|
|
||||||
t.Logf("all clients logged in again")
|
t.Logf("all clients logged in again")
|
||||||
|
|
||||||
allClients, err = scenario.ListTailscaleClients()
|
|
||||||
assertNoErrListClients(t, err)
|
|
||||||
|
|
||||||
allIps, err = scenario.ListTailscaleClientsIPs()
|
allIps, err = scenario.ListTailscaleClientsIPs()
|
||||||
assertNoErrListClientIPs(t, err)
|
assertNoErrListClientIPs(t, err)
|
||||||
|
|
||||||
|
@ -152,6 +156,10 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||||
success = pingAllHelper(t, allClients, allAddrs)
|
success = pingAllHelper(t, allClients, allAddrs)
|
||||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
|
|
||||||
|
listNodes, err = headscale.ListNodes()
|
||||||
|
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
||||||
|
t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes))
|
||||||
|
|
||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
ips, err := client.IPs()
|
ips, err := client.IPs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -606,22 +606,12 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||||
t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String())
|
t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
var listNodes []v1.Node
|
listNodes, err := headscale.ListNodes()
|
||||||
err = executeAndUnmarshal(
|
|
||||||
headscale,
|
|
||||||
[]string{
|
|
||||||
"headscale",
|
|
||||||
"nodes",
|
|
||||||
"list",
|
|
||||||
"--output",
|
|
||||||
"json",
|
|
||||||
},
|
|
||||||
&listNodes,
|
|
||||||
)
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Len(t, listNodes, 1)
|
assert.Len(t, listNodes, 2)
|
||||||
|
|
||||||
assert.Equal(t, "user2", listNodes[0].GetUser().GetName())
|
assert.Equal(t, "user1", listNodes[0].GetUser().GetName())
|
||||||
|
assert.Equal(t, "user2", listNodes[1].GetUser().GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiKeyCommand(t *testing.T) {
|
func TestApiKeyCommand(t *testing.T) {
|
||||||
|
|
|
@ -17,7 +17,8 @@ type ControlServer interface {
|
||||||
WaitForRunning() error
|
WaitForRunning() error
|
||||||
CreateUser(user string) error
|
CreateUser(user string) error
|
||||||
CreateAuthKey(user string, reusable bool, ephemeral bool) (*v1.PreAuthKey, error)
|
CreateAuthKey(user string, reusable bool, ephemeral bool) (*v1.PreAuthKey, error)
|
||||||
ListNodesInUser(user string) ([]*v1.Node, error)
|
ListNodes(users ...string) ([]*v1.Node, error)
|
||||||
|
ListUsers() ([]*v1.User, error)
|
||||||
GetCert() []byte
|
GetCert() []byte
|
||||||
GetHostname() string
|
GetHostname() string
|
||||||
GetIP() string
|
GetIP() string
|
||||||
|
|
|
@ -105,137 +105,6 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
|
||||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
|
||||||
IntegrationSkip(t)
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
for _, https := range []bool{true, false} {
|
|
||||||
t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) {
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
|
||||||
assertNoErr(t, err)
|
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
|
||||||
|
|
||||||
spec := map[string]int{
|
|
||||||
"user1": len(MustTestVersions),
|
|
||||||
"user2": len(MustTestVersions),
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := []hsic.Option{hsic.WithTestName("pingallbyip")}
|
|
||||||
if https {
|
|
||||||
opts = append(opts, []hsic.Option{
|
|
||||||
hsic.WithTLS(),
|
|
||||||
}...)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...)
|
|
||||||
assertNoErrHeadscaleEnv(t, err)
|
|
||||||
|
|
||||||
allClients, err := scenario.ListTailscaleClients()
|
|
||||||
assertNoErrListClients(t, err)
|
|
||||||
|
|
||||||
err = scenario.WaitForTailscaleSync()
|
|
||||||
assertNoErrSync(t, err)
|
|
||||||
|
|
||||||
// assertClientsState(t, allClients)
|
|
||||||
|
|
||||||
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
|
||||||
for _, client := range allClients {
|
|
||||||
ips, err := client.IPs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err)
|
|
||||||
}
|
|
||||||
clientIPs[client] = ips
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, client := range allClients {
|
|
||||||
err := client.Logout()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to logout client %s: %s", client.Hostname(), err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = scenario.WaitForTailscaleLogout()
|
|
||||||
assertNoErrLogout(t, err)
|
|
||||||
|
|
||||||
t.Logf("all clients logged out")
|
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
|
||||||
assertNoErrGetHeadscale(t, err)
|
|
||||||
|
|
||||||
// if the server is not running with HTTPS, we have to wait a bit before
|
|
||||||
// reconnection as the newest Tailscale client has a measure that will only
|
|
||||||
// reconnect over HTTPS if they saw a noise connection previously.
|
|
||||||
// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38
|
|
||||||
// https://github.com/juanfont/headscale/issues/2164
|
|
||||||
if !https {
|
|
||||||
time.Sleep(5 * time.Minute)
|
|
||||||
}
|
|
||||||
|
|
||||||
for userName := range spec {
|
|
||||||
key, err := scenario.CreatePreAuthKey(userName, true, false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to run tailscale up for user %s: %s", userName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = scenario.WaitForTailscaleSync()
|
|
||||||
assertNoErrSync(t, err)
|
|
||||||
|
|
||||||
// assertClientsState(t, allClients)
|
|
||||||
|
|
||||||
allClients, err = scenario.ListTailscaleClients()
|
|
||||||
assertNoErrListClients(t, err)
|
|
||||||
|
|
||||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
|
||||||
assertNoErrListClientIPs(t, err)
|
|
||||||
|
|
||||||
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
|
|
||||||
return x.String()
|
|
||||||
})
|
|
||||||
|
|
||||||
success := pingAllHelper(t, allClients, allAddrs)
|
|
||||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
|
||||||
|
|
||||||
for _, client := range allClients {
|
|
||||||
ips, err := client.IPs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// lets check if the IPs are the same
|
|
||||||
if len(ips) != len(clientIPs[client]) {
|
|
||||||
t.Fatalf("IPs changed for client %s", client.Hostname())
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ip := range ips {
|
|
||||||
found := false
|
|
||||||
for _, oldIP := range clientIPs[client] {
|
|
||||||
if ip == oldIP {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
t.Fatalf(
|
|
||||||
"IPs changed for client %s. Used to be %v now %v",
|
|
||||||
client.Hostname(),
|
|
||||||
clientIPs[client],
|
|
||||||
ips,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEphemeral(t *testing.T) {
|
func TestEphemeral(t *testing.T) {
|
||||||
testEphemeralWithOptions(t, hsic.WithTestName("ephemeral"))
|
testEphemeralWithOptions(t, hsic.WithTestName("ephemeral"))
|
||||||
}
|
}
|
||||||
|
@ -314,21 +183,9 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
||||||
|
|
||||||
t.Logf("all clients logged out")
|
t.Logf("all clients logged out")
|
||||||
|
|
||||||
for userName := range spec {
|
nodes, err := headscale.ListNodes()
|
||||||
nodes, err := headscale.ListNodesInUser(userName)
|
assertNoErr(t, err)
|
||||||
if err != nil {
|
require.Len(t, nodes, 0)
|
||||||
log.Error().
|
|
||||||
Err(err).
|
|
||||||
Str("user", userName).
|
|
||||||
Msg("Error listing nodes in user")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nodes) != 0 {
|
|
||||||
t.Fatalf("expected no nodes, got %d in user %s", len(nodes), userName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not
|
// TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not
|
||||||
|
@ -431,7 +288,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) {
|
||||||
time.Sleep(3 * time.Minute)
|
time.Sleep(3 * time.Minute)
|
||||||
|
|
||||||
for userName := range spec {
|
for userName := range spec {
|
||||||
nodes, err := headscale.ListNodesInUser(userName)
|
nodes, err := headscale.ListNodes(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
|
|
@ -16,7 +16,7 @@ func DefaultConfigEnv() map[string]string {
|
||||||
"HEADSCALE_POLICY_PATH": "",
|
"HEADSCALE_POLICY_PATH": "",
|
||||||
"HEADSCALE_DATABASE_TYPE": "sqlite",
|
"HEADSCALE_DATABASE_TYPE": "sqlite",
|
||||||
"HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
|
"HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
|
||||||
"HEADSCALE_DATABASE_DEBUG": "1",
|
"HEADSCALE_DATABASE_DEBUG": "0",
|
||||||
"HEADSCALE_DATABASE_GORM_SLOW_THRESHOLD": "1",
|
"HEADSCALE_DATABASE_GORM_SLOW_THRESHOLD": "1",
|
||||||
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
|
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
|
||||||
"HEADSCALE_PREFIXES_V4": "100.64.0.0/10",
|
"HEADSCALE_PREFIXES_V4": "100.64.0.0/10",
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package hsic
|
package hsic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -10,6 +11,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -744,13 +746,59 @@ func (t *HeadscaleInContainer) CreateAuthKey(
|
||||||
return &preAuthKey, nil
|
return &preAuthKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNodesInUser list the TailscaleClients (Node, Headscale internal representation)
|
// ListNodes lists the currently registered Nodes in headscale.
|
||||||
// associated with a user.
|
// Optionally a list of usernames can be passed to get users for
|
||||||
func (t *HeadscaleInContainer) ListNodesInUser(
|
// specific users.
|
||||||
user string,
|
func (t *HeadscaleInContainer) ListNodes(
|
||||||
|
users ...string,
|
||||||
) ([]*v1.Node, error) {
|
) ([]*v1.Node, error) {
|
||||||
|
var ret []*v1.Node
|
||||||
|
execUnmarshal := func(command []string) error {
|
||||||
|
result, _, err := dockertestutil.ExecuteCommand(
|
||||||
|
t.container,
|
||||||
|
command,
|
||||||
|
[]string{},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute list node command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nodes []*v1.Node
|
||||||
|
err = json.Unmarshal([]byte(result), &nodes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal nodes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = append(ret, nodes...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) == 0 {
|
||||||
|
err := execUnmarshal([]string{"headscale", "nodes", "list", "--output", "json"})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, user := range users {
|
||||||
command := []string{"headscale", "--user", user, "nodes", "list", "--output", "json"}
|
command := []string{"headscale", "--user", user, "nodes", "list", "--output", "json"}
|
||||||
|
|
||||||
|
err := execUnmarshal(command)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(ret, func(i, j int) bool {
|
||||||
|
return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1
|
||||||
|
})
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListUsers returns a list of users from Headscale.
|
||||||
|
func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) {
|
||||||
|
command := []string{"headscale", "users", "list", "--output", "json"}
|
||||||
|
|
||||||
result, _, err := dockertestutil.ExecuteCommand(
|
result, _, err := dockertestutil.ExecuteCommand(
|
||||||
t.container,
|
t.container,
|
||||||
command,
|
command,
|
||||||
|
@ -760,13 +808,13 @@ func (t *HeadscaleInContainer) ListNodesInUser(
|
||||||
return nil, fmt.Errorf("failed to execute list node command: %w", err)
|
return nil, fmt.Errorf("failed to execute list node command: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var nodes []*v1.Node
|
var users []*v1.User
|
||||||
err = json.Unmarshal([]byte(result), &nodes)
|
err = json.Unmarshal([]byte(result), &users)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal nodes: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal nodes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteFile save file inside the Headscale container.
|
// WriteFile save file inside the Headscale container.
|
||||||
|
|
Loading…
Reference in a new issue