fix route table migration wiping routes 0.22 -> 0.23 (#2076)

This commit is contained in:
Kristoffer Dalby 2024-08-27 18:54:28 +02:00 committed by GitHub
parent 827e3e83ae
commit cf6a606d74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 204 additions and 19 deletions

View file

@ -34,4 +34,4 @@ jobs:
- name: Run tests - name: Run tests
if: steps.changed-files.outputs.files == 'true' if: steps.changed-files.outputs.files == 'true'
run: nix develop --check run: nix develop --command -- gotestsum

View file

@ -51,8 +51,8 @@ func NewHeadscaleDatabase(
dbConn, dbConn,
gormigrate.DefaultOptions, gormigrate.DefaultOptions,
[]*gormigrate.Migration{ []*gormigrate.Migration{
// New migrations should be added as transactions at the end of this list. // New migrations must be added as transactions at the end of this list.
// The initial commit here is quite messy, completely out of order and // The initial migration here is quite messy, completely out of order and
// has no versioning and is the tech debt of not having versioned migrations // has no versioning and is the tech debt of not having versioned migrations
// prior to this point. This first migration is all DB changes to bring a DB // prior to this point. This first migration is all DB changes to bring a DB
// up to 0.23.0. // up to 0.23.0.
@ -123,10 +123,22 @@ func NewHeadscaleDatabase(
} }
} }
// Only run automigrate Route table if it does not exist. It has only been
// changed ones, when machines where renamed to nodes, which is covered
// further up. This whole initial integration is a mess and if AutoMigrate
// is ran on a 0.22 to 0.23 update, it will wipe all the routes.
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
if err != nil {
return err
}
}
if !tx.Migrator().HasTable(&types.Route{}) {
err = tx.AutoMigrate(&types.Route{}) err = tx.AutoMigrate(&types.Route{})
if err != nil { if err != nil {
return err return err
} }
}
err = tx.AutoMigrate(&types.Node{}) err = tx.AutoMigrate(&types.Node{})
if err != nil { if err != nil {

168
hscontrol/db/db_test.go Normal file
View file

@ -0,0 +1,168 @@
package db
import (
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
func TestMigrations(t *testing.T) {
ipp := func(p string) types.IPPrefix {
return types.IPPrefix(netip.MustParsePrefix(p))
}
r := func(id uint64, p string, a, e, i bool) types.Route {
return types.Route{
NodeID: id,
Prefix: ipp(p),
Advertised: a,
Enabled: e,
IsPrimary: i,
}
}
tests := []struct {
dbPath string
wantFunc func(*testing.T, *HSDatabase)
wantErr string
}{
{
dbPath: "testdata/0-22-3-to-0-23-0-routes-are-dropped-2063.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx)
})
assert.NoError(t, err)
assert.Len(t, routes, 10)
want := types.Routes{
r(1, "0.0.0.0/0", true, true, false),
r(1, "::/0", true, true, false),
r(1, "10.9.110.0/24", true, true, true),
r(26, "172.100.100.0/24", true, true, true),
r(26, "172.100.100.0/24", true, false, false),
r(31, "0.0.0.0/0", true, true, false),
r(31, "0.0.0.0/0", true, false, false),
r(31, "::/0", true, true, false),
r(31, "::/0", true, false, false),
r(32, "192.168.0.24/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
},
{
dbPath: "testdata/0-22-3-to-0-23-0-routes-fail-foreign-key-2076.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) {
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
return GetRoutes(rx)
})
assert.NoError(t, err)
assert.Len(t, routes, 4)
want := types.Routes{
// These routes exists, but have no nodes associated with them
// when the migration starts.
// r(1, "0.0.0.0/0", true, true, false),
// r(1, "::/0", true, true, false),
// r(3, "0.0.0.0/0", true, true, false),
// r(3, "::/0", true, true, false),
// r(5, "0.0.0.0/0", true, true, false),
// r(5, "::/0", true, true, false),
// r(6, "0.0.0.0/0", true, true, false),
// r(6, "::/0", true, true, false),
// r(6, "10.0.0.0/8", true, false, false),
// r(7, "0.0.0.0/0", true, true, false),
// r(7, "::/0", true, true, false),
// r(7, "10.0.0.0/8", true, false, false),
// r(9, "0.0.0.0/0", true, true, false),
// r(9, "::/0", true, true, false),
// r(9, "10.0.0.0/8", true, true, false),
// r(11, "0.0.0.0/0", true, true, false),
// r(11, "::/0", true, true, false),
// r(11, "10.0.0.0/8", true, true, true),
// r(12, "0.0.0.0/0", true, true, false),
// r(12, "::/0", true, true, false),
// r(12, "10.0.0.0/8", true, false, false),
//
// These nodes exists, so routes should be kept.
r(13, "10.0.0.0/8", true, false, false),
r(13, "0.0.0.0/0", true, true, false),
r(13, "::/0", true, true, false),
r(13, "10.18.80.2/32", true, true, true),
}
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
return x == y
})); diff != "" {
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
}
},
},
}
for _, tt := range tests {
t.Run(tt.dbPath, func(t *testing.T) {
dbPath, err := testCopyOfDatabase(tt.dbPath)
if err != nil {
t.Fatalf("copying db for test: %s", err)
}
hsdb, err := NewHeadscaleDatabase(types.DatabaseConfig{
Type: "sqlite3",
Sqlite: types.SqliteConfig{
Path: dbPath,
},
}, "")
if err != nil && tt.wantErr != err.Error() {
t.Errorf("TestMigrations() unexpected error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantFunc != nil {
tt.wantFunc(t, hsdb)
}
})
}
}
func testCopyOfDatabase(src string) (string, error) {
sourceFileStat, err := os.Stat(src)
if err != nil {
return "", err
}
if !sourceFileStat.Mode().IsRegular() {
return "", fmt.Errorf("%s is not a regular file", src)
}
source, err := os.Open(src)
if err != nil {
return "", err
}
defer source.Close()
tmpDir, err := os.MkdirTemp("", "hsdb-test-*")
if err != nil {
return "", err
}
fn := filepath.Base(src)
dst := filepath.Join(tmpDir, fn)
destination, err := os.Create(dst)
if err != nil {
return "", err
}
defer destination.Close()
_, err = io.Copy(destination, source)
return dst, err
}

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"sort" "sort"
"sync"
"time" "time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
@ -12,7 +13,6 @@ import (
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -724,7 +724,7 @@ func ExpireExpiredNodes(tx *gorm.DB,
// It is used to delete ephemeral nodes that have disconnected and should be // It is used to delete ephemeral nodes that have disconnected and should be
// cleaned up. // cleaned up.
type EphemeralGarbageCollector struct { type EphemeralGarbageCollector struct {
mu deadlock.Mutex mu sync.Mutex
deleteFunc func(types.NodeID) deleteFunc func(types.NodeID)
toBeDeleted map[types.NodeID]*time.Timer toBeDeleted map[types.NodeID]*time.Timer
@ -752,10 +752,9 @@ func (e *EphemeralGarbageCollector) Close() {
// Schedule schedules a node for deletion after the expiry duration. // Schedule schedules a node for deletion after the expiry duration.
func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) { func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) {
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock()
timer := time.NewTimer(expiry) timer := time.NewTimer(expiry)
e.toBeDeleted[nodeID] = timer e.toBeDeleted[nodeID] = timer
e.mu.Unlock()
go func() { go func() {
select { select {

View file

@ -609,12 +609,14 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
}) })
go e.Start() go e.Start()
e.Schedule(1, 1*time.Second) go e.Schedule(1, 1*time.Second)
e.Schedule(2, 2*time.Second) go e.Schedule(2, 2*time.Second)
e.Schedule(3, 3*time.Second) go e.Schedule(3, 3*time.Second)
e.Schedule(4, 4*time.Second) go e.Schedule(4, 4*time.Second)
e.Cancel(2)
e.Cancel(4) time.Sleep(time.Second)
go e.Cancel(2)
go e.Cancel(4)
time.Sleep(6 * time.Second) time.Sleep(6 * time.Second)

View file

@ -4,7 +4,9 @@ import (
"net/netip" "net/netip"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"tailscale.com/types/ipproto"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/views"
) )
var PrefixComparer = cmp.Comparer(func(x, y netip.Prefix) bool { var PrefixComparer = cmp.Comparer(func(x, y netip.Prefix) bool {
@ -31,6 +33,8 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool {
return x.String() == y.String() return x.String() == y.String()
}) })
var ViewSliceIPProtoComparer = cmp.Comparer(func(a, b views.Slice[ipproto.Proto]) bool { return views.SliceEqual(a, b) })
var Comparers []cmp.Option = []cmp.Option{ var Comparers []cmp.Option = []cmp.Option{
IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, ViewSliceIPProtoComparer,
} }

View file

@ -1170,7 +1170,7 @@ func TestSubnetRouteACL(t *testing.T) {
}, },
} }
if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.PrefixComparer); diff != "" { if diff := cmp.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff) t.Errorf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)
} }
@ -1220,7 +1220,7 @@ func TestSubnetRouteACL(t *testing.T) {
}, },
} }
if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.PrefixComparer); diff != "" { if diff := cmp.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" {
t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff) t.Errorf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)
} }
} }