diff --git a/flake.nix b/flake.nix index 8faae71e..90a2aad8 100644 --- a/flake.nix +++ b/flake.nix @@ -32,7 +32,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorHash = "sha256-Qoqu2k4vvnbRFLmT/v8lI+HCEWqJsHFs8uZRfNmwQpo="; + vendorHash = "sha256-4VNiHUblvtcl9UetwiL6ZeVYb0h2e9zhYVsirhAkvOg="; subPackages = ["cmd/headscale"]; @@ -102,6 +102,7 @@ ko yq-go ripgrep + postgresql # 'dot' is needed for pprof graphs # go tool pprof -http=: diff --git a/go.mod b/go.mod index 7eac4652..8d51fc6a 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( gorm.io/gorm v1.25.11 tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 zgo.at/zcache/v2 v2.1.0 + zombiezen.com/go/postgrestest v1.0.1 ) require ( @@ -134,6 +135,7 @@ require ( github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index cc15ef6c..9315dbb6 100644 --- a/go.sum +++ b/go.sum @@ -311,6 +311,7 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= +github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4= @@ -731,3 +732,5 @@ tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 h1:nfRWV6ECxwNvvXKtbqSVs tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg= zgo.at/zcache/v2 v2.1.0 h1:USo+ubK+R4vtjw4viGzTe/zjXyPw6R7SK/RL3epBBxs= zgo.at/zcache/v2 v2.1.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk= +zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4= +zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ= diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index a4f06554..0d9120c2 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -505,11 +505,11 @@ func NewHeadscaleDatabase( // - A provider_identifier is unique // - A user name is unique if there is no provider_identifier is not set for _, idx := range []string{ - "DROP INDEX IF EXISTS `idx_provider_identifier`", - "DROP INDEX IF EXISTS `idx_name_provider_identifier`", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_provider_identifier` ON `users` (`provider_identifier`) WHERE provider_identifier IS NOT NULL;", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_provider_identifier` ON `users` (`name`,`provider_identifier`);", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_no_provider_identifier` ON `users` (`name`) WHERE provider_identifier IS NULL;", + "DROP INDEX IF EXISTS idx_provider_identifier", + "DROP INDEX IF EXISTS idx_name_provider_identifier", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL;", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name_provider_identifier ON users (name,provider_identifier);", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL;", } { err = tx.Exec(idx).Error if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 34115647..bafe1e1b 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -9,6 +9,7 @@ import ( "path/filepath" "slices" "sort" + "strings" "testing" "time" @@ -259,6 +260,16 @@ func emptyCache() *zcache.Cache[string, types.Node] { return zcache.New[string, types.Node](time.Minute, time.Hour) } +// requireConstraintFailed checks if the error is a constraint failure with +// either SQLite and PostgreSQL error messages. +func requireConstraintFailed(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { + require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) + } +} + func TestConstraints(t *testing.T) { tests := []struct { name string @@ -270,9 +281,7 @@ func TestConstraints(t *testing.T) { _, err := CreateUser(db, "user1") require.NoError(t, err) _, err = CreateUser(db, "user1") - require.Error(t, err) - assert.Contains(t, err.Error(), "UNIQUE constraint failed:") - // require.Contains(t, err.Error(), "user already exists") + requireConstraintFailed(t, err) }, }, { @@ -294,8 +303,7 @@ func TestConstraints(t *testing.T) { user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} err = db.Save(&user).Error - require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed:") + requireConstraintFailed(t, err) }, }, { @@ -317,8 +325,7 @@ func TestConstraints(t *testing.T) { user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} err = db.Save(&user).Error - require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed:") + requireConstraintFailed(t, err) }, }, { @@ -354,8 +361,12 @@ func TestConstraints(t *testing.T) { } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, err := newTestDB() + t.Run(tt.name+"-postgres", func(t *testing.T) { + db := newPostgresTestDB(t) + tt.run(t, db.DB.Debug()) + }) + t.Run(tt.name+"-sqlite", func(t *testing.T) { + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating database: %s", err) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 6c1d1099..bb29b00a 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -558,7 +558,7 @@ func TestAutoApproveRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - adb, err := newTestDB() + adb, err := newSQLiteTestDB() require.NoError(t, err) pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) @@ -692,7 +692,7 @@ func generateRandomNumber(t *testing.T, max int64) int64 { } func TestListEphemeralNodes(t *testing.T) { - db, err := newTestDB() + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) } @@ -748,7 +748,7 @@ func TestListEphemeralNodes(t *testing.T) { } func TestRenameNode(t *testing.T) { - db, err := newTestDB() + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 6cc46d3d..fb7ce1df 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -1,12 +1,17 @@ package db import ( + "context" "log" + "net/url" "os" + "strconv" + "strings" "testing" "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" + "zombiezen.com/go/postgrestest" ) func Test(t *testing.T) { @@ -36,13 +41,15 @@ func (s *Suite) ResetDB(c *check.C) { // } var err error - db, err = newTestDB() + db, err = newSQLiteTestDB() if err != nil { c.Fatal(err) } } -func newTestDB() (*HSDatabase, error) { +// TODO(kradalby): make this a t.Helper when we dont depend +// on check test framework. +func newSQLiteTestDB() (*HSDatabase, error) { var err error tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") if err != nil { @@ -53,7 +60,7 @@ func newTestDB() (*HSDatabase, error) { db, err = NewHeadscaleDatabase( types.DatabaseConfig{ - Type: "sqlite3", + Type: types.DatabaseSqlite, Sqlite: types.SqliteConfig{ Path: tmpDir + "/headscale_test.db", }, @@ -67,3 +74,53 @@ func newTestDB() (*HSDatabase, error) { return db, nil } + +func newPostgresTestDB(t *testing.T) *HSDatabase { + t.Helper() + + var err error + tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") + if err != nil { + t.Fatal(err) + } + + log.Printf("database path: %s", tmpDir+"/headscale_test.db") + + ctx := context.Background() + srv, err := postgrestest.Start(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(srv.Cleanup) + + u, err := srv.CreateDatabase(ctx) + if err != nil { + t.Fatal(err) + } + t.Logf("created local postgres: %s", u) + pu, _ := url.Parse(u) + + pass, _ := pu.User.Password() + port, _ := strconv.Atoi(pu.Port()) + + db, err = NewHeadscaleDatabase( + types.DatabaseConfig{ + Type: types.DatabasePostgres, + Postgres: types.PostgresConfig{ + Host: pu.Hostname(), + User: pu.User.Username(), + Name: strings.TrimLeft(pu.Path, "/"), + Pass: pass, + Port: port, + Ssl: "disable", + }, + }, + "", + emptyCache(), + ) + if err != nil { + t.Fatal(err) + } + + return db +}