diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 65324f77..628168a0 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -21,6 +21,7 @@ jobs: - TestPolicyUpdateWhileRunningWithCLIInDatabase - TestOIDCAuthenticationPingAll - TestOIDCExpireNodesBasedOnTokenExpiry + - TestOIDC024UserCreation - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndRelogin - TestUserCommand diff --git a/CHANGELOG.md b/CHANGELOG.md index 465adc87..70a56aaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,11 @@ - Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020) - Having usernames in magic DNS is no longer possible. - Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020) - - `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC. + - `strip_email_domain` is deprecated, domain is _always_ part of the username for OIDC. + - The option is available until the migration strategy is removed. - Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated. + - By default, users are automatically migrated based on their username when logged in. + - This migration can be disabled, and should be on new installations or fully migrated installations [#2170](https://github.com/juanfont/headscale/pull/2170) - User has been extended to store username, display name, profile picture url and email. - These fields are forwarded to the client, and shows up nicely in the user switcher. - These fields can be made available via the API/CLI for non-OIDC users in the future. diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 568a2a03..309ad67d 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -1,8 +1,10 @@ package cli import ( + "encoding/json" "fmt" "net" + "net/http" "os" "strconv" "time" @@ -64,6 +66,19 @@ func mockOIDC() error { accessTTL = newTTL } + userStr := os.Getenv("MOCKOIDC_USERS") + if userStr == "" { + return fmt.Errorf("MOCKOIDC_USERS not defined") + } + + var users []mockoidc.MockUser + err := json.Unmarshal([]byte(userStr), &users) + if err != nil { + return fmt.Errorf("unmarshalling users: %w", err) + } + + log.Info().Interface("users", users).Msg("loading users from JSON") + log.Info().Msgf("Access token TTL: %s", accessTTL) port, err := strconv.Atoi(portStr) @@ -71,7 +86,7 @@ func mockOIDC() error { return err } - mock, err := getMockOIDC(clientID, clientSecret) + mock, err := getMockOIDC(clientID, clientSecret, users) if err != nil { return err } @@ -93,12 +108,18 @@ func mockOIDC() error { return nil } -func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) { +func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) { keypair, err := mockoidc.NewKeypair(nil) if err != nil { return nil, err } + userQueue := mockoidc.UserQueue{} + + for _, user := range users { + userQueue.Push(&user) + } + mock := mockoidc.MockOIDC{ ClientID: clientID, ClientSecret: clientSecret, @@ -107,9 +128,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro CodeChallengeMethodsSupported: []string{"plain", "S256"}, Keypair: keypair, SessionStore: mockoidc.NewSessionStore(), - UserQueue: &mockoidc.UserQueue{}, + UserQueue: &userQueue, ErrorQueue: &mockoidc.ErrorQueue{}, } + mock.AddMiddleware(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Info().Msgf("Request: %+v", r) + h.ServeHTTP(w, r) + if r.Response != nil { + log.Info().Msgf("Response: %+v", r.Response) + } + }) + }) + return &mock, nil } diff --git a/config-example.yaml b/config-example.yaml index 2632555d..c485698b 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -364,12 +364,18 @@ unix_socket_permission: "0770" # allowed_users: # - alice@example.com # -# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. -# # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` -# # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following -# user: `first-name.last-name.example.com` -# -# strip_email_domain: true +# # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users +# # by taking the username from the legacy user and matching it with the username +# # provided by the OIDC. This is useful when migrating from legacy users to OIDC +# # to force them using the unique identifier from the OIDC and to give them a +# # proper display name and picture if available. +# # Note that this will only work if the username from the legacy user is the same +# # and ther is a posibility for account takeover should a username have changed +# # with the provider. +# # Disabling this feature will cause all new logins to be created as new users. +# # Note this option will be removed in the future and should be set to false +# # on all new installations, or when all users have logged in with OIDC once. +# map_legacy_users: true # Logtail configuration # Logtail is Tailscales logging and auditing infrastructure, it allows the control panel diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index b7661ab2..529dc696 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -474,6 +474,8 @@ func NewHeadscaleDatabase( Rollback: func(db *gorm.DB) error { return nil }, }, { + // Pick up new user fields used for OIDC and to + // populate the user with more interesting information. ID: "202407191627", Migrate: func(tx *gorm.DB) error { err := tx.AutoMigrate(&types.User{}) @@ -485,6 +487,21 @@ func NewHeadscaleDatabase( }, Rollback: func(db *gorm.DB) error { return nil }, }, + { + // The unique constraint of Name has been dropped + // in favour of a unique together of name and + // provider identity. + ID: "202408181235", + Migrate: func(tx *gorm.DB) error { + err := tx.AutoMigrate(&types.User{}) + if err != nil { + return err + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 84267b41..78a7f96d 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -436,25 +436,42 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( ) (*types.User, error) { var user *types.User var err error - user, err = a.db.GetUserByOIDCIdentifier(claims.Sub) + user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, fmt.Errorf("creating or updating user: %w", err) } // This check is for legacy, if the user cannot be found by the OIDC identifier // look it up by username. This should only be needed once. - if user == nil { - user, err = a.db.GetUserByName(claims.Username) - if err != nil && !errors.Is(err, db.ErrUserNotFound) { - return nil, fmt.Errorf("creating or updating user: %w", err) - } + // This branch will presist for a number of versions after the OIDC migration and + // then be removed following a deprecation. + // TODO(kradalby): Remove when strip_email_domain and migration is removed + // after #2170 is cleaned up. + if a.cfg.MapLegacyUsers && user == nil { + log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username") + if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil { + log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username") + user, err = a.db.GetUserByName(oldUsername) + if err != nil && !errors.Is(err, db.ErrUserNotFound) { + return nil, fmt.Errorf("getting user: %w", err) + } - // if the user is still not found, create a new empty user. - if user == nil { - user = &types.User{} + // If the user exists, but it already has a provider identifier (OIDC sub), create a new user. + // This is to prevent users that have already been migrated to the new OIDC format + // to be updated with the new OIDC identifier inexplicitly which might be the cause of an + // account takeover. + if user != nil && user.ProviderIdentifier != "" { + log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") + user = &types.User{} + } } } + // if the user is still not found, create a new empty user. + if user == nil { + user = &types.User{} + } + user.FromClaim(claims) err = a.db.DB.Save(user).Error if err != nil { @@ -502,3 +519,24 @@ func renderOIDCCallbackTemplate( return &content, nil } + +// TODO(kradalby): Reintroduce when strip_email_domain is removed +// after #2170 is cleaned up +// DEPRECATED: DO NOT USE +func getUserName( + claims *types.OIDCClaims, + stripEmaildomain bool, +) (string, error) { + if !claims.EmailVerified { + return "", fmt.Errorf("email not verified") + } + userName, err := util.NormalizeToFQDNRules( + claims.Email, + stripEmaildomain, + ) + if err != nil { + return "", err + } + + return userName, nil +} diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index ff73985b..4839319e 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -178,7 +178,12 @@ func (pol *ACLPolicy) CompileFilterRules( for srcIndex, src := range acl.Sources { srcs, err := pol.expandSource(src, nodes) if err != nil { - return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) + return nil, fmt.Errorf( + "parsing policy, acl index: %d->%d: %w", + index, + srcIndex, + err, + ) } srcIPs = append(srcIPs, srcs...) } @@ -335,12 +340,21 @@ func (pol *ACLPolicy) CompileSSHPolicy( case "check": checkAction, err := sshCheckAction(sshACL.CheckPeriod) if err != nil { - return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) + return nil, fmt.Errorf( + "parsing SSH policy, parsing check duration, index: %d: %w", + index, + err, + ) } else { action = *checkAction } default: - return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) + return nil, fmt.Errorf( + "parsing SSH policy, unknown action %q, index: %d: %w", + sshACL.Action, + index, + err, + ) } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -977,10 +991,7 @@ func FilterNodesByACL( continue } - log.Printf("Checking if %s can access %s", node.Hostname, peer.Hostname) - if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { - log.Printf("CAN ACCESS %s can access %s", node.Hostname, peer.Hostname) result = append(result, peer) } } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index ec963793..e24e6a9a 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -162,8 +162,10 @@ type OIDCConfig struct { AllowedDomains []string AllowedUsers []string AllowedGroups []string + StripEmaildomain bool Expiry time.Duration UseExpiryFromToken bool + MapLegacyUsers bool } type DERPConfig struct { @@ -272,9 +274,11 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("database.sqlite.write_ahead_log", true) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) + viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.use_expiry_from_token", false) + viper.SetDefault("oidc.map_legacy_users", true) viper.SetDefault("logtail.enabled", false) viper.SetDefault("randomize_client_port", false) @@ -318,14 +322,18 @@ func validateServerConfig() error { depr.warn("dns_config.use_username_in_magic_dns") depr.warn("dns.use_username_in_magic_dns") - depr.fatal("oidc.strip_email_domain") + // TODO(kradalby): Reintroduce when strip_email_domain is removed + // after #2170 is cleaned up + // depr.fatal("oidc.strip_email_domain") depr.fatal("dns.use_username_in_musername_in_magic_dns") depr.fatal("dns_config.use_username_in_musername_in_magic_dns") depr.Log() for _, removed := range []string{ - "oidc.strip_email_domain", + // TODO(kradalby): Reintroduce when strip_email_domain is removed + // after #2170 is cleaned up + // "oidc.strip_email_domain", "dns_config.use_username_in_musername_in_magic_dns", } { if viper.IsSet(removed) { @@ -897,6 +905,10 @@ func LoadServerConfig() (*Config, error) { } }(), UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), + // TODO(kradalby): Remove when strip_email_domain is removed + // after #2170 is cleaned up + StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), + MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"), }, LogTail: logTailConfig, diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index f983d7f5..5b27e671 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -19,10 +19,14 @@ type UserID uint64 // that contain our machines. type User struct { gorm.Model + // The index `idx_name_provider_identifier` is to enforce uniqueness + // between Name and ProviderIdentifier. This ensures that + // you can have multiple usersnames of the same name in OIDC, + // but not if you only run with CLI users. // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"unique"` + Name string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` // Typically the full name of the user DisplayName string @@ -34,7 +38,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier string `gorm:"index"` + ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"` // Provider is the origin of the user account, // same as RegistrationMethod, without authkey. @@ -50,8 +54,15 @@ type User struct { // enabled with OIDC, which means that there is a domain involved which // should be used throughout headscale, in information returned to the // user and the Policy engine. +// If the username does not contain an '@' it will be added to the end. func (u *User) Username() string { - return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + // TODO(kradalby): Wire up all of this for the future + // if !strings.Contains(username, "@") { + // username = username + "@" + // } + + return username } // DisplayNameOrUsername returns the DisplayName if it exists, otherwise @@ -116,6 +127,7 @@ func (u *User) Proto() *v1.User { type OIDCClaims struct { // Sub is the user's unique identifier at the provider. Sub string `json:"sub"` + Iss string `json:"iss"` // Name is the user's full name. Name string `json:"name,omitempty"` @@ -126,12 +138,18 @@ type OIDCClaims struct { Username string `json:"preferred_username,omitempty"` } +func (c *OIDCClaims) Identifier() string { + return c.Iss + "/" + c.Sub +} + // FromClaim overrides a User from OIDC claims. // All fields will be updated, except for the ID. func (u *User) FromClaim(claims *OIDCClaims) { - u.ProviderIdentifier = claims.Sub + u.ProviderIdentifier = claims.Identifier() u.DisplayName = claims.Name - u.Email = claims.Email + if claims.EmailVerified { + u.Email = claims.Email + } u.Name = claims.Username u.ProfilePicURL = claims.ProfilePictureURL u.Provider = util.RegisterMethodOIDC diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index f57576f4..bf43eb50 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -182,3 +182,33 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { return fqdns } + +// TODO(kradalby): Reintroduce when strip_email_domain is removed +// after #2170 is cleaned up +// DEPRECATED: DO NOT USE +// NormalizeToFQDNRules will replace forbidden chars in user +// it can also return an error if the user doesn't respect RFC 952 and 1123. +func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { + + name = strings.ToLower(name) + name = strings.ReplaceAll(name, "'", "") + atIdx := strings.Index(name, "@") + if stripEmailDomain && atIdx > 0 { + name = name[:atIdx] + } else { + name = strings.ReplaceAll(name, "@", ".") + } + name = invalidCharsInUserRegex.ReplaceAllString(name, "-") + + for _, elt := range strings.Split(name, ".") { + if len(elt) > LabelHostnameLength { + return "", fmt.Errorf( + "label %v is more than 63 chars: %w", + elt, + ErrInvalidUserName, + ) + } + } + + return name, nil +} diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 6fbdd9e4..25fb358c 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -3,6 +3,7 @@ package integration import ( "context" "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -10,14 +11,19 @@ import ( "net" "net/http" "net/netip" + "sort" "strconv" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" + "github.com/oauth2-proxy/mockoidc" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "github.com/samber/lo" @@ -50,18 +56,32 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { } defer scenario.ShutdownAssertNoPanics(t) + // Logins to MockOIDC is served by a queue with a strict order, + // if we use more than one node per user, the order of the logins + // will not be deterministic and the test will fail. spec := map[string]int{ - "user1": len(MustTestVersions), + "user1": 1, + "user2": 1, } - oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) + mockusers := []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + } + + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() oidcMap := map[string]string{ "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "CREDENTIALS_DIRECTORY_TEST": "/tmp", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // TODO(kradalby): Remove when strip_email_domain is removed + // after #2170 is cleaned up + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", } err = scenario.CreateHeadscaleEnv( @@ -91,6 +111,55 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + var listUsers []v1.User + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listUsers, + ) + assertNoErr(t, err) + + want := []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Email: "", // Unverified + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user2", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].Id < listUsers[j].Id + }) + + if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } } // This test is really flaky. @@ -111,11 +180,16 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ - "user1": 3, + "user1": 1, + "user2": 1, } - oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) + oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + }) assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() oidcMap := map[string]string{ "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, @@ -159,6 +233,297 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assertTailscaleNodesLogout(t, allClients) } +// TODO(kradalby): +// - Test that creates a new user when one exists when migration is turned off +// - Test that takes over a user when one exists when migration is turned on +// - But email is not verified +// - stripped email domain on/off +func TestOIDC024UserCreation(t *testing.T) { + IntegrationSkip(t) + + tests := []struct { + name string + config map[string]string + emailVerified bool + cliUsers []string + oidcUsers []string + want func(iss string) []v1.User + }{ + { + name: "no-migration-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + }, + emailVerified: true, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "no-migration-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + }, + emailVerified: false, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-strip-domains-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", + }, + emailVerified: true, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "2", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-strip-domains-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", + }, + emailVerified: false, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-no-strip-domains-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + }, + emailVerified: true, + cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + // Hmm I think we will have to overwrite the initial name here + // createuser with "user1.headscale.net", but oidc with "user1" + { + Id: "1", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "2", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-no-strip-domains-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + }, + emailVerified: false, + cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1.headscale.net", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2.headscale.net", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + baseScenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + + scenario := AuthOIDCScenario{ + Scenario: baseScenario, + } + defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{} + for _, user := range tt.cliUsers { + spec[user] = 1 + } + + var mockusers []mockoidc.MockUser + for _, user := range tt.oidcUsers { + mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified)) + } + + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) + assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, + "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + for k, v := range tt.config { + oidcMap[k] = v + } + + err = scenario.CreateHeadscaleEnv( + spec, + hsic.WithTestName("oidcmigration"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + ) + assertNoErrHeadscaleEnv(t, err) + + // Ensure that the nodes have logged in, this is what + // triggers user creation via OIDC. + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + want := tt.want(oidcConfig.Issuer) + + var listUsers []v1.User + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listUsers, + ) + assertNoErr(t, err) + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].Id < listUsers[j].Id + }) + + if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Errorf("unexpected users: %s", diff) + } + }) + } +} + func (s *AuthOIDCScenario) CreateHeadscaleEnv( users map[string]int, opts ...hsic.Option, @@ -174,6 +539,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( } for userName, clientCount := range users { + if clientCount != 1 { + // OIDC scenario only supports one client per user. + // This is because the MockOIDC server can only serve login + // requests based on a queue it has been given on startup. + // We currently only populates it with one login request per user. + return fmt.Errorf("client count must be 1 for OIDC scenario.") + } log.Printf("creating user %s with %d clients", userName, clientCount) err = s.CreateUser(userName) if err != nil { @@ -194,7 +566,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( return nil } -func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { +func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { port, err := dockertestutil.RandomFreeHostPort() if err != nil { log.Fatalf("could not find an open port: %s", err) @@ -205,6 +577,11 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf hostname := fmt.Sprintf("hs-oidcmock-%s", hash) + usersJSON, err := json.Marshal(users) + if err != nil { + return nil, err + } + mockOidcOptions := &dockertest.RunOptions{ Name: hostname, Cmd: []string{"headscale", "mockoidc"}, @@ -219,6 +596,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf "MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_SECRET=supersecret", fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), + fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), }, } @@ -310,45 +688,40 @@ func (s *AuthOIDCScenario) runTailscaleUp( log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) - if err := s.pool.Retry(func() error { - log.Printf("%s logging in with url", c.Hostname()) - httpClient := &http.Client{Transport: insecureTransport} - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf( - "%s failed to login using url %s: %s", - c.Hostname(), - loginURL, - err, - ) + log.Printf("%s logging in with url", c.Hostname()) + httpClient := &http.Client{Transport: insecureTransport} + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + resp, err := httpClient.Do(req) + if err != nil { + log.Printf( + "%s failed to login using url %s: %s", + c.Hostname(), + loginURL, + err, + ) - return err - } + return err + } - if resp.StatusCode != http.StatusOK { - log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + if resp.StatusCode != http.StatusOK { + log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + body, _ := io.ReadAll(resp.Body) + log.Printf("body: %s", body) - return errStatusCodeNotOK - } + return errStatusCodeNotOK + } - defer resp.Body.Close() + defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - log.Printf("%s failed to read response body: %s", c.Hostname(), err) + _, err = io.ReadAll(resp.Body) + if err != nil { + log.Printf("%s failed to read response body: %s", c.Hostname(), err) - return err - } - - return nil - }); err != nil { return err } log.Printf("Finished request for %s to join tailnet", c.Hostname()) - return nil }) @@ -395,3 +768,12 @@ func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { assert.Equal(t, "NeedsLogin", status.BackendState) } } + +func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { + return mockoidc.MockUser{ + Subject: username, + PreferredUsername: username, + Email: fmt.Sprintf("%s@headscale.net", username), + EmailVerified: emailVerified, + } +} diff --git a/integration/cli_test.go b/integration/cli_test.go index 2b81e814..a6c8258e 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -212,7 +212,9 @@ func TestPreAuthKeyCommand(t *testing.T) { continue } - assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"}) + tags := listedPreAuthKeys[index].GetAclTags() + sort.Strings(tags) + assert.Equal(t, []string{"tag:test1", "tag:test2"}, tags) } // Test key expiry diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index 1b41e324..9e16f366 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -74,7 +74,7 @@ func ExecuteCommand( select { case res := <-resultChan: if res.err != nil { - return stdout.String(), stderr.String(), res.err + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), res.err) } if res.exitCode != 0 { @@ -83,12 +83,12 @@ func ExecuteCommand( // log.Println("stdout: ", stdout.String()) // log.Println("stderr: ", stderr.String()) - return stdout.String(), stderr.String(), ErrDockertestCommandFailed + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed) } return stdout.String(), stderr.String(), nil case <-time.After(execConfig.timeout): - return stdout.String(), stderr.String(), ErrDockertestCommandTimeout + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandTimeout) } }