diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b03fc434..f4659332 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,4 +34,4 @@ jobs: - name: Run tests if: steps.changed-files.outputs.files == 'true' - run: nix develop --check + run: nix develop --command -- gotestsum diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 331dba54..3aaa7eeb 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -51,8 +51,8 @@ func NewHeadscaleDatabase( dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{ - // New migrations should be added as transactions at the end of this list. - // The initial commit here is quite messy, completely out of order and + // New migrations must be added as transactions at the end of this list. + // The initial migration here is quite messy, completely out of order and // has no versioning and is the tech debt of not having versioned migrations // prior to this point. This first migration is all DB changes to bring a DB // up to 0.23.0. @@ -123,9 +123,21 @@ func NewHeadscaleDatabase( } } - err = tx.AutoMigrate(&types.Route{}) - if err != nil { - return err + // Only run automigrate Route table if it does not exist. It has only been + // changed ones, when machines where renamed to nodes, which is covered + // further up. This whole initial integration is a mess and if AutoMigrate + // is ran on a 0.22 to 0.23 update, it will wipe all the routes. + if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { + err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error + if err != nil { + return err + } + } + if !tx.Migrator().HasTable(&types.Route{}) { + err = tx.AutoMigrate(&types.Route{}) + if err != nil { + return err + } } err = tx.AutoMigrate(&types.Node{}) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go new file mode 100644 index 00000000..b32d93ce --- /dev/null +++ b/hscontrol/db/db_test.go @@ -0,0 +1,168 @@ +package db + +import ( + "fmt" + "io" + "net/netip" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "gorm.io/gorm" +) + +func TestMigrations(t *testing.T) { + ipp := func(p string) types.IPPrefix { + return types.IPPrefix(netip.MustParsePrefix(p)) + } + r := func(id uint64, p string, a, e, i bool) types.Route { + return types.Route{ + NodeID: id, + Prefix: ipp(p), + Advertised: a, + Enabled: e, + IsPrimary: i, + } + } + tests := []struct { + dbPath string + wantFunc func(*testing.T, *HSDatabase) + wantErr string + }{ + { + dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite", + wantFunc: func(t *testing.T, h *HSDatabase) { + routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) { + return GetRoutes(rx) + }) + assert.NoError(t, err) + + assert.Len(t, routes, 10) + want := types.Routes{ + r(1, "0.0.0.0/0", true, true, false), + r(1, "::/0", true, true, false), + r(1, "10.9.110.0/24", true, true, true), + r(26, "172.100.100.0/24", true, true, true), + r(26, "172.100.100.0/24", true, false, false), + r(31, "0.0.0.0/0", true, true, false), + r(31, "0.0.0.0/0", true, false, false), + r(31, "::/0", true, true, false), + r(31, "::/0", true, false, false), + r(32, "192.168.0.24/32", true, true, true), + } + if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool { + return x == y + })); diff != "" { + t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff) + } + }, + }, + { + dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite", + wantFunc: func(t *testing.T, h *HSDatabase) { + routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) { + return GetRoutes(rx) + }) + assert.NoError(t, err) + + assert.Len(t, routes, 4) + want := types.Routes{ + // These routes exists, but have no nodes associated with them + // when the migration starts. + // r(1, "0.0.0.0/0", true, true, false), + // r(1, "::/0", true, true, false), + // r(3, "0.0.0.0/0", true, true, false), + // r(3, "::/0", true, true, false), + // r(5, "0.0.0.0/0", true, true, false), + // r(5, "::/0", true, true, false), + // r(6, "0.0.0.0/0", true, true, false), + // r(6, "::/0", true, true, false), + // r(6, "10.0.0.0/8", true, false, false), + // r(7, "0.0.0.0/0", true, true, false), + // r(7, "::/0", true, true, false), + // r(7, "10.0.0.0/8", true, false, false), + // r(9, "0.0.0.0/0", true, true, false), + // r(9, "::/0", true, true, false), + // r(9, "10.0.0.0/8", true, true, false), + // r(11, "0.0.0.0/0", true, true, false), + // r(11, "::/0", true, true, false), + // r(11, "10.0.0.0/8", true, true, true), + // r(12, "0.0.0.0/0", true, true, false), + // r(12, "::/0", true, true, false), + // r(12, "10.0.0.0/8", true, false, false), + // + // These nodes exists, so routes should be kept. + r(13, "10.0.0.0/8", true, false, false), + r(13, "0.0.0.0/0", true, true, false), + r(13, "::/0", true, true, false), + r(13, "10.18.80.2/32", true, true, true), + } + if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool { + return x == y + })); diff != "" { + t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.dbPath, func(t *testing.T) { + dbPath, err := testCopyOfDatabase(tt.dbPath) + if err != nil { + t.Fatalf("copying db for test: %s", err) + } + + hsdb, err := NewHeadscaleDatabase(types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, + }, "") + if err != nil && tt.wantErr != err.Error() { + t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantFunc != nil { + tt.wantFunc(t, hsdb) + } + }) + } +} + +func testCopyOfDatabase(src string) (string, error) { + sourceFileStat, err := os.Stat(src) + if err != nil { + return "", err + } + + if !sourceFileStat.Mode().IsRegular() { + return "", fmt.Errorf("%s is not a regular file", src) + } + + source, err := os.Open(src) + if err != nil { + return "", err + } + defer source.Close() + + tmpDir, err := os.MkdirTemp("", "hsdb-test-*") + if err != nil { + return "", err + } + + fn := filepath.Base(src) + dst := filepath.Join(tmpDir, fn) + + destination, err := os.Create(dst) + if err != nil { + return "", err + } + defer destination.Close() + _, err = io.Copy(destination, source) + return dst, err +} diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index a2515ebf..a9e78a45 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "sort" + "sync" "time" "github.com/juanfont/headscale/hscontrol/types" @@ -12,7 +13,6 @@ import ( "github.com/patrickmn/go-cache" "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" - "github.com/sasha-s/go-deadlock" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -724,7 +724,7 @@ func ExpireExpiredNodes(tx *gorm.DB, // It is used to delete ephemeral nodes that have disconnected and should be // cleaned up. type EphemeralGarbageCollector struct { - mu deadlock.Mutex + mu sync.Mutex deleteFunc func(types.NodeID) toBeDeleted map[types.NodeID]*time.Timer @@ -752,10 +752,9 @@ func (e *EphemeralGarbageCollector) Close() { // Schedule schedules a node for deletion after the expiry duration. func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) { e.mu.Lock() - defer e.mu.Unlock() - timer := time.NewTimer(expiry) e.toBeDeleted[nodeID] = timer + e.mu.Unlock() go func() { select { diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index ad94f064..c83da120 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -609,12 +609,14 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) { }) go e.Start() - e.Schedule(1, 1*time.Second) - e.Schedule(2, 2*time.Second) - e.Schedule(3, 3*time.Second) - e.Schedule(4, 4*time.Second) - e.Cancel(2) - e.Cancel(4) + go e.Schedule(1, 1*time.Second) + go e.Schedule(2, 2*time.Second) + go e.Schedule(3, 3*time.Second) + go e.Schedule(4, 4*time.Second) + + time.Sleep(time.Second) + go e.Cancel(2) + go e.Cancel(4) time.Sleep(6 * time.Second) diff --git a/hscontrol/db/testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite b/hscontrol/db/testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite new file mode 100644 index 00000000..10e1aaec Binary files /dev/null and b/hscontrol/db/testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite differ diff --git a/hscontrol/db/testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite b/hscontrol/db/testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite new file mode 100644 index 00000000..dbe96962 Binary files /dev/null and b/hscontrol/db/testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite differ diff --git a/hscontrol/util/test.go b/hscontrol/util/test.go index 0a23acb4..d93ae1f2 100644 --- a/hscontrol/util/test.go +++ b/hscontrol/util/test.go @@ -4,7 +4,9 @@ import ( "net/netip" "github.com/google/go-cmp/cmp" + "tailscale.com/types/ipproto" "tailscale.com/types/key" + "tailscale.com/types/views" ) var PrefixComparer = cmp.Comparer(func(x, y netip.Prefix) bool { @@ -31,6 +33,8 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool { return x.String() == y.String() }) +var ViewSliceIPProtoComparer = cmp.Comparer(func(a, b views.Slice[ipproto.Proto]) bool { return views.SliceEqual(a, b) }) + var Comparers []cmp.Option = []cmp.Option{ - IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, + IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, ViewSliceIPProtoComparer, } diff --git a/integration/route_test.go b/integration/route_test.go index ed371642..a92258af 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -1170,7 +1170,7 @@ func TestSubnetRouteACL(t *testing.T) { }, } - if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" { + if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff) } @@ -1220,7 +1220,7 @@ func TestSubnetRouteACL(t *testing.T) { }, } - if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" { + if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff) } }