Restore foreign keys and add constraints (#1562)

* fix #1482, restore foregin keys, add constraints

* #1562, fix tests, fix formatting

* #1562: fix tests

* #1562: fix local run of test_integration
This commit is contained in:
MichaelKo 2024-05-16 02:40:14 +02:00 committed by GitHub
parent 2bac80cfbf
commit 7fd2485000
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 149 additions and 61 deletions

View file

@ -57,6 +57,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Add command to backfill IP addresses for nodes missing IPs from configured prefixes. [#1869](https://github.com/juanfont/headscale/pull/1869) - Add command to backfill IP addresses for nodes missing IPs from configured prefixes. [#1869](https://github.com/juanfont/headscale/pull/1869)
- Log available update as warning [#1877](https://github.com/juanfont/headscale/pull/1877) - Log available update as warning [#1877](https://github.com/juanfont/headscale/pull/1877)
- Add `autogroup:internet` to Policy [#1917](https://github.com/juanfont/headscale/pull/1917) - Add `autogroup:internet` to Policy [#1917](https://github.com/juanfont/headscale/pull/1917)
- Restore foreign keys and add constraints [#1562](https://github.com/juanfont/headscale/pull/1562)
## 0.22.3 (2023-05-12) ## 0.22.3 (2023-05-12)

View file

@ -31,6 +31,7 @@ test_integration:
--name headscale-test-suite \ --name headscale-test-suite \
-v $$PWD:$$PWD -w $$PWD/integration \ -v $$PWD:$$PWD -w $$PWD/integration \
-v /var/run/docker.sock:/var/run/docker.sock \ -v /var/run/docker.sock:/var/run/docker.sock \
-v $$PWD/control_logs:/tmp/control \
golang:1 \ golang:1 \
go run gotest.tools/gotestsum@latest -- -failfast ./... -timeout 120m -parallel 8 go run gotest.tools/gotestsum@latest -- -failfast ./... -timeout 120m -parallel 8

View file

@ -314,7 +314,11 @@ func (h *Headscale) handleAuthKey(
Msg("node was already registered before, refreshing with new auth key") Msg("node was already registered before, refreshing with new auth key")
node.NodeKey = nodeKey node.NodeKey = nodeKey
node.AuthKeyID = uint(pak.ID) pakID := uint(pak.ID)
if pakID != 0 {
node.AuthKeyID = &pakID
}
node.Expiry = &registerRequest.Expiry node.Expiry = &registerRequest.Expiry
node.User = pak.User node.User = pak.User
node.UserID = pak.UserID node.UserID = pak.UserID
@ -373,7 +377,6 @@ func (h *Headscale) handleAuthKey(
Expiry: &registerRequest.Expiry, Expiry: &registerRequest.Expiry,
NodeKey: nodeKey, NodeKey: nodeKey,
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID),
ForcedTags: pak.Proto().GetAclTags(), ForcedTags: pak.Proto().GetAclTags(),
} }
@ -389,6 +392,10 @@ func (h *Headscale) handleAuthKey(
return return
} }
pakID := uint(pak.ID)
if pakID != 0 {
nodeToRegister.AuthKeyID = &pakID
}
node, err = h.db.RegisterNode( node, err = h.db.RegisterNode(
nodeToRegister, nodeToRegister,
ipv4, ipv6, ipv4, ipv6,

View file

@ -91,7 +91,8 @@ func NewHeadscaleDatabase(
_ = tx.Migrator(). _ = tx.Migrator().
RenameColumn(&types.Node{}, "nickname", "given_name") RenameColumn(&types.Node{}, "nickname", "given_name")
// If the Node table has a column for registered, dbConn.Model(&types.Node{}).Where("auth_key_id = ?", 0).Update("auth_key_id", nil)
// If the Node table has a column for registered,
// find all occourences of "false" and drop them. Then // find all occourences of "false" and drop them. Then
// remove the column. // remove the column.
if tx.Migrator().HasColumn(&types.Node{}, "registered") { if tx.Migrator().HasColumn(&types.Node{}, "registered") {
@ -441,8 +442,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
db, err := gorm.Open( db, err := gorm.Open(
sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"), sqlite.Open(cfg.Sqlite.Path+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{ &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger,
Logger: dbLogger,
}, },
) )
@ -488,8 +488,7 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
} }
db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{ db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true, Logger: dbLogger,
Logger: dbLogger,
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -87,8 +87,11 @@ func TestIPAllocatorSequential(t *testing.T) {
name: "simple-with-db", name: "simple-with-db",
dbFunc: func() *HSDatabase { dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-with-db") db := dbForTest(t, "simple-with-db")
user := types.User{Name: ""}
db.DB.Save(&user)
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"), IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"), IPv6: nap("fd7a:115c:a1e0::1"),
}) })
@ -112,8 +115,11 @@ func TestIPAllocatorSequential(t *testing.T) {
name: "before-after-free-middle-in-db", name: "before-after-free-middle-in-db",
dbFunc: func() *HSDatabase { dbFunc: func() *HSDatabase {
db := dbForTest(t, "before-after-free-middle-in-db") db := dbForTest(t, "before-after-free-middle-in-db")
user := types.User{Name: ""}
db.DB.Save(&user)
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.2"), IPv4: nap("100.64.0.2"),
IPv6: nap("fd7a:115c:a1e0::2"), IPv6: nap("fd7a:115c:a1e0::2"),
}) })
@ -307,8 +313,11 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "simple-backfill-ipv6", name: "simple-backfill-ipv6",
dbFunc: func() *HSDatabase { dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv6") db := dbForTest(t, "simple-backfill-ipv6")
user := types.User{Name: ""}
db.DB.Save(&user)
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"), IPv4: nap("100.64.0.1"),
}) })
@ -337,8 +346,11 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "simple-backfill-ipv4", name: "simple-backfill-ipv4",
dbFunc: func() *HSDatabase { dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv4") db := dbForTest(t, "simple-backfill-ipv4")
user := types.User{Name: ""}
db.DB.Save(&user)
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv6: nap("fd7a:115c:a1e0::1"), IPv6: nap("fd7a:115c:a1e0::1"),
}) })
@ -367,8 +379,11 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "simple-backfill-remove-ipv6", name: "simple-backfill-remove-ipv6",
dbFunc: func() *HSDatabase { dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-remove-ipv6") db := dbForTest(t, "simple-backfill-remove-ipv6")
user := types.User{Name: ""}
db.DB.Save(&user)
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"), IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"), IPv6: nap("fd7a:115c:a1e0::1"),
}) })
@ -392,8 +407,11 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "simple-backfill-remove-ipv4", name: "simple-backfill-remove-ipv4",
dbFunc: func() *HSDatabase { dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-remove-ipv4") db := dbForTest(t, "simple-backfill-remove-ipv4")
user := types.User{Name: ""}
db.DB.Save(&user)
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"), IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"), IPv6: nap("fd7a:115c:a1e0::1"),
}) })
@ -417,17 +435,23 @@ func TestBackfillIPAddresses(t *testing.T) {
name: "multi-backfill-ipv6", name: "multi-backfill-ipv6",
dbFunc: func() *HSDatabase { dbFunc: func() *HSDatabase {
db := dbForTest(t, "simple-backfill-ipv6") db := dbForTest(t, "simple-backfill-ipv6")
user := types.User{Name: ""}
db.DB.Save(&user)
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.1"), IPv4: nap("100.64.0.1"),
}) })
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.2"), IPv4: nap("100.64.0.2"),
}) })
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.3"), IPv4: nap("100.64.0.3"),
}) })
db.DB.Save(&types.Node{ db.DB.Save(&types.Node{
User: user,
IPv4: nap("100.64.0.4"), IPv4: nap("100.64.0.4"),
}) })
@ -451,6 +475,8 @@ func TestBackfillIPAddresses(t *testing.T) {
"MachineKeyDatabaseField", "MachineKeyDatabaseField",
"NodeKeyDatabaseField", "NodeKeyDatabaseField",
"DiscoKeyDatabaseField", "DiscoKeyDatabaseField",
"User",
"UserID",
"Endpoints", "Endpoints",
"HostinfoDatabaseField", "HostinfoDatabaseField",
"Hostinfo", "Hostinfo",

View file

@ -279,7 +279,7 @@ func DeleteNode(tx *gorm.DB,
} }
// Unscoped causes the node to be fully removed from the database. // Unscoped causes the node to be fully removed from the database.
if err := tx.Unscoped().Delete(&node).Error; err != nil { if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil {
return changed, err return changed, err
} }

View file

@ -29,6 +29,7 @@ func (s *Suite) TestGetNode(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
@ -37,9 +38,10 @@ func (s *Suite) TestGetNode(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(node) trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)
_, err = db.getNode("test", "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -58,6 +60,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -65,9 +68,10 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -88,6 +92,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -95,9 +100,10 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
_, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -117,9 +123,9 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
Hostname: "testnode3", Hostname: "testnode3",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(1),
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -138,6 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) {
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
pakID := uint(pak.ID)
for index := 0; index <= 10; index++ { for index := 0; index <= 10; index++ {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
@ -149,9 +156,10 @@ func (s *Suite) TestListPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index), Hostname: "testnode" + strconv.Itoa(index),
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
} }
node0ByID, err := db.GetNodeByID(0) node0ByID, err := db.GetNodeByID(0)
@ -188,6 +196,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
for index := 0; index <= 10; index++ { for index := 0; index <= 10; index++ {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(stor[index%2].key.ID)
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))) v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
node := types.Node{ node := types.Node{
@ -198,9 +207,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
Hostname: "testnode" + strconv.Itoa(index), Hostname: "testnode" + strconv.Itoa(index),
UserID: stor[index%2].user.ID, UserID: stor[index%2].user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(stor[index%2].key.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
} }
aclPolicy := &policy.ACLPolicy{ aclPolicy := &policy.ACLPolicy{
@ -272,6 +282,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
@ -280,7 +291,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
Expiry: &time.Time{}, Expiry: &time.Time{},
} }
db.DB.Save(node) db.DB.Save(node)
@ -316,6 +327,7 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
machineKey2 := key.NewMachine() machineKey2 := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -324,9 +336,11 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
GivenName: "hostname-1", GivenName: "hostname-1",
UserID: user1.ID, UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(node)
trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
@ -357,6 +371,7 @@ func (s *Suite) TestSetTags(c *check.C) {
nodeKey := key.NewNode() nodeKey := key.NewNode()
machineKey := key.NewMachine() machineKey := key.NewMachine()
pakID := uint(pak.ID)
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -364,9 +379,11 @@ func (s *Suite) TestSetTags(c *check.C) {
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(node)
trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)
// assign simple tags // assign simple tags
sTags := []string{"tag:test", "tag:foo"} sTags := []string{"tag:test", "tag:foo"}
@ -548,6 +565,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
route2 := netip.MustParsePrefix("10.11.0.0/24") route2 := netip.MustParsePrefix("10.11.0.0/24")
v4 := netip.MustParseAddr("100.64.0.1") v4 := netip.MustParseAddr("100.64.0.1")
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
MachineKey: machineKey.Public(), MachineKey: machineKey.Public(),
@ -555,7 +573,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
Hostname: "test", Hostname: "test",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:exit"}, RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
@ -563,7 +581,8 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
IPv4: &v4, IPv4: &v4,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
sendUpdate, err := db.SaveNodeRoutes(&node) sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

View file

@ -197,9 +197,10 @@ func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
} }
nodes := types.Nodes{} nodes := types.Nodes{}
pakID := uint(pak.ID)
if err := tx. if err := tx.
Preload("AuthKey"). Preload("AuthKey").
Where(&types.Node{AuthKeyID: uint(pak.ID)}). Where(&types.Node{AuthKeyID: &pakID}).
Find(&nodes).Error; err != nil { Find(&nodes).Error; err != nil {
return nil, err return nil, err
} }

View file

@ -76,14 +76,16 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
@ -97,14 +99,16 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 1, ID: 1,
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -131,15 +135,17 @@ func (*Suite) TestEphemeralKeyReusable(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now().Add(-time.Second * 30) now := time.Now().Add(-time.Second * 30)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
_, err = db.ValidatePreAuthKey(pak.Key) _, err = db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -165,13 +171,14 @@ func (*Suite) TestEphemeralKeyNotReusable(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now().Add(-time.Second * 30) now := time.Now().Add(-time.Second * 30)
pakId := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testest", Hostname: "testest",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
LastSeen: &now, LastSeen: &now,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakId,
} }
db.DB.Save(&node) db.DB.Save(&node)

View file

@ -43,15 +43,17 @@ func (s *Suite) TestGetRoutes(c *check.C) {
RoutableIPs: []netip.Prefix{route}, RoutableIPs: []netip.Prefix{route},
} }
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "test_get_route_node", Hostname: "test_get_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
su, err := db.SaveNodeRoutes(&node) su, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -93,15 +95,17 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
RoutableIPs: []netip.Prefix{route, route2}, RoutableIPs: []netip.Prefix{route, route2},
} }
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
sendUpdate, err := db.SaveNodeRoutes(&node) sendUpdate, err := db.SaveNodeRoutes(&node)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -165,15 +169,17 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
hostInfo1 := tailcfg.Hostinfo{ hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route, route2}, RoutableIPs: []netip.Prefix{route, route2},
} }
pakID := uint(pak.ID)
node1 := types.Node{ node1 := types.Node{
ID: 1, ID: 1,
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
Hostinfo: &hostInfo1, Hostinfo: &hostInfo1,
} }
db.DB.Save(&node1) trx := db.DB.Save(&node1)
c.Assert(trx.Error, check.IsNil)
sendUpdate, err := db.SaveNodeRoutes(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -193,7 +199,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
Hostinfo: &hostInfo2, Hostinfo: &hostInfo2,
} }
db.DB.Save(&node2) db.DB.Save(&node2)
@ -247,16 +253,18 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
} }
now := time.Now() now := time.Now()
pakID := uint(pak.ID)
node1 := types.Node{ node1 := types.Node{
ID: 1, ID: 1,
Hostname: "test_enable_route_node", Hostname: "test_enable_route_node",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
Hostinfo: &hostInfo1, Hostinfo: &hostInfo1,
LastSeen: &now, LastSeen: &now,
} }
db.DB.Save(&node1) trx := db.DB.Save(&node1)
c.Assert(trx.Error, check.IsNil)
sendUpdate, err := db.SaveNodeRoutes(&node1) sendUpdate, err := db.SaveNodeRoutes(&node1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -617,7 +625,16 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
db := dbForTest(t, tt.name) db := dbForTest(t, tt.name)
user := types.User{Name: tt.name}
if err := db.DB.Save(&user).Error; err != nil {
t.Fatalf("failed to create user: %s", err)
}
for _, route := range tt.routes { for _, route := range tt.routes {
route.Node.User = user
if err := db.DB.Save(&route.Node).Error; err != nil {
t.Fatalf("failed to create node: %s", err)
}
if err := db.DB.Save(&route).Error; err != nil { if err := db.DB.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err) t.Fatalf("failed to create route: %s", err)
} }
@ -1013,8 +1030,16 @@ func TestFailoverRouteTx(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
db := dbForTest(t, tt.name) db := dbForTest(t, tt.name)
user := types.User{Name: "test"}
if err := db.DB.Save(&user).Error; err != nil {
t.Fatalf("failed to create user: %s", err)
}
for _, route := range tt.routes { for _, route := range tt.routes {
route.Node.User = user
if err := db.DB.Save(&route.Node).Error; err != nil {
t.Fatalf("failed to create node: %s", err)
}
if err := db.DB.Save(&route).Error; err != nil { if err := db.DB.Save(&route).Error; err != nil {
t.Fatalf("failed to create route: %s", err) t.Fatalf("failed to create route: %s", err)
} }

View file

@ -46,14 +46,16 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testnode", Hostname: "testnode",
UserID: user.ID, UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
err = db.DestroyUser("test") err = db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserStillHasNodes) c.Assert(err, check.Equals, ErrUserStillHasNodes)
@ -98,14 +100,16 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pakID := uint(pak.ID)
node := types.Node{ node := types.Node{
ID: 0, ID: 0,
Hostname: "testnode", Hostname: "testnode",
UserID: oldUser.ID, UserID: oldUser.ID,
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID), AuthKeyID: &pakID,
} }
db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
c.Assert(node.UserID, check.Equals, oldUser.ID) c.Assert(node.UserID, check.Equals, oldUser.ID)
err = db.AssignNodeToUser(&node, newUser.Name) err = db.AssignNodeToUser(&node, newUser.Name)

View file

@ -187,10 +187,9 @@ func Test_fullMapResponse(t *testing.T) {
UserID: 0, UserID: 0,
User: types.User{Name: "mini"}, User: types.User{Name: "mini"},
ForcedTags: []string{}, ForcedTags: []string{},
AuthKeyID: 0, AuthKey: &types.PreAuthKey{},
AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen,
LastSeen: &lastSeen, Expiry: &expire,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{ Routes: []types.Route{
{ {

View file

@ -97,7 +97,6 @@ func TestTailNode(t *testing.T) {
Name: "mini", Name: "mini",
}, },
ForcedTags: []string{}, ForcedTags: []string{},
AuthKeyID: 0,
AuthKey: &types.PreAuthKey{}, AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Expiry: &expire, Expiry: &expire,

View file

@ -108,20 +108,20 @@ type Node struct {
// parts of headscale. // parts of headscale.
GivenName string `gorm:"type:varchar(63);unique_index"` GivenName string `gorm:"type:varchar(63);unique_index"`
UserID uint UserID uint
User User `gorm:"foreignKey:UserID"` User User `gorm:"constraint:OnDelete:CASCADE;"`
RegisterMethod string RegisterMethod string
ForcedTags StringList ForcedTags StringList
// TODO(kradalby): This seems like irrelevant information? // TODO(kradalby): This seems like irrelevant information?
AuthKeyID uint AuthKeyID *uint `sql:"DEFAULT:NULL"`
AuthKey *PreAuthKey AuthKey *PreAuthKey `gorm:"constraint:OnDelete:SET NULL;"`
LastSeen *time.Time LastSeen *time.Time
Expiry *time.Time Expiry *time.Time
Routes []Route Routes []Route `gorm:"constraint:OnDelete:CASCADE;"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time

View file

@ -14,11 +14,11 @@ type PreAuthKey struct {
ID uint64 `gorm:"primary_key"` ID uint64 `gorm:"primary_key"`
Key string Key string
UserID uint UserID uint
User User User User `gorm:"constraint:OnDelete:CASCADE;"`
Reusable bool Reusable bool
Ephemeral bool `gorm:"default:false"` Ephemeral bool `gorm:"default:false"`
Used bool `gorm:"default:false"` Used bool `gorm:"default:false"`
ACLTags []PreAuthKeyACLTag ACLTags []PreAuthKeyACLTag `gorm:"constraint:OnDelete:CASCADE;"`
CreatedAt *time.Time CreatedAt *time.Time
Expiration *time.Time Expiration *time.Time