diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 529dc696..28681213 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -498,6 +498,25 @@ func NewHeadscaleDatabase( return err } + // Set up indexes and unique constraints outside of GORM, it does not support + // conditional unique constraints. + // This ensures the following: + // - A user name and provider_identifier is unique + // - A provider_identifier is unique + // - A user name is unique if there is no provider_identifier is not set + for _, idx := range []string{ + "DROP INDEX IF EXISTS `idx_provider_identifier`", + "DROP INDEX IF EXISTS `idx_name_provider_identifier`", + "CREATE UNIQUE INDEX IF NOT EXISTS `idx_provider_identifier` ON `users` (`provider_identifier`) WHERE provider_identifier IS NOT NULL;", + "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_provider_identifier` ON `users` (`name`,`provider_identifier`);", + "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_no_provider_identifier` ON `users` (`name`) WHERE provider_identifier IS NULL;", + } { + err = tx.Exec(idx).Error + if err != nil { + return fmt.Errorf("creating username index: %w", err) + } + } + return nil }, Rollback: func(db *gorm.DB) error { return nil }, diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index a291ad7d..34115647 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -271,8 +271,8 @@ func TestConstraints(t *testing.T) { require.NoError(t, err) _, err = CreateUser(db, "user1") require.Error(t, err) - // assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username") - require.Contains(t, err.Error(), "user already exists") + assert.Contains(t, err.Error(), "UNIQUE constraint failed:") + // require.Contains(t, err.Error(), "user already exists") }, }, { @@ -295,7 +295,7 @@ func TestConstraints(t *testing.T) { err = db.Save(&user).Error require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") + require.Contains(t, err.Error(), "UNIQUE constraint failed:") }, }, { @@ -318,7 +318,7 @@ func TestConstraints(t *testing.T) { err = db.Save(&user).Error require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") + require.Contains(t, err.Error(), "UNIQUE constraint failed:") }, }, { @@ -328,9 +328,9 @@ func TestConstraints(t *testing.T) { require.NoError(t, err) user := types.User{ - Name: "user1", + Name: "user1", + ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, } - user.ProviderIdentifier.String = "http://test.com/user1" err = db.Save(&user).Error require.NoError(t, err) @@ -340,9 +340,9 @@ func TestConstraints(t *testing.T) { name: "allow-duplicate-username-oidc-then-cli", run: func(t *testing.T, db *gorm.DB) { user := types.User{ - Name: "user1", + Name: "user1", + ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, } - user.ProviderIdentifier.String = "http://test.com/user1" err := db.Save(&user).Error require.NoError(t, err) @@ -360,7 +360,7 @@ func TestConstraints(t *testing.T) { t.Fatalf("creating database: %s", err) } - tt.run(t, db.DB) + tt.run(t, db.DB.Debug()) }) } diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index f36be708..8194dea6 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -27,7 +27,7 @@ type User struct { // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"uniqueIndex:idx_name_provider_identifier;index"` + Name string // Typically the full name of the user DisplayName string @@ -39,7 +39,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier sql.NullString `gorm:"uniqueIndex:idx_name_provider_identifier;uniqueIndex:idx_provider_identifier"` + ProviderIdentifier sql.NullString // Provider is the origin of the user account, // same as RegistrationMethod, without authkey.