From 1f73616f90992fd73e21846faed55391c8981cd6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 4 Oct 2024 12:24:35 +0200 Subject: [PATCH 01/14] Harden OIDC migration and make optional This commit hardens the migration part of the OIDC from the old username based approach to the new sub based approach and makes it possible for the operator to opt out entirely. Fixes #1990 Signed-off-by: Kristoffer Dalby --- config-example.yaml | 18 ++++++++++++------ hscontrol/oidc.go | 13 ++++++++++++- hscontrol/types/config.go | 3 +++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/config-example.yaml b/config-example.yaml index 2632555d..c485698b 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -364,12 +364,18 @@ unix_socket_permission: "0770" # allowed_users: # - alice@example.com # -# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. -# # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` -# # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following -# user: `first-name.last-name.example.com` -# -# strip_email_domain: true +# # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users +# # by taking the username from the legacy user and matching it with the username +# # provided by the OIDC. This is useful when migrating from legacy users to OIDC +# # to force them using the unique identifier from the OIDC and to give them a +# # proper display name and picture if available. +# # Note that this will only work if the username from the legacy user is the same +# # and ther is a posibility for account takeover should a username have changed +# # with the provider. +# # Disabling this feature will cause all new logins to be created as new users. +# # Note this option will be removed in the future and should be set to false +# # on all new installations, or when all users have logged in with OIDC once. +# map_legacy_users: true # Logtail configuration # Logtail is Tailscales logging and auditing infrastructure, it allows the control panel diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 10008e67..a4775ae8 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -443,7 +443,9 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( // This check is for legacy, if the user cannot be found by the OIDC identifier // look it up by username. This should only be needed once. - if user == nil { + // This branch will presist for a number of versions after the OIDC migration and + // then be removed following a deprecation. + if a.cfg.MapLegacyUsers && user == nil { user, err = a.db.GetUserByName(claims.Username) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, fmt.Errorf("creating or updating user: %w", err) @@ -453,6 +455,15 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( if user == nil { user = &types.User{} } + + // If the user exists, but it already has a provider identifier (OIDC sub), create a new user. + // This is to prevent users that have already been migrated to the new OIDC format + // to be updated with the new OIDC identifier inexplicitly which might be the cause of an + // account takeover. + if user.ProviderIdentifier != "" { + log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") + user = &types.User{} + } } user.FromClaim(claims) diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 5895ebc9..4af000a5 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -165,6 +165,7 @@ type OIDCConfig struct { AllowedGroups []string Expiry time.Duration UseExpiryFromToken bool + MapLegacyUsers bool } type DERPConfig struct { @@ -276,6 +277,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.use_expiry_from_token", false) + viper.SetDefault("oidc.map_legacy_users", true) viper.SetDefault("logtail.enabled", false) viper.SetDefault("randomize_client_port", false) @@ -897,6 +899,7 @@ func LoadServerConfig() (*Config, error) { } }(), UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), + MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"), }, LogTail: logTailConfig, From d7363e1c144551315a0575e9910bce8ab5e12d53 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 4 Oct 2024 12:29:52 +0200 Subject: [PATCH 02/14] update changelog Signed-off-by: Kristoffer Dalby --- CHANGELOG.md | 78 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ca0ed05..cf6766b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,16 +2,82 @@ ## Next +### Security fix: OIDC changes in Headscale 0.24.0 + +_Headscale v0.23.0 and earlier_ identified OIDC users by the "username" part of their email address (when `strip_email_domain: true`, the default) or whole email address (when `strip_email_domain: false`). + +Depending on how Headscale and your Identity Provider (IdP) were configured, only using the `email` claim could allow a malicious user with an IdP account to take over another Headscale user's account, even when `strip_email_domain: false`. + +This would also cause a user to lose access to their Headscale account if they changed their email address. + +_Headscale v0.24.0_ now identifies OIDC users by the `iss` and `sub` claims. [These are guaranteed by the OIDC specification to be stable and unique](https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability), even if a user changes email address. A well-designed IdP will typically set `sub` to an opaque identifier like a UUID or numeric ID, which has no relation to the user's name or email address. + +This issue _only_ affects Headscale installations which authenticate with OIDC. + +Headscale v0.24.0 and later will also automatically update profile fields with OIDC data on login. This means that users can change those details in your IdP, and have it populate to Headscale automatically the next time they log in. However, this may affect the way you reference users in policies. + +#### Migrating existing installations + +Headscale v0.23.0 and earlier never recorded the `iss` and `sub` fields, so all legacy (existing) OIDC accounts from _need to be migrated_ to be properly secured. + +Headscale v0.24.0 has an automatic migration feature, which is enabled by default (`map_legacy_users: true`). **This will be disabled by default in a future version of Headscale – any unmigrated users will get new accounts.** + +Headscale v0.24.0 will ignore any `email` claim if the IdP does not provide an `email_verified` claim set to `true`. [What "verified" actually means is contextually dependent](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) – Headscale uses it as a signal that the contents of the `email` claim is reasonably trustworthy. + +Headscale v0.23.0 and earlier never checked the `email_verified` claim. This means even if an IdP explicitly indicated to Headscale that its `email` claim was untrustworthy, Headscale would have still accepted it. + +##### What does automatic migration do? + +When automatic migration is enabled (`map_legacy_users: true`), Headscale will first match an OIDC account to a Headscale account by `iss` and `sub`, and then fall back to matching OIDC users similarly to how Headscale v0.23.0 did: + +- If `strip_email_domain: true` (the default): the Headscale username matches the "username" part of their email address. +- If `strip_email_domain: false`: the Headscale username matches the _whole_ email address. + +On migration, Headscale will change the account's username to their `preferred_username`. **This could break any ACLs or policies which are configured to match by username.** + +Like with Headscale v0.23.0 and earlier, this migration only works for users who haven't changed their email address since their last Headscale login. + +A _successful_ automated migration should otherwise be transparent to users. + +Once a Headscale account has been migrated, it will be _unavailable_ to be matched by the legacy process. An OIDC login with a matching username, but _non-matching_ `iss` and `sub` will instead get a _new_ Headscale account. + +Because of the way OIDC works, Headscale's automated migration process can _only_ work when a user tries to log in after the update. Mass updates would require Headscale implement a protocol like SCIM, which is **extremely** complicated and not available in all identity providers. + +Administrators could also attempt to migrate users manually by editing the database, using their own mapping rules with known-good data sources. + +Legacy account migration should have no effect on new installations where all users have a recorded `sub` and `iss`. + +##### What happens when automatic migration is disabled? + +When automatic migration is disabled (`map_legacy_users: false`), Headscale will only try to match an OIDC account to a Headscale account by `iss` and `sub`. + +If there is no match, it will get a _new_ Headscale account – even if there was a legacy account which _could_ have matched and migrated. + +We recommend new Headscale users explicitly disable automatic migration – but it should otherwise have no effect if every account has a recorded `iss` and `sub`. + +When automatic migration is disabled, the `strip_email_domain` setting will have no effect. + +Special thanks to @micolous for reviewing, proposing and working with us on these changes. + +#### Other OIDC changes + +Headscale now uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to populate and update user information every time they log in: + +| Headscale profile field | OIDC claim | Notes / examples | +| ----------------------- | -------------------- | --------------------------------------------------------------------------------------------------------- | +| email address | `email` | Only used when `"email_verified": true` | +| display name | `name` | eg: `Sam Smith` | +| username | `preferred_username` | Varies depending on IdP and configuration, eg: `ssmith`, `ssmith@idp.example.com`, `\\example.com\ssmith` | +| profile picture | `picture` | URL to a profile picture or avatar | + +These should show up nicely in the Tailscale client. + +This will also affect the way you [reference users in policies](https://github.com/juanfont/headscale/pull/2205). + ### BREAKING - Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020) - Having usernames in magic DNS is no longer possible. -- Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020) - - `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC. - - Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated. - - User has been extended to store username, display name, profile picture url and email. - - These fields are forwarded to the client, and shows up nicely in the user switcher. - - These fields can be made available via the API/CLI for non-OIDC users in the future. - Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149) - Clean up old code required by old versions From 9f56a723ef51a8f3cb80cf6d59b947b66a44b46d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 7 Oct 2024 17:41:54 +0200 Subject: [PATCH 03/14] remove log print Signed-off-by: Kristoffer Dalby --- hscontrol/policy/acls.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 225667ec..9e1172fd 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -178,7 +178,12 @@ func (pol *ACLPolicy) CompileFilterRules( for srcIndex, src := range acl.Sources { srcs, err := pol.expandSource(src, nodes) if err != nil { - return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) + return nil, fmt.Errorf( + "parsing policy, acl index: %d->%d: %w", + index, + srcIndex, + err, + ) } srcIPs = append(srcIPs, srcs...) } @@ -335,12 +340,21 @@ func (pol *ACLPolicy) CompileSSHPolicy( case "check": checkAction, err := sshCheckAction(sshACL.CheckPeriod) if err != nil { - return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) + return nil, fmt.Errorf( + "parsing SSH policy, parsing check duration, index: %d: %w", + index, + err, + ) } else { action = *checkAction } default: - return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) + return nil, fmt.Errorf( + "parsing SSH policy, unknown action %q, index: %d: %w", + sshACL.Action, + index, + err, + ) } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -977,10 +991,7 @@ func FilterNodesByACL( continue } - log.Printf("Checking if %s can access %s", node.Hostname, peer.Hostname) - if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { - log.Printf("CAN ACCESS %s can access %s", node.Hostname, peer.Hostname) result = append(result, peer) } } From 8302291e38c40cc3e8944f8ec7975d56772285c2 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 17 Oct 2024 05:58:25 -0600 Subject: [PATCH 04/14] add @ to end of username if not present Signed-off-by: Kristoffer Dalby --- hscontrol/types/users.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index f983d7f5..db8a50bd 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -3,6 +3,7 @@ package types import ( "cmp" "strconv" + "strings" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -50,8 +51,14 @@ type User struct { // enabled with OIDC, which means that there is a domain involved which // should be used throughout headscale, in information returned to the // user and the Policy engine. +// If the username does not contain an '@' it will be added to the end. func (u *User) Username() string { - return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + if !strings.Contains(username, "@") { + username = username + "@" + } + + return username } // DisplayNameOrUsername returns the DisplayName if it exists, otherwise From 939c233b8d31b2611211d182ed82417543cb1282 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 17 Oct 2024 05:58:44 -0600 Subject: [PATCH 05/14] add iss to identifier, only set email if verified Signed-off-by: Kristoffer Dalby --- hscontrol/db/db.go | 17 +++++++++++++++++ hscontrol/types/users.go | 15 +++++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index b7661ab2..529dc696 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -474,6 +474,8 @@ func NewHeadscaleDatabase( Rollback: func(db *gorm.DB) error { return nil }, }, { + // Pick up new user fields used for OIDC and to + // populate the user with more interesting information. ID: "202407191627", Migrate: func(tx *gorm.DB) error { err := tx.AutoMigrate(&types.User{}) @@ -485,6 +487,21 @@ func NewHeadscaleDatabase( }, Rollback: func(db *gorm.DB) error { return nil }, }, + { + // The unique constraint of Name has been dropped + // in favour of a unique together of name and + // provider identity. + ID: "202408181235", + Migrate: func(tx *gorm.DB) error { + err := tx.AutoMigrate(&types.User{}) + if err != nil { + return err + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index db8a50bd..3ed6981e 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -20,10 +20,14 @@ type UserID uint64 // that contain our machines. type User struct { gorm.Model + // The index `idx_name_provider_identifier` is to enforce uniqueness + // between Name and ProviderIdentifier. This ensures that + // you can have multiple usersnames of the same name in OIDC, + // but not if you only run with CLI users. // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"unique"` + Name string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` // Typically the full name of the user DisplayName string @@ -35,7 +39,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier string `gorm:"index"` + ProviderIdentifier string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` // Provider is the origin of the user account, // same as RegistrationMethod, without authkey. @@ -123,6 +127,7 @@ func (u *User) Proto() *v1.User { type OIDCClaims struct { // Sub is the user's unique identifier at the provider. Sub string `json:"sub"` + Iss string `json:"iss"` // Name is the user's full name. Name string `json:"name,omitempty"` @@ -136,9 +141,11 @@ type OIDCClaims struct { // FromClaim overrides a User from OIDC claims. // All fields will be updated, except for the ID. func (u *User) FromClaim(claims *OIDCClaims) { - u.ProviderIdentifier = claims.Sub + u.ProviderIdentifier = claims.Iss + "/" + claims.Sub u.DisplayName = claims.Name - u.Email = claims.Email + if claims.EmailVerified { + u.Email = claims.Email + } u.Name = claims.Username u.ProfilePicURL = claims.ProfilePictureURL u.Provider = util.RegisterMethodOIDC From 8059d475a4b9e06905458bfa2dc5cc61a482ba6a Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Thu, 17 Oct 2024 06:22:44 -0600 Subject: [PATCH 06/14] restore strip_email_domain for migration Signed-off-by: Kristoffer Dalby --- hscontrol/oidc.go | 52 ++++++++++++++++++++++++++++----------- hscontrol/types/config.go | 10 ++++++-- hscontrol/util/dns.go | 30 ++++++++++++++++++++++ 3 files changed, 75 insertions(+), 17 deletions(-) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index a4775ae8..ad518b90 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -445,25 +445,29 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( // look it up by username. This should only be needed once. // This branch will presist for a number of versions after the OIDC migration and // then be removed following a deprecation. + // TODO(kradalby): Remove when strip_email_domain and migration is removed + // after #2170 is cleaned up. if a.cfg.MapLegacyUsers && user == nil { - user, err = a.db.GetUserByName(claims.Username) - if err != nil && !errors.Is(err, db.ErrUserNotFound) { - return nil, fmt.Errorf("creating or updating user: %w", err) - } + if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil { + user, err = a.db.GetUserByName(oldUsername) + if err != nil && !errors.Is(err, db.ErrUserNotFound) { + return nil, fmt.Errorf("creating or updating user: %w", err) + } - // if the user is still not found, create a new empty user. - if user == nil { - user = &types.User{} + // If the user exists, but it already has a provider identifier (OIDC sub), create a new user. + // This is to prevent users that have already been migrated to the new OIDC format + // to be updated with the new OIDC identifier inexplicitly which might be the cause of an + // account takeover. + if user != nil && user.ProviderIdentifier != "" { + log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") + user = &types.User{} + } } + } - // If the user exists, but it already has a provider identifier (OIDC sub), create a new user. - // This is to prevent users that have already been migrated to the new OIDC format - // to be updated with the new OIDC identifier inexplicitly which might be the cause of an - // account takeover. - if user.ProviderIdentifier != "" { - log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") - user = &types.User{} - } + // if the user is still not found, create a new empty user. + if user == nil { + user = &types.User{} } user.FromClaim(claims) @@ -513,3 +517,21 @@ func renderOIDCCallbackTemplate( return &content, nil } + +// TODO(kradalby): Reintroduce when strip_email_domain is removed +// after #2170 is cleaned up +// DEPRECATED: DO NOT USE +func getUserName( + claims *types.OIDCClaims, + stripEmaildomain bool, +) (string, error) { + userName, err := util.NormalizeToFQDNRules( + claims.Email, + stripEmaildomain, + ) + if err != nil { + return "", err + } + + return userName, nil +} diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4af000a5..bbce18b8 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -163,6 +163,7 @@ type OIDCConfig struct { AllowedDomains []string AllowedUsers []string AllowedGroups []string + StripEmaildomain bool Expiry time.Duration UseExpiryFromToken bool MapLegacyUsers bool @@ -274,6 +275,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("database.sqlite.write_ahead_log", true) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) + viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.use_expiry_from_token", false) @@ -321,14 +323,18 @@ func validateServerConfig() error { depr.warn("dns_config.use_username_in_magic_dns") depr.warn("dns.use_username_in_magic_dns") - depr.fatal("oidc.strip_email_domain") + // TODO(kradalby): Reintroduce when strip_email_domain is removed + // after #2170 is cleaned up + // depr.fatal("oidc.strip_email_domain") depr.fatal("dns.use_username_in_musername_in_magic_dns") depr.fatal("dns_config.use_username_in_musername_in_magic_dns") depr.Log() for _, removed := range []string{ - "oidc.strip_email_domain", + // TODO(kradalby): Reintroduce when strip_email_domain is removed + // after #2170 is cleaned up + // "oidc.strip_email_domain", "dns_config.use_username_in_musername_in_magic_dns", } { if viper.IsSet(removed) { diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index f57576f4..bf43eb50 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -182,3 +182,33 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { return fqdns } + +// TODO(kradalby): Reintroduce when strip_email_domain is removed +// after #2170 is cleaned up +// DEPRECATED: DO NOT USE +// NormalizeToFQDNRules will replace forbidden chars in user +// it can also return an error if the user doesn't respect RFC 952 and 1123. +func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { + + name = strings.ToLower(name) + name = strings.ReplaceAll(name, "'", "") + atIdx := strings.Index(name, "@") + if stripEmailDomain && atIdx > 0 { + name = name[:atIdx] + } else { + name = strings.ReplaceAll(name, "@", ".") + } + name = invalidCharsInUserRegex.ReplaceAllString(name, "-") + + for _, elt := range strings.Split(name, ".") { + if len(elt) > LabelHostnameLength { + return "", fmt.Errorf( + "label %v is more than 63 chars: %w", + elt, + ErrInvalidUserName, + ) + } + } + + return name, nil +} From 69b9abaa6c1cb193db1fbe5c44f1c18c0801db3f Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 18 Oct 2024 06:59:27 -0600 Subject: [PATCH 07/14] fix oidc test, add tests for migration Signed-off-by: Kristoffer Dalby --- .github/workflows/test-integration.yaml | 1 + cmd/headscale/cli/mockoidc.go | 37 +- hscontrol/oidc.go | 9 +- hscontrol/types/config.go | 5 +- hscontrol/types/users.go | 16 +- integration/auth_oidc_test.go | 450 ++++++++++++++++++++++-- integration/dockertestutil/execute.go | 6 +- 7 files changed, 475 insertions(+), 49 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 7e730aa8..1e514f24 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -21,6 +21,7 @@ jobs: - TestPolicyUpdateWhileRunningWithCLIInDatabase - TestOIDCAuthenticationPingAll - TestOIDCExpireNodesBasedOnTokenExpiry + - TestOIDC024UserCreation - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndRelogin - TestUserCommand diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 568a2a03..309ad67d 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -1,8 +1,10 @@ package cli import ( + "encoding/json" "fmt" "net" + "net/http" "os" "strconv" "time" @@ -64,6 +66,19 @@ func mockOIDC() error { accessTTL = newTTL } + userStr := os.Getenv("MOCKOIDC_USERS") + if userStr == "" { + return fmt.Errorf("MOCKOIDC_USERS not defined") + } + + var users []mockoidc.MockUser + err := json.Unmarshal([]byte(userStr), &users) + if err != nil { + return fmt.Errorf("unmarshalling users: %w", err) + } + + log.Info().Interface("users", users).Msg("loading users from JSON") + log.Info().Msgf("Access token TTL: %s", accessTTL) port, err := strconv.Atoi(portStr) @@ -71,7 +86,7 @@ func mockOIDC() error { return err } - mock, err := getMockOIDC(clientID, clientSecret) + mock, err := getMockOIDC(clientID, clientSecret, users) if err != nil { return err } @@ -93,12 +108,18 @@ func mockOIDC() error { return nil } -func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) { +func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) { keypair, err := mockoidc.NewKeypair(nil) if err != nil { return nil, err } + userQueue := mockoidc.UserQueue{} + + for _, user := range users { + userQueue.Push(&user) + } + mock := mockoidc.MockOIDC{ ClientID: clientID, ClientSecret: clientSecret, @@ -107,9 +128,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro CodeChallengeMethodsSupported: []string{"plain", "S256"}, Keypair: keypair, SessionStore: mockoidc.NewSessionStore(), - UserQueue: &mockoidc.UserQueue{}, + UserQueue: &userQueue, ErrorQueue: &mockoidc.ErrorQueue{}, } + mock.AddMiddleware(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Info().Msgf("Request: %+v", r) + h.ServeHTTP(w, r) + if r.Response != nil { + log.Info().Msgf("Response: %+v", r.Response) + } + }) + }) + return &mock, nil } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index ad518b90..fce7e455 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -436,7 +436,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( ) (*types.User, error) { var user *types.User var err error - user, err = a.db.GetUserByOIDCIdentifier(claims.Sub) + user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, fmt.Errorf("creating or updating user: %w", err) } @@ -448,10 +448,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( // TODO(kradalby): Remove when strip_email_domain and migration is removed // after #2170 is cleaned up. if a.cfg.MapLegacyUsers && user == nil { + log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username") if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil { + log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username") user, err = a.db.GetUserByName(oldUsername) if err != nil && !errors.Is(err, db.ErrUserNotFound) { - return nil, fmt.Errorf("creating or updating user: %w", err) + return nil, fmt.Errorf("getting user: %w", err) } // If the user exists, but it already has a provider identifier (OIDC sub), create a new user. @@ -525,6 +527,9 @@ func getUserName( claims *types.OIDCClaims, stripEmaildomain bool, ) (string, error) { + if !claims.EmailVerified { + return "", fmt.Errorf("email not verified") + } userName, err := util.NormalizeToFQDNRules( claims.Email, stripEmaildomain, diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index bbce18b8..1a051135 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -905,7 +905,10 @@ func LoadServerConfig() (*Config, error) { } }(), UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), - MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"), + // TODO(kradalby): Remove when strip_email_domain is removed + // after #2170 is cleaned up + StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), + MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"), }, LogTail: logTailConfig, diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 3ed6981e..5b27e671 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -3,7 +3,6 @@ package types import ( "cmp" "strconv" - "strings" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -39,7 +38,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` + ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"` // Provider is the origin of the user account, // same as RegistrationMethod, without authkey. @@ -58,9 +57,10 @@ type User struct { // If the username does not contain an '@' it will be added to the end. func (u *User) Username() string { username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) - if !strings.Contains(username, "@") { - username = username + "@" - } + // TODO(kradalby): Wire up all of this for the future + // if !strings.Contains(username, "@") { + // username = username + "@" + // } return username } @@ -138,10 +138,14 @@ type OIDCClaims struct { Username string `json:"preferred_username,omitempty"` } +func (c *OIDCClaims) Identifier() string { + return c.Iss + "/" + c.Sub +} + // FromClaim overrides a User from OIDC claims. // All fields will be updated, except for the ID. func (u *User) FromClaim(claims *OIDCClaims) { - u.ProviderIdentifier = claims.Iss + "/" + claims.Sub + u.ProviderIdentifier = claims.Identifier() u.DisplayName = claims.Name if claims.EmailVerified { u.Email = claims.Email diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 6fbdd9e4..25fb358c 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -3,6 +3,7 @@ package integration import ( "context" "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -10,14 +11,19 @@ import ( "net" "net/http" "net/netip" + "sort" "strconv" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" + "github.com/oauth2-proxy/mockoidc" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "github.com/samber/lo" @@ -50,18 +56,32 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { } defer scenario.ShutdownAssertNoPanics(t) + // Logins to MockOIDC is served by a queue with a strict order, + // if we use more than one node per user, the order of the logins + // will not be deterministic and the test will fail. spec := map[string]int{ - "user1": len(MustTestVersions), + "user1": 1, + "user2": 1, } - oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) + mockusers := []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + } + + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() oidcMap := map[string]string{ "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, "CREDENTIALS_DIRECTORY_TEST": "/tmp", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // TODO(kradalby): Remove when strip_email_domain is removed + // after #2170 is cleaned up + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", } err = scenario.CreateHeadscaleEnv( @@ -91,6 +111,55 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + var listUsers []v1.User + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listUsers, + ) + assertNoErr(t, err) + + want := []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Email: "", // Unverified + Provider: "oidc", + ProviderId: oidcConfig.Issuer + "/user2", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].Id < listUsers[j].Id + }) + + if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } } // This test is really flaky. @@ -111,11 +180,16 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { defer scenario.ShutdownAssertNoPanics(t) spec := map[string]int{ - "user1": 3, + "user1": 1, + "user2": 1, } - oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) + oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + }) assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() oidcMap := map[string]string{ "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, @@ -159,6 +233,297 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assertTailscaleNodesLogout(t, allClients) } +// TODO(kradalby): +// - Test that creates a new user when one exists when migration is turned off +// - Test that takes over a user when one exists when migration is turned on +// - But email is not verified +// - stripped email domain on/off +func TestOIDC024UserCreation(t *testing.T) { + IntegrationSkip(t) + + tests := []struct { + name string + config map[string]string + emailVerified bool + cliUsers []string + oidcUsers []string + want func(iss string) []v1.User + }{ + { + name: "no-migration-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + }, + emailVerified: true, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "no-migration-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + }, + emailVerified: false, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-strip-domains-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", + }, + emailVerified: true, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "2", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-strip-domains-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", + }, + emailVerified: false, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-no-strip-domains-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + }, + emailVerified: true, + cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + // Hmm I think we will have to overwrite the initial name here + // createuser with "user1.headscale.net", but oidc with "user1" + { + Id: "1", + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "2", + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-no-strip-domains-not-verified-email", + config: map[string]string{ + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "1", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + }, + emailVerified: false, + cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []v1.User { + return []v1.User{ + { + Id: "1", + Name: "user1.headscale.net", + }, + { + Id: "2", + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: "3", + Name: "user2.headscale.net", + }, + { + Id: "4", + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + baseScenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + + scenario := AuthOIDCScenario{ + Scenario: baseScenario, + } + defer scenario.ShutdownAssertNoPanics(t) + + spec := map[string]int{} + for _, user := range tt.cliUsers { + spec[user] = 1 + } + + var mockusers []mockoidc.MockUser + for _, user := range tt.oidcUsers { + mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified)) + } + + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) + assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, + "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + for k, v := range tt.config { + oidcMap[k] = v + } + + err = scenario.CreateHeadscaleEnv( + spec, + hsic.WithTestName("oidcmigration"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + ) + assertNoErrHeadscaleEnv(t, err) + + // Ensure that the nodes have logged in, this is what + // triggers user creation via OIDC. + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + want := tt.want(oidcConfig.Issuer) + + var listUsers []v1.User + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listUsers, + ) + assertNoErr(t, err) + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].Id < listUsers[j].Id + }) + + if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Errorf("unexpected users: %s", diff) + } + }) + } +} + func (s *AuthOIDCScenario) CreateHeadscaleEnv( users map[string]int, opts ...hsic.Option, @@ -174,6 +539,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( } for userName, clientCount := range users { + if clientCount != 1 { + // OIDC scenario only supports one client per user. + // This is because the MockOIDC server can only serve login + // requests based on a queue it has been given on startup. + // We currently only populates it with one login request per user. + return fmt.Errorf("client count must be 1 for OIDC scenario.") + } log.Printf("creating user %s with %d clients", userName, clientCount) err = s.CreateUser(userName) if err != nil { @@ -194,7 +566,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( return nil } -func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { +func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { port, err := dockertestutil.RandomFreeHostPort() if err != nil { log.Fatalf("could not find an open port: %s", err) @@ -205,6 +577,11 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf hostname := fmt.Sprintf("hs-oidcmock-%s", hash) + usersJSON, err := json.Marshal(users) + if err != nil { + return nil, err + } + mockOidcOptions := &dockertest.RunOptions{ Name: hostname, Cmd: []string{"headscale", "mockoidc"}, @@ -219,6 +596,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf "MOCKOIDC_CLIENT_ID=superclient", "MOCKOIDC_CLIENT_SECRET=supersecret", fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), + fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), }, } @@ -310,45 +688,40 @@ func (s *AuthOIDCScenario) runTailscaleUp( log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) - if err := s.pool.Retry(func() error { - log.Printf("%s logging in with url", c.Hostname()) - httpClient := &http.Client{Transport: insecureTransport} - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf( - "%s failed to login using url %s: %s", - c.Hostname(), - loginURL, - err, - ) + log.Printf("%s logging in with url", c.Hostname()) + httpClient := &http.Client{Transport: insecureTransport} + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + resp, err := httpClient.Do(req) + if err != nil { + log.Printf( + "%s failed to login using url %s: %s", + c.Hostname(), + loginURL, + err, + ) - return err - } + return err + } - if resp.StatusCode != http.StatusOK { - log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + if resp.StatusCode != http.StatusOK { + log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + body, _ := io.ReadAll(resp.Body) + log.Printf("body: %s", body) - return errStatusCodeNotOK - } + return errStatusCodeNotOK + } - defer resp.Body.Close() + defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - log.Printf("%s failed to read response body: %s", c.Hostname(), err) + _, err = io.ReadAll(resp.Body) + if err != nil { + log.Printf("%s failed to read response body: %s", c.Hostname(), err) - return err - } - - return nil - }); err != nil { return err } log.Printf("Finished request for %s to join tailnet", c.Hostname()) - return nil }) @@ -395,3 +768,12 @@ func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { assert.Equal(t, "NeedsLogin", status.BackendState) } } + +func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { + return mockoidc.MockUser{ + Subject: username, + PreferredUsername: username, + Email: fmt.Sprintf("%s@headscale.net", username), + EmailVerified: emailVerified, + } +} diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index 1b41e324..9e16f366 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -74,7 +74,7 @@ func ExecuteCommand( select { case res := <-resultChan: if res.err != nil { - return stdout.String(), stderr.String(), res.err + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), res.err) } if res.exitCode != 0 { @@ -83,12 +83,12 @@ func ExecuteCommand( // log.Println("stdout: ", stdout.String()) // log.Println("stderr: ", stderr.String()) - return stdout.String(), stderr.String(), ErrDockertestCommandFailed + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed) } return stdout.String(), stderr.String(), nil case <-time.After(execConfig.timeout): - return stdout.String(), stderr.String(), ErrDockertestCommandTimeout + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandTimeout) } } From d6fedd117eb835434f032df682cce0f68ca11d41 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 21 Oct 2024 17:30:28 -0500 Subject: [PATCH 08/14] make preauthkey tags test stable Signed-off-by: Kristoffer Dalby --- integration/cli_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/integration/cli_test.go b/integration/cli_test.go index 150ebb18..2e152deb 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -213,7 +213,9 @@ func TestPreAuthKeyCommand(t *testing.T) { continue } - assert.Equal(t, []string{"tag:test1", "tag:test2"}, listedPreAuthKeys[index].GetAclTags()) + tags := listedPreAuthKeys[index].GetAclTags() + sort.Strings(tags) + assert.Equal(t, []string{"tag:test1", "tag:test2"}, tags) } // Test key expiry From 33c8bbcef84c76f008a86aa54b55c87f1a13da21 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 17 Nov 2024 19:40:06 -0700 Subject: [PATCH 09/14] use userID instead of username everywhere Signed-off-by: Kristoffer Dalby --- hscontrol/db/db_test.go | 4 +- hscontrol/db/node.go | 8 +-- hscontrol/db/node_test.go | 40 +++++------ hscontrol/db/preauth_keys.go | 20 +++--- hscontrol/db/preauth_keys_test.go | 35 +++++----- hscontrol/db/routes_test.go | 16 ++--- hscontrol/db/users.go | 110 +++++++++++++----------------- hscontrol/db/users_test.go | 42 +++++++----- hscontrol/grpcv1.go | 52 +++++++++++--- hscontrol/types/users.go | 2 +- 10 files changed, 178 insertions(+), 151 deletions(-) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index ebc37694..87f94eb9 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -121,12 +121,12 @@ func TestMigrations(t *testing.T) { dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite", wantFunc: func(t *testing.T, h *HSDatabase) { keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { - kratest, err := ListPreAuthKeys(rx, "kratest") + kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest if err != nil { return nil, err } - testkra, err := ListPreAuthKeys(rx, "testkra") + testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra if err != nil { return nil, err } diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 1b6e7538..1c2a165c 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -91,15 +91,15 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { }) } -func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { +func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return getNode(rx, user, name) + return getNode(rx, uid, name) }) } // getNode finds a Node by name and user and returns the Node struct. -func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { - nodes, err := ListNodesByUser(tx, user) +func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) { + nodes, err := ListNodesByUser(tx, uid) if err != nil { return nil, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index a81d8f0f..6c1d1099 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -30,10 +30,10 @@ func (s *Suite) TestGetNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -51,7 +51,7 @@ func (s *Suite) TestGetNode(c *check.C) { trx := db.DB.Save(node) c.Assert(trx.Error, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) } @@ -59,7 +59,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -88,7 +88,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -136,7 +136,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) c.Assert(err, check.IsNil) - _, err = db.getNode(user.Name, "testnode3") + _, err = db.getNode(types.UserID(user.ID), "testnode3") c.Assert(err, check.NotNil) } @@ -144,7 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -190,7 +190,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { for _, name := range []string{"test", "admin"} { user, err := db.CreateUser(name) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) stor = append(stor, base{user, pak}) } @@ -282,10 +282,10 @@ func (s *Suite) TestExpireNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -303,7 +303,7 @@ func (s *Suite) TestExpireNode(c *check.C) { } db.DB.Save(node) - nodeFromDB, err := db.getNode("test", "testnode") + nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB, check.NotNil) @@ -313,7 +313,7 @@ func (s *Suite) TestExpireNode(c *check.C) { err = db.NodeSetExpiry(nodeFromDB.ID, now) c.Assert(err, check.IsNil) - nodeFromDB, err = db.getNode("test", "testnode") + nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, true) @@ -323,10 +323,10 @@ func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -349,7 +349,7 @@ func (s *Suite) TestSetTags(c *check.C) { sTags := []string{"tag:test", "tag:foo"} err = db.SetTags(node.ID, sTags) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, sTags) @@ -357,7 +357,7 @@ func (s *Suite) TestSetTags(c *check.C) { eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} err = db.SetTags(node.ID, eTags) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert( node.ForcedTags, @@ -368,7 +368,7 @@ func (s *Suite) TestSetTags(c *check.C) { // test removing tags err = db.SetTags(node.ID, []string{}) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, []string{}) } @@ -568,7 +568,7 @@ func TestAutoApproveRoutes(t *testing.T) { user, err := adb.CreateUser("test") require.NoError(t, err) - pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) require.NoError(t, err) nodeKey := key.NewNode() @@ -700,10 +700,10 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser("test") require.NoError(t, err) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) require.NoError(t, err) - pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) + pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) require.NoError(t, err) node := types.Node{ diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 59bbdf98..aeee5b52 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -23,29 +23,27 @@ var ( ) func (hsdb *HSDatabase) CreatePreAuthKey( - // TODO(kradalby): Should be ID, not name - userName string, + uid types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { - return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags) + return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags) }) } // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func CreatePreAuthKey( tx *gorm.DB, - // TODO(kradalby): Should be ID, not name - userName string, + uid types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { - user, err := GetUserByUsername(tx, userName) + user, err := GetUserByID(tx, uid) if err != nil { return nil, err } @@ -89,15 +87,15 @@ func CreatePreAuthKey( return &key, nil } -func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { +func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { - return ListPreAuthKeys(rx, userName) + return ListPreAuthKeysByUser(rx, uid) }) } -// ListPreAuthKeys returns the list of PreAuthKeys for a user. -func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) { - user, err := GetUserByUsername(tx, userName) +// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user. +func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) { + user, err := GetUserByID(tx, uid) if err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index ec3f6441..3c56a35e 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -11,14 +11,14 @@ import ( ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) - + // ID does not exist + _, err := db.CreatePreAuthKey(12345, true, false, nil, nil) c.Assert(err, check.NotNil) user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -26,17 +26,18 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { c.Assert(len(key.Key), check.Equals, 48) // Make sure the User association is populated - c.Assert(key.User.Name, check.Equals, user.Name) + c.Assert(key.User.ID, check.Equals, user.ID) - _, err = db.ListPreAuthKeys("bogus") + // ID does not exist + _, err = db.ListPreAuthKeys(1000000) c.Assert(err, check.NotNil) - keys, err := db.ListPreAuthKeys(user.Name) + keys, err := db.ListPreAuthKeys(types.UserID(user.ID)) c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) // Make sure the User association is populated - c.Assert((keys)[0].User.Name, check.Equals, user.Name) + c.Assert((keys)[0].User.ID, check.Equals, user.ID) } func (*Suite) TestExpiredPreAuthKey(c *check.C) { @@ -44,7 +45,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { c.Assert(err, check.IsNil) now := time.Now().Add(-5 * time.Second) - pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -62,7 +63,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -74,7 +75,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { user, err := db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -96,7 +97,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { user, err := db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -118,7 +119,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -130,7 +131,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) @@ -147,7 +148,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true db.DB.Save(&pak) @@ -160,15 +161,15 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) { user, err := db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) - listedPaks, err := db.ListPreAuthKeys("test8") + listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) c.Assert(err, check.IsNil) gotTags := listedPaks[0].Proto().GetAclTags() sort.Sort(sort.StringSlice(gotTags)) diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 5071077c..7b11e136 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -35,10 +35,10 @@ func (s *Suite) TestGetRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_get_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_get_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -79,10 +79,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -153,10 +153,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -234,10 +234,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 135276c7..840d316d 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -40,21 +40,21 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) { return &user, nil } -func (hsdb *HSDatabase) DestroyUser(name string) error { +func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error { return hsdb.Write(func(tx *gorm.DB) error { - return DestroyUser(tx, name) + return DestroyUser(tx, uid) }) } // DestroyUser destroys a User. Returns error if the User does // not exist or if there are nodes associated with it. -func DestroyUser(tx *gorm.DB, name string) error { - user, err := GetUserByUsername(tx, name) +func DestroyUser(tx *gorm.DB, uid types.UserID) error { + user, err := GetUserByID(tx, uid) if err != nil { - return ErrUserNotFound + return err } - nodes, err := ListNodesByUser(tx, name) + nodes, err := ListNodesByUser(tx, uid) if err != nil { return err } @@ -62,7 +62,7 @@ func DestroyUser(tx *gorm.DB, name string) error { return ErrUserStillHasNodes } - keys, err := ListPreAuthKeys(tx, name) + keys, err := ListPreAuthKeysByUser(tx, uid) if err != nil { return err } @@ -80,17 +80,17 @@ func DestroyUser(tx *gorm.DB, name string) error { return nil } -func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { +func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error { return hsdb.Write(func(tx *gorm.DB) error { - return RenameUser(tx, oldName, newName) + return RenameUser(tx, uid, newName) }) } // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func RenameUser(tx *gorm.DB, oldName, newName string) error { +func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { var err error - oldUser, err := GetUserByUsername(tx, oldName) + oldUser, err := GetUserByID(tx, uid) if err != nil { return err } @@ -98,50 +98,25 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error { if err != nil { return err } - _, err = GetUserByUsername(tx, newName) - if err == nil { - return ErrUserExists - } - if !errors.Is(err, ErrUserNotFound) { - return err - } oldUser.Name = newName - if result := tx.Save(&oldUser); result.Error != nil { - return result.Error + if err := tx.Save(&oldUser).Error; err != nil { + return err } return nil } -func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { +func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { - return GetUserByUsername(rx, name) + return GetUserByID(rx, uid) }) } -func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) { +func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) { user := types.User{} - if result := tx.First(&user, "name = ?", name); errors.Is( - result.Error, - gorm.ErrRecordNotFound, - ) { - return nil, ErrUserNotFound - } - - return &user, nil -} - -func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { - return GetUserByID(rx, id) - }) -} - -func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) { - user := types.User{} - if result := tx.First(&user, "id = ?", id); errors.Is( + if result := tx.First(&user, "id = ?", uid); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -169,54 +144,65 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) { return &user, nil } -func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { +func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { - return ListUsers(rx) + return ListUsers(rx, where...) }) } // ListUsers gets all the existing users. -func ListUsers(tx *gorm.DB) ([]types.User, error) { +func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { + if len(where) > 1 { + return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) + } + + var user *types.User + if len(where) == 1 { + user = where[0] + } + users := []types.User{} - if err := tx.Find(&users).Error; err != nil { + if err := tx.Where(user).Find(&users).Error; err != nil { return nil, err } return users, nil } -// ListNodesByUser gets all the nodes in a given user. -func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) { - err := util.CheckForFQDNRules(name) - if err != nil { - return nil, err - } - user, err := GetUserByUsername(tx, name) +// GetUserByName returns a user if the provided username is +// unique, and otherwise an error. +func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { + users, err := hsdb.ListUsers(&types.User{Name: name}) if err != nil { return nil, err } + if len(users) != 1 { + return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) + } + + return &users[0], nil +} + +// ListNodesByUser gets all the nodes in a given user. +func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { nodes := types.Nodes{} - if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { + if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } -func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { +func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error { return hsdb.Write(func(tx *gorm.DB) error { - return AssignNodeToUser(tx, node, username) + return AssignNodeToUser(tx, node, uid) }) } // AssignNodeToUser assigns a Node to a user. -func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error { - err := util.CheckForFQDNRules(username) - if err != nil { - return err - } - user, err := GetUserByUsername(tx, username) +func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error { + user, err := GetUserByID(tx, uid) if err != nil { return err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 54399664..6684989e 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -1,6 +1,8 @@ package db import ( + "strings" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" @@ -17,24 +19,24 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.IsNil) - _, err = db.GetUserByName("test") + _, err = db.GetUserByID(types.UserID(user.ID)) c.Assert(err, check.NotNil) } func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := db.DestroyUser("test") + err := db.DestroyUser(9998) c.Assert(err, check.Equals, ErrUserNotFound) user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.IsNil) result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) @@ -44,7 +46,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { user, err = db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -57,7 +59,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.Equals, ErrUserStillHasNodes) } @@ -70,24 +72,28 @@ func (s *Suite) TestRenameUser(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = db.RenameUser("test", "test-renamed") + err = db.RenameUser(types.UserID(userTest.ID), "test-renamed") c.Assert(err, check.IsNil) - _, err = db.GetUserByName("test") - c.Assert(err, check.Equals, ErrUserNotFound) + users, err = db.ListUsers(&types.User{Name: "test"}) + c.Assert(err, check.Equals, nil) + c.Assert(len(users), check.Equals, 0) - _, err = db.GetUserByName("test-renamed") + users, err = db.ListUsers(&types.User{Name: "test-renamed"}) c.Assert(err, check.IsNil) + c.Assert(len(users), check.Equals, 1) - err = db.RenameUser("test-does-not-exit", "test") + err = db.RenameUser(99988, "test") c.Assert(err, check.Equals, ErrUserNotFound) userTest2, err := db.CreateUser("test2") c.Assert(err, check.IsNil) c.Assert(userTest2.Name, check.Equals, "test2") - err = db.RenameUser("test2", "test-renamed") - c.Assert(err, check.Equals, ErrUserExists) + err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed") + if !strings.Contains(err.Error(), "UNIQUE constraint failed") { + c.Fatalf("expected failure with unique constraint, got: %s", err.Error()) + } } func (s *Suite) TestSetMachineUser(c *check.C) { @@ -97,7 +103,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { newUser, err := db.CreateUser("new") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -111,15 +117,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) { c.Assert(trx.Error, check.IsNil) c.Assert(node.UserID, check.Equals, oldUser.ID) - err = db.AssignNodeToUser(&node, newUser.Name) + err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) c.Assert(err, check.IsNil) c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.User.Name, check.Equals, newUser.Name) - err = db.AssignNodeToUser(&node, "non-existing-user") + err = db.AssignNodeToUser(&node, 9584849) c.Assert(err, check.Equals, ErrUserNotFound) - err = db.AssignNodeToUser(&node, newUser.Name) + err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) c.Assert(err, check.IsNil) c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.User.Name, check.Equals, newUser.Name) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 68793716..dd7ab03d 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -65,24 +65,34 @@ func (api headscaleV1APIServer) RenameUser( ctx context.Context, request *v1.RenameUserRequest, ) (*v1.RenameUserResponse, error) { - err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName()) + oldUser, err := api.h.db.GetUserByName(request.GetOldName()) if err != nil { return nil, err } - user, err := api.h.db.GetUserByName(request.GetNewName()) + err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) if err != nil { return nil, err } - return &v1.RenameUserResponse{User: user.Proto()}, nil + newUser, err := api.h.db.GetUserByName(request.GetNewName()) + if err != nil { + return nil, err + } + + return &v1.RenameUserResponse{User: newUser.Proto()}, nil } func (api headscaleV1APIServer) DeleteUser( ctx context.Context, request *v1.DeleteUserRequest, ) (*v1.DeleteUserResponse, error) { - err := api.h.db.DestroyUser(request.GetName()) + user, err := api.h.db.GetUserByName(request.GetName()) + if err != nil { + return nil, err + } + + err = api.h.db.DestroyUser(types.UserID(user.ID)) if err != nil { return nil, err } @@ -131,8 +141,13 @@ func (api headscaleV1APIServer) CreatePreAuthKey( } } + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + preAuthKey, err := api.h.db.CreatePreAuthKey( - request.GetUser(), + types.UserID(user.ID), request.GetReusable(), request.GetEphemeral(), &expiration, @@ -168,7 +183,12 @@ func (api headscaleV1APIServer) ListPreAuthKeys( ctx context.Context, request *v1.ListPreAuthKeysRequest, ) (*v1.ListPreAuthKeysResponse, error) { - preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser()) + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + + preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID)) if err != nil { return nil, err } @@ -406,10 +426,20 @@ func (api headscaleV1APIServer) ListNodes( ctx context.Context, request *v1.ListNodesRequest, ) (*v1.ListNodesResponse, error) { + // TODO(kradalby): it looks like this can be simplified a lot, + // the filtering of nodes by user, vs nodes as a whole can + // probably be done once. + // TODO(kradalby): This should be done in one tx. + isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() if request.GetUser() != "" { + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { - return db.ListNodesByUser(rx, request.GetUser()) + return db.ListNodesByUser(rx, types.UserID(user.ID)) }) if err != nil { return nil, err @@ -465,12 +495,18 @@ func (api headscaleV1APIServer) MoveNode( ctx context.Context, request *v1.MoveNodeRequest, ) (*v1.MoveNodeResponse, error) { + // TODO(kradalby): This should be done in one tx. node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) if err != nil { return nil, err } - err = api.h.db.AssignNodeToUser(node, request.GetUser()) + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + + err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID)) if err != nil { return nil, err } diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 5b27e671..9e0bfeb0 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -26,7 +26,7 @@ type User struct { // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` + Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"` // Typically the full name of the user DisplayName string From 7c92acb50ca18c6d554dc9f734490f2436f81fd1 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 17 Nov 2024 19:49:51 -0700 Subject: [PATCH 10/14] nits Signed-off-by: Kristoffer Dalby --- hscontrol/db/routes.go | 2 +- hscontrol/types/preauth_key.go | 2 +- hscontrol/types/users.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 086261aa..d8fe7b3f 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -639,7 +639,7 @@ func EnableAutoApprovedRoutes( log.Trace(). Str("node", node.Hostname). - Str("user", node.User.Name). + Uint("user.id", node.User.ID). Strs("routeApprovers", routeApprovers). Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()). Msg("looking up route for autoapproving") diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index ba3b597b..0174c9e8 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -26,7 +26,7 @@ type PreAuthKey struct { func (key *PreAuthKey) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ - User: key.User.Name, + User: key.User.Username(), Id: strconv.FormatUint(key.ID, util.Base10), Key: key.Key, Ephemeral: key.Ephemeral, diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 9e0bfeb0..8b3d2e83 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -21,7 +21,7 @@ type User struct { gorm.Model // The index `idx_name_provider_identifier` is to enforce uniqueness // between Name and ProviderIdentifier. This ensures that - // you can have multiple usersnames of the same name in OIDC, + // you can have multiple users with the same name in OIDC, // but not if you only run with CLI users. // Username for the user, is used if email is empty @@ -54,9 +54,9 @@ type User struct { // enabled with OIDC, which means that there is a domain involved which // should be used throughout headscale, in information returned to the // user and the Policy engine. -// If the username does not contain an '@' it will be added to the end. func (u *User) Username() string { username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + // TODO(kradalby): Wire up all of this for the future // if !strings.Contains(username, "@") { // username = username + "@" From 0dd5956756b0e2811e843c9ef6be7241e6a69906 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 18 Nov 2024 17:33:46 +0100 Subject: [PATCH 11/14] fix constraints Signed-off-by: Kristoffer Dalby --- hscontrol/db/db_test.go | 108 ++++++++++++++++++++++++++++++++++ hscontrol/db/users.go | 10 ++-- hscontrol/oidc.go | 2 +- hscontrol/types/users.go | 11 ++-- integration/auth_oidc_test.go | 2 +- 5 files changed, 122 insertions(+), 11 deletions(-) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 87f94eb9..a291ad7d 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "fmt" "io" "net/netip" @@ -257,3 +258,110 @@ func testCopyOfDatabase(src string) (string, error) { func emptyCache() *zcache.Cache[string, types.Node] { return zcache.New[string, types.Node](time.Minute, time.Hour) } + +func TestConstraints(t *testing.T) { + tests := []struct { + name string + run func(*testing.T, *gorm.DB) + }{ + { + name: "no-duplicate-username-if-no-oidc", + run: func(t *testing.T, db *gorm.DB) { + _, err := CreateUser(db, "user1") + require.NoError(t, err) + _, err = CreateUser(db, "user1") + require.Error(t, err) + // assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username") + require.Contains(t, err.Error(), "user already exists") + }, + }, + { + name: "no-oidc-duplicate-username-and-id", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err := db.Save(&user).Error + require.NoError(t, err) + + user = types.User{ + Model: gorm.Model{ID: 2}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err = db.Save(&user).Error + require.Error(t, err) + require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") + }, + }, + { + name: "no-oidc-duplicate-id", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err := db.Save(&user).Error + require.NoError(t, err) + + user = types.User{ + Model: gorm.Model{ID: 2}, + Name: "user1.1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err = db.Save(&user).Error + require.Error(t, err) + require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") + }, + }, + { + name: "allow-duplicate-username-cli-then-oidc", + run: func(t *testing.T, db *gorm.DB) { + _, err := CreateUser(db, "user1") // Create CLI username + require.NoError(t, err) + + user := types.User{ + Name: "user1", + } + user.ProviderIdentifier.String = "http://test.com/user1" + + err = db.Save(&user).Error + require.NoError(t, err) + }, + }, + { + name: "allow-duplicate-username-oidc-then-cli", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Name: "user1", + } + user.ProviderIdentifier.String = "http://test.com/user1" + + err := db.Save(&user).Error + require.NoError(t, err) + + _, err = CreateUser(db, "user1") // Create CLI username + require.NoError(t, err) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newTestDB() + if err != nil { + t.Fatalf("creating database: %s", err) + } + + tt.run(t, db.DB) + }) + + } +} diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 840d316d..0eaa9ea3 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -28,11 +28,9 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) { if err != nil { return nil, err } - user := types.User{} - if err := tx.Where("name = ?", name).First(&user).Error; err == nil { - return nil, ErrUserExists + user := types.User{ + Name: name, } - user.Name = name if err := tx.Create(&user).Error; err != nil { return nil, fmt.Errorf("creating user: %w", err) } @@ -177,6 +175,10 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { return nil, err } + if len(users) == 0 { + return nil, ErrUserNotFound + } + if len(users) != 1 { return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index fce7e455..e8461967 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -460,7 +460,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( // This is to prevent users that have already been migrated to the new OIDC format // to be updated with the new OIDC identifier inexplicitly which might be the cause of an // account takeover. - if user != nil && user.ProviderIdentifier != "" { + if user != nil && user.ProviderIdentifier.Valid { log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") user = &types.User{} } diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 8b3d2e83..f36be708 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,6 +2,7 @@ package types import ( "cmp" + "database/sql" "strconv" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -26,7 +27,7 @@ type User struct { // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"` + Name string `gorm:"uniqueIndex:idx_name_provider_identifier;index"` // Typically the full name of the user DisplayName string @@ -38,7 +39,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"` + ProviderIdentifier sql.NullString `gorm:"uniqueIndex:idx_name_provider_identifier;uniqueIndex:idx_provider_identifier"` // Provider is the origin of the user account, // same as RegistrationMethod, without authkey. @@ -55,7 +56,7 @@ type User struct { // should be used throughout headscale, in information returned to the // user and the Policy engine. func (u *User) Username() string { - username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) // TODO(kradalby): Wire up all of this for the future // if !strings.Contains(username, "@") { @@ -118,7 +119,7 @@ func (u *User) Proto() *v1.User { CreatedAt: timestamppb.New(u.CreatedAt), DisplayName: u.DisplayName, Email: u.Email, - ProviderId: u.ProviderIdentifier, + ProviderId: u.ProviderIdentifier.String, Provider: u.Provider, ProfilePicUrl: u.ProfilePicURL, } @@ -145,7 +146,7 @@ func (c *OIDCClaims) Identifier() string { // FromClaim overrides a User from OIDC claims. // All fields will be updated, except for the ID. func (u *User) FromClaim(claims *OIDCClaims) { - u.ProviderIdentifier = claims.Identifier() + u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true} u.DisplayName = claims.Name if claims.EmailVerified { u.Email = claims.Email diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 25fb358c..2fbfb555 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -54,7 +54,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { scenario := AuthOIDCScenario{ Scenario: baseScenario, } - defer scenario.ShutdownAssertNoPanics(t) + // defer scenario.ShutdownAssertNoPanics(t) // Logins to MockOIDC is served by a queue with a strict order, // if we use more than one node per user, the order of the logins From 60b95959299ad4d2450e7ad8bd55fcd805f897a8 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 22 Nov 2024 16:42:34 +0100 Subject: [PATCH 12/14] fix nil in test Signed-off-by: Kristoffer Dalby --- hscontrol/db/users_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 6684989e..06073762 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -90,9 +90,10 @@ func (s *Suite) TestRenameUser(c *check.C) { c.Assert(err, check.IsNil) c.Assert(userTest2.Name, check.Equals, "test2") + want := "UNIQUE constraint failed" err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed") - if !strings.Contains(err.Error(), "UNIQUE constraint failed") { - c.Fatalf("expected failure with unique constraint, got: %s", err.Error()) + if err == nil || !strings.Contains(err.Error(), want) { + c.Fatalf("expected failure with unique constraint, want: %q got: %q", want, err) } } From 4f57410a5b7fcd39266b446b8602e99295d57caa Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 22 Nov 2024 17:45:46 +0100 Subject: [PATCH 13/14] fix constraints Signed-off-by: Kristoffer Dalby --- hscontrol/db/db.go | 19 +++++++++++++++++++ hscontrol/db/db_test.go | 18 +++++++++--------- hscontrol/types/users.go | 4 ++-- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 529dc696..28681213 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -498,6 +498,25 @@ func NewHeadscaleDatabase( return err } + // Set up indexes and unique constraints outside of GORM, it does not support + // conditional unique constraints. + // This ensures the following: + // - A user name and provider_identifier is unique + // - A provider_identifier is unique + // - A user name is unique if there is no provider_identifier is not set + for _, idx := range []string{ + "DROP INDEX IF EXISTS `idx_provider_identifier`", + "DROP INDEX IF EXISTS `idx_name_provider_identifier`", + "CREATE UNIQUE INDEX IF NOT EXISTS `idx_provider_identifier` ON `users` (`provider_identifier`) WHERE provider_identifier IS NOT NULL;", + "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_provider_identifier` ON `users` (`name`,`provider_identifier`);", + "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_no_provider_identifier` ON `users` (`name`) WHERE provider_identifier IS NULL;", + } { + err = tx.Exec(idx).Error + if err != nil { + return fmt.Errorf("creating username index: %w", err) + } + } + return nil }, Rollback: func(db *gorm.DB) error { return nil }, diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index a291ad7d..34115647 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -271,8 +271,8 @@ func TestConstraints(t *testing.T) { require.NoError(t, err) _, err = CreateUser(db, "user1") require.Error(t, err) - // assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username") - require.Contains(t, err.Error(), "user already exists") + assert.Contains(t, err.Error(), "UNIQUE constraint failed:") + // require.Contains(t, err.Error(), "user already exists") }, }, { @@ -295,7 +295,7 @@ func TestConstraints(t *testing.T) { err = db.Save(&user).Error require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") + require.Contains(t, err.Error(), "UNIQUE constraint failed:") }, }, { @@ -318,7 +318,7 @@ func TestConstraints(t *testing.T) { err = db.Save(&user).Error require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") + require.Contains(t, err.Error(), "UNIQUE constraint failed:") }, }, { @@ -328,9 +328,9 @@ func TestConstraints(t *testing.T) { require.NoError(t, err) user := types.User{ - Name: "user1", + Name: "user1", + ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, } - user.ProviderIdentifier.String = "http://test.com/user1" err = db.Save(&user).Error require.NoError(t, err) @@ -340,9 +340,9 @@ func TestConstraints(t *testing.T) { name: "allow-duplicate-username-oidc-then-cli", run: func(t *testing.T, db *gorm.DB) { user := types.User{ - Name: "user1", + Name: "user1", + ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, } - user.ProviderIdentifier.String = "http://test.com/user1" err := db.Save(&user).Error require.NoError(t, err) @@ -360,7 +360,7 @@ func TestConstraints(t *testing.T) { t.Fatalf("creating database: %s", err) } - tt.run(t, db.DB) + tt.run(t, db.DB.Debug()) }) } diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index f36be708..8194dea6 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -27,7 +27,7 @@ type User struct { // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"uniqueIndex:idx_name_provider_identifier;index"` + Name string // Typically the full name of the user DisplayName string @@ -39,7 +39,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier sql.NullString `gorm:"uniqueIndex:idx_name_provider_identifier;uniqueIndex:idx_provider_identifier"` + ProviderIdentifier sql.NullString // Provider is the origin of the user account, // same as RegistrationMethod, without authkey. From 8671a80dc8bfbc546e24d72bf8488d846cc82609 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 23 Nov 2024 11:19:52 +0100 Subject: [PATCH 14/14] fix postgres constraints, add postgres testing This commit fixes the constraint syntax so it is both valid for sqlite and postgres. To validate this, I've added a new postgres testing library and a helper that will spin up local postgres, setup a db and use it in the constraints tests. This should also help testing db stuff in the future. postgres has been added to the nix dev shell and is now required for running the unit tests. Signed-off-by: Kristoffer Dalby --- flake.nix | 3 +- go.mod | 2 ++ go.sum | 3 ++ hscontrol/db/db.go | 10 +++--- hscontrol/db/db_test.go | 29 ++++++++++++------ hscontrol/db/node_test.go | 6 ++-- hscontrol/db/suite_test.go | 63 ++++++++++++++++++++++++++++++++++++-- 7 files changed, 95 insertions(+), 21 deletions(-) diff --git a/flake.nix b/flake.nix index 8faae71e..90a2aad8 100644 --- a/flake.nix +++ b/flake.nix @@ -32,7 +32,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorHash = "sha256-Qoqu2k4vvnbRFLmT/v8lI+HCEWqJsHFs8uZRfNmwQpo="; + vendorHash = "sha256-4VNiHUblvtcl9UetwiL6ZeVYb0h2e9zhYVsirhAkvOg="; subPackages = ["cmd/headscale"]; @@ -102,6 +102,7 @@ ko yq-go ripgrep + postgresql # 'dot' is needed for pprof graphs # go tool pprof -http=: diff --git a/go.mod b/go.mod index 7eac4652..8d51fc6a 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( gorm.io/gorm v1.25.11 tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 zgo.at/zcache/v2 v2.1.0 + zombiezen.com/go/postgrestest v1.0.1 ) require ( @@ -134,6 +135,7 @@ require ( github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index cc15ef6c..9315dbb6 100644 --- a/go.sum +++ b/go.sum @@ -311,6 +311,7 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= +github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4= @@ -731,3 +732,5 @@ tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 h1:nfRWV6ECxwNvvXKtbqSVs tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg= zgo.at/zcache/v2 v2.1.0 h1:USo+ubK+R4vtjw4viGzTe/zjXyPw6R7SK/RL3epBBxs= zgo.at/zcache/v2 v2.1.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk= +zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4= +zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ= diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 28681213..179bfcc3 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -505,11 +505,11 @@ func NewHeadscaleDatabase( // - A provider_identifier is unique // - A user name is unique if there is no provider_identifier is not set for _, idx := range []string{ - "DROP INDEX IF EXISTS `idx_provider_identifier`", - "DROP INDEX IF EXISTS `idx_name_provider_identifier`", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_provider_identifier` ON `users` (`provider_identifier`) WHERE provider_identifier IS NOT NULL;", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_provider_identifier` ON `users` (`name`,`provider_identifier`);", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_no_provider_identifier` ON `users` (`name`) WHERE provider_identifier IS NULL;", + "DROP INDEX IF EXISTS idx_provider_identifier", + "DROP INDEX IF EXISTS idx_name_provider_identifier", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL;", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name_provider_identifier ON users (name,provider_identifier);", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL;", } { err = tx.Exec(idx).Error if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 34115647..bafe1e1b 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -9,6 +9,7 @@ import ( "path/filepath" "slices" "sort" + "strings" "testing" "time" @@ -259,6 +260,16 @@ func emptyCache() *zcache.Cache[string, types.Node] { return zcache.New[string, types.Node](time.Minute, time.Hour) } +// requireConstraintFailed checks if the error is a constraint failure with +// either SQLite and PostgreSQL error messages. +func requireConstraintFailed(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { + require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) + } +} + func TestConstraints(t *testing.T) { tests := []struct { name string @@ -270,9 +281,7 @@ func TestConstraints(t *testing.T) { _, err := CreateUser(db, "user1") require.NoError(t, err) _, err = CreateUser(db, "user1") - require.Error(t, err) - assert.Contains(t, err.Error(), "UNIQUE constraint failed:") - // require.Contains(t, err.Error(), "user already exists") + requireConstraintFailed(t, err) }, }, { @@ -294,8 +303,7 @@ func TestConstraints(t *testing.T) { user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} err = db.Save(&user).Error - require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed:") + requireConstraintFailed(t, err) }, }, { @@ -317,8 +325,7 @@ func TestConstraints(t *testing.T) { user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} err = db.Save(&user).Error - require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed:") + requireConstraintFailed(t, err) }, }, { @@ -354,8 +361,12 @@ func TestConstraints(t *testing.T) { } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, err := newTestDB() + t.Run(tt.name+"-postgres", func(t *testing.T) { + db := newPostgresTestDB(t) + tt.run(t, db.DB.Debug()) + }) + t.Run(tt.name+"-sqlite", func(t *testing.T) { + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating database: %s", err) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 6c1d1099..bb29b00a 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -558,7 +558,7 @@ func TestAutoApproveRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - adb, err := newTestDB() + adb, err := newSQLiteTestDB() require.NoError(t, err) pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) @@ -692,7 +692,7 @@ func generateRandomNumber(t *testing.T, max int64) int64 { } func TestListEphemeralNodes(t *testing.T) { - db, err := newTestDB() + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) } @@ -748,7 +748,7 @@ func TestListEphemeralNodes(t *testing.T) { } func TestRenameNode(t *testing.T) { - db, err := newTestDB() + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 6cc46d3d..fb7ce1df 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -1,12 +1,17 @@ package db import ( + "context" "log" + "net/url" "os" + "strconv" + "strings" "testing" "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" + "zombiezen.com/go/postgrestest" ) func Test(t *testing.T) { @@ -36,13 +41,15 @@ func (s *Suite) ResetDB(c *check.C) { // } var err error - db, err = newTestDB() + db, err = newSQLiteTestDB() if err != nil { c.Fatal(err) } } -func newTestDB() (*HSDatabase, error) { +// TODO(kradalby): make this a t.Helper when we dont depend +// on check test framework. +func newSQLiteTestDB() (*HSDatabase, error) { var err error tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") if err != nil { @@ -53,7 +60,7 @@ func newTestDB() (*HSDatabase, error) { db, err = NewHeadscaleDatabase( types.DatabaseConfig{ - Type: "sqlite3", + Type: types.DatabaseSqlite, Sqlite: types.SqliteConfig{ Path: tmpDir + "/headscale_test.db", }, @@ -67,3 +74,53 @@ func newTestDB() (*HSDatabase, error) { return db, nil } + +func newPostgresTestDB(t *testing.T) *HSDatabase { + t.Helper() + + var err error + tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") + if err != nil { + t.Fatal(err) + } + + log.Printf("database path: %s", tmpDir+"/headscale_test.db") + + ctx := context.Background() + srv, err := postgrestest.Start(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(srv.Cleanup) + + u, err := srv.CreateDatabase(ctx) + if err != nil { + t.Fatal(err) + } + t.Logf("created local postgres: %s", u) + pu, _ := url.Parse(u) + + pass, _ := pu.User.Password() + port, _ := strconv.Atoi(pu.Port()) + + db, err = NewHeadscaleDatabase( + types.DatabaseConfig{ + Type: types.DatabasePostgres, + Postgres: types.PostgresConfig{ + Host: pu.Hostname(), + User: pu.User.Username(), + Name: strings.TrimLeft(pu.Path, "/"), + Pass: pass, + Port: port, + Ssl: "disable", + }, + }, + "", + emptyCache(), + ) + if err != nil { + t.Fatal(err) + } + + return db +}