From f6276ab9d2856d49d15909c3c0332015d4252af7 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 23 Nov 2024 11:19:52 +0100 Subject: [PATCH] fix postgres constraints, add postgres testing This commit fixes the constraint syntax so it is both valid for sqlite and postgres. To validate this, I've added a new postgres testing library and a helper that will spin up local postgres, setup a db and use it in the constraints tests. This should also help testing db stuff in the future. postgres has been added to the nix dev shell and is now required for running the unit tests. Signed-off-by: Kristoffer Dalby --- flake.nix | 3 +- go.mod | 2 ++ go.sum | 3 ++ hscontrol/db/db.go | 10 +++--- hscontrol/db/db_test.go | 29 ++++++++++++------ hscontrol/db/node_test.go | 6 ++-- hscontrol/db/suite_test.go | 63 ++++++++++++++++++++++++++++++++++++-- 7 files changed, 95 insertions(+), 21 deletions(-) diff --git a/flake.nix b/flake.nix index 8faae71e..90a2aad8 100644 --- a/flake.nix +++ b/flake.nix @@ -32,7 +32,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorHash = "sha256-Qoqu2k4vvnbRFLmT/v8lI+HCEWqJsHFs8uZRfNmwQpo="; + vendorHash = "sha256-4VNiHUblvtcl9UetwiL6ZeVYb0h2e9zhYVsirhAkvOg="; subPackages = ["cmd/headscale"]; @@ -102,6 +102,7 @@ ko yq-go ripgrep + postgresql # 'dot' is needed for pprof graphs # go tool pprof -http=: diff --git a/go.mod b/go.mod index 7eac4652..8d51fc6a 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( gorm.io/gorm v1.25.11 tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 zgo.at/zcache/v2 v2.1.0 + zombiezen.com/go/postgrestest v1.0.1 ) require ( @@ -134,6 +135,7 @@ require ( github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index cc15ef6c..9315dbb6 100644 --- a/go.sum +++ b/go.sum @@ -311,6 +311,7 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= +github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4= @@ -731,3 +732,5 @@ tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 h1:nfRWV6ECxwNvvXKtbqSVs tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg= zgo.at/zcache/v2 v2.1.0 h1:USo+ubK+R4vtjw4viGzTe/zjXyPw6R7SK/RL3epBBxs= zgo.at/zcache/v2 v2.1.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk= +zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4= +zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ= diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index a4f06554..0d9120c2 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -505,11 +505,11 @@ func NewHeadscaleDatabase( // - A provider_identifier is unique // - A user name is unique if there is no provider_identifier is not set for _, idx := range []string{ - "DROP INDEX IF EXISTS `idx_provider_identifier`", - "DROP INDEX IF EXISTS `idx_name_provider_identifier`", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_provider_identifier` ON `users` (`provider_identifier`) WHERE provider_identifier IS NOT NULL;", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_provider_identifier` ON `users` (`name`,`provider_identifier`);", - "CREATE UNIQUE INDEX IF NOT EXISTS `idx_name_no_provider_identifier` ON `users` (`name`) WHERE provider_identifier IS NULL;", + "DROP INDEX IF EXISTS idx_provider_identifier", + "DROP INDEX IF EXISTS idx_name_provider_identifier", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL;", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name_provider_identifier ON users (name,provider_identifier);", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL;", } { err = tx.Exec(idx).Error if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 34115647..bafe1e1b 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -9,6 +9,7 @@ import ( "path/filepath" "slices" "sort" + "strings" "testing" "time" @@ -259,6 +260,16 @@ func emptyCache() *zcache.Cache[string, types.Node] { return zcache.New[string, types.Node](time.Minute, time.Hour) } +// requireConstraintFailed checks if the error is a constraint failure with +// either SQLite and PostgreSQL error messages. +func requireConstraintFailed(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { + require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) + } +} + func TestConstraints(t *testing.T) { tests := []struct { name string @@ -270,9 +281,7 @@ func TestConstraints(t *testing.T) { _, err := CreateUser(db, "user1") require.NoError(t, err) _, err = CreateUser(db, "user1") - require.Error(t, err) - assert.Contains(t, err.Error(), "UNIQUE constraint failed:") - // require.Contains(t, err.Error(), "user already exists") + requireConstraintFailed(t, err) }, }, { @@ -294,8 +303,7 @@ func TestConstraints(t *testing.T) { user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} err = db.Save(&user).Error - require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed:") + requireConstraintFailed(t, err) }, }, { @@ -317,8 +325,7 @@ func TestConstraints(t *testing.T) { user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} err = db.Save(&user).Error - require.Error(t, err) - require.Contains(t, err.Error(), "UNIQUE constraint failed:") + requireConstraintFailed(t, err) }, }, { @@ -354,8 +361,12 @@ func TestConstraints(t *testing.T) { } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db, err := newTestDB() + t.Run(tt.name+"-postgres", func(t *testing.T) { + db := newPostgresTestDB(t) + tt.run(t, db.DB.Debug()) + }) + t.Run(tt.name+"-sqlite", func(t *testing.T) { + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating database: %s", err) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 6c1d1099..bb29b00a 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -558,7 +558,7 @@ func TestAutoApproveRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - adb, err := newTestDB() + adb, err := newSQLiteTestDB() require.NoError(t, err) pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) @@ -692,7 +692,7 @@ func generateRandomNumber(t *testing.T, max int64) int64 { } func TestListEphemeralNodes(t *testing.T) { - db, err := newTestDB() + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) } @@ -748,7 +748,7 @@ func TestListEphemeralNodes(t *testing.T) { } func TestRenameNode(t *testing.T) { - db, err := newTestDB() + db, err := newSQLiteTestDB() if err != nil { t.Fatalf("creating db: %s", err) } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 6cc46d3d..fb7ce1df 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -1,12 +1,17 @@ package db import ( + "context" "log" + "net/url" "os" + "strconv" + "strings" "testing" "github.com/juanfont/headscale/hscontrol/types" "gopkg.in/check.v1" + "zombiezen.com/go/postgrestest" ) func Test(t *testing.T) { @@ -36,13 +41,15 @@ func (s *Suite) ResetDB(c *check.C) { // } var err error - db, err = newTestDB() + db, err = newSQLiteTestDB() if err != nil { c.Fatal(err) } } -func newTestDB() (*HSDatabase, error) { +// TODO(kradalby): make this a t.Helper when we dont depend +// on check test framework. +func newSQLiteTestDB() (*HSDatabase, error) { var err error tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") if err != nil { @@ -53,7 +60,7 @@ func newTestDB() (*HSDatabase, error) { db, err = NewHeadscaleDatabase( types.DatabaseConfig{ - Type: "sqlite3", + Type: types.DatabaseSqlite, Sqlite: types.SqliteConfig{ Path: tmpDir + "/headscale_test.db", }, @@ -67,3 +74,53 @@ func newTestDB() (*HSDatabase, error) { return db, nil } + +func newPostgresTestDB(t *testing.T) *HSDatabase { + t.Helper() + + var err error + tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") + if err != nil { + t.Fatal(err) + } + + log.Printf("database path: %s", tmpDir+"/headscale_test.db") + + ctx := context.Background() + srv, err := postgrestest.Start(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(srv.Cleanup) + + u, err := srv.CreateDatabase(ctx) + if err != nil { + t.Fatal(err) + } + t.Logf("created local postgres: %s", u) + pu, _ := url.Parse(u) + + pass, _ := pu.User.Password() + port, _ := strconv.Atoi(pu.Port()) + + db, err = NewHeadscaleDatabase( + types.DatabaseConfig{ + Type: types.DatabasePostgres, + Postgres: types.PostgresConfig{ + Host: pu.Hostname(), + User: pu.User.Username(), + Name: strings.TrimLeft(pu.Path, "/"), + Pass: pass, + Port: port, + Ssl: "disable", + }, + }, + "", + emptyCache(), + ) + if err != nil { + t.Fatal(err) + } + + return db +}