mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
fix constraints
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
5e7c3153b9
commit
281025bb16
5 changed files with 122 additions and 11 deletions
|
@ -1,6 +1,7 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -257,3 +258,110 @@ func testCopyOfDatabase(src string) (string, error) {
|
||||||
func emptyCache() *zcache.Cache[string, types.Node] {
|
func emptyCache() *zcache.Cache[string, types.Node] {
|
||||||
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConstraints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
run func(*testing.T, *gorm.DB)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-duplicate-username-if-no-oidc",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
_, err := CreateUser(db, "user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = CreateUser(db, "user1")
|
||||||
|
require.Error(t, err)
|
||||||
|
// assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username")
|
||||||
|
require.Contains(t, err.Error(), "user already exists")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-oidc-duplicate-username-and-id",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user = types.User{
|
||||||
|
Model: gorm.Model{ID: 2},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-oidc-duplicate-id",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user = types.User{
|
||||||
|
Model: gorm.Model{ID: 2},
|
||||||
|
Name: "user1.1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow-duplicate-username-cli-then-oidc",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
_, err := CreateUser(db, "user1") // Create CLI username
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user := types.User{
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier.String = "http://test.com/user1"
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow-duplicate-username-oidc-then-cli",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier.String = "http://test.com/user1"
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = CreateUser(db, "user1") // Create CLI username
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
db, err := newTestDB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating database: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tt.run(t, db.DB)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -28,11 +28,9 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user := types.User{}
|
user := types.User{
|
||||||
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
|
Name: name,
|
||||||
return nil, ErrUserExists
|
|
||||||
}
|
}
|
||||||
user.Name = name
|
|
||||||
if err := tx.Create(&user).Error; err != nil {
|
if err := tx.Create(&user).Error; err != nil {
|
||||||
return nil, fmt.Errorf("creating user: %w", err)
|
return nil, fmt.Errorf("creating user: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -177,6 +175,10 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(users) == 0 {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
if len(users) != 1 {
|
if len(users) != 1 {
|
||||||
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
||||||
}
|
}
|
||||||
|
|
|
@ -460,7 +460,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||||
// This is to prevent users that have already been migrated to the new OIDC format
|
// This is to prevent users that have already been migrated to the new OIDC format
|
||||||
// to be updated with the new OIDC identifier inexplicitly which might be the cause of an
|
// to be updated with the new OIDC identifier inexplicitly which might be the cause of an
|
||||||
// account takeover.
|
// account takeover.
|
||||||
if user != nil && user.ProviderIdentifier != "" {
|
if user != nil && user.ProviderIdentifier.Valid {
|
||||||
log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.")
|
log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.")
|
||||||
user = &types.User{}
|
user = &types.User{}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"database/sql"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
@ -26,7 +27,7 @@ type User struct {
|
||||||
|
|
||||||
// Username for the user, is used if email is empty
|
// Username for the user, is used if email is empty
|
||||||
// Should not be used, please use Username().
|
// Should not be used, please use Username().
|
||||||
Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"`
|
Name string `gorm:"uniqueIndex:idx_name_provider_identifier;index"`
|
||||||
|
|
||||||
// Typically the full name of the user
|
// Typically the full name of the user
|
||||||
DisplayName string
|
DisplayName string
|
||||||
|
@ -38,7 +39,7 @@ type User struct {
|
||||||
// Unique identifier of the user from OIDC,
|
// Unique identifier of the user from OIDC,
|
||||||
// comes from `sub` claim in the OIDC token
|
// comes from `sub` claim in the OIDC token
|
||||||
// and is used to lookup the user.
|
// and is used to lookup the user.
|
||||||
ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"`
|
ProviderIdentifier sql.NullString `gorm:"uniqueIndex:idx_name_provider_identifier;uniqueIndex:idx_provider_identifier"`
|
||||||
|
|
||||||
// Provider is the origin of the user account,
|
// Provider is the origin of the user account,
|
||||||
// same as RegistrationMethod, without authkey.
|
// same as RegistrationMethod, without authkey.
|
||||||
|
@ -55,7 +56,7 @@ type User struct {
|
||||||
// should be used throughout headscale, in information returned to the
|
// should be used throughout headscale, in information returned to the
|
||||||
// user and the Policy engine.
|
// user and the Policy engine.
|
||||||
func (u *User) Username() string {
|
func (u *User) Username() string {
|
||||||
username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
|
username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10))
|
||||||
|
|
||||||
// TODO(kradalby): Wire up all of this for the future
|
// TODO(kradalby): Wire up all of this for the future
|
||||||
// if !strings.Contains(username, "@") {
|
// if !strings.Contains(username, "@") {
|
||||||
|
@ -118,7 +119,7 @@ func (u *User) Proto() *v1.User {
|
||||||
CreatedAt: timestamppb.New(u.CreatedAt),
|
CreatedAt: timestamppb.New(u.CreatedAt),
|
||||||
DisplayName: u.DisplayName,
|
DisplayName: u.DisplayName,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
ProviderId: u.ProviderIdentifier,
|
ProviderId: u.ProviderIdentifier.String,
|
||||||
Provider: u.Provider,
|
Provider: u.Provider,
|
||||||
ProfilePicUrl: u.ProfilePicURL,
|
ProfilePicUrl: u.ProfilePicURL,
|
||||||
}
|
}
|
||||||
|
@ -145,7 +146,7 @@ func (c *OIDCClaims) Identifier() string {
|
||||||
// FromClaim overrides a User from OIDC claims.
|
// FromClaim overrides a User from OIDC claims.
|
||||||
// All fields will be updated, except for the ID.
|
// All fields will be updated, except for the ID.
|
||||||
func (u *User) FromClaim(claims *OIDCClaims) {
|
func (u *User) FromClaim(claims *OIDCClaims) {
|
||||||
u.ProviderIdentifier = claims.Identifier()
|
u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true}
|
||||||
u.DisplayName = claims.Name
|
u.DisplayName = claims.Name
|
||||||
if claims.EmailVerified {
|
if claims.EmailVerified {
|
||||||
u.Email = claims.Email
|
u.Email = claims.Email
|
||||||
|
|
|
@ -54,7 +54,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
scenario := AuthOIDCScenario{
|
scenario := AuthOIDCScenario{
|
||||||
Scenario: baseScenario,
|
Scenario: baseScenario,
|
||||||
}
|
}
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
// defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
// Logins to MockOIDC is served by a queue with a strict order,
|
// Logins to MockOIDC is served by a queue with a strict order,
|
||||||
// if we use more than one node per user, the order of the logins
|
// if we use more than one node per user, the order of the logins
|
||||||
|
|
Loading…
Reference in a new issue