mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-30 02:43:05 +00:00
Compare commits
No commits in common. "7bd21685c94a56ad3d0579757ddcae2d41872e6a" and "950d062ea36fa5ab598a9b285d960d47969b44b1" have entirely different histories.
7bd21685c9
...
950d062ea3
12 changed files with 155 additions and 182 deletions
|
@ -120,12 +120,12 @@ func TestMigrations(t *testing.T) {
|
||||||
dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
|
dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
|
||||||
wantFunc: func(t *testing.T, h *HSDatabase) {
|
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||||
keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||||
kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest
|
kratest, err := ListPreAuthKeys(rx, "kratest")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra
|
testkra, err := ListPreAuthKeys(rx, "testkra")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,15 +91,15 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) {
|
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
return getNode(rx, uid, name)
|
return getNode(rx, user, name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// getNode finds a Node by name and user and returns the Node struct.
|
// getNode finds a Node by name and user and returns the Node struct.
|
||||||
func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
|
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
|
||||||
nodes, err := ListNodesByUser(tx, uid)
|
nodes, err := ListNodesByUser(tx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,10 +29,10 @@ func (s *Suite) TestGetNode(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -50,7 +50,7 @@ func (s *Suite) TestGetNode(c *check.C) {
|
||||||
trx := db.DB.Save(node)
|
trx := db.DB.Save(node)
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -87,7 +87,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -135,7 +135,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||||
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
|
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
_, err = db.getNode(user.Name, "testnode3")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ func (s *Suite) TestListPeers(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -189,7 +189,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
for _, name := range []string{"test", "admin"} {
|
for _, name := range []string{"test", "admin"} {
|
||||||
user, err := db.CreateUser(name)
|
user, err := db.CreateUser(name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
stor = append(stor, base{user, pak})
|
stor = append(stor, base{user, pak})
|
||||||
}
|
}
|
||||||
|
@ -281,10 +281,10 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -302,7 +302,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
}
|
}
|
||||||
db.DB.Save(node)
|
db.DB.Save(node)
|
||||||
|
|
||||||
nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode")
|
nodeFromDB, err := db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(nodeFromDB, check.NotNil)
|
c.Assert(nodeFromDB, check.NotNil)
|
||||||
|
|
||||||
|
@ -312,7 +312,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode")
|
nodeFromDB, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
||||||
|
@ -322,10 +322,10 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
_, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -348,7 +348,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
sTags := []string{"tag:test", "tag:foo"}
|
sTags := []string{"tag:test", "tag:foo"}
|
||||||
err = db.SetTags(node.ID, sTags)
|
err = db.SetTags(node.ID, sTags)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
node, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
|
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
|
||||||
|
|
||||||
|
@ -356,7 +356,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||||
err = db.SetTags(node.ID, eTags)
|
err = db.SetTags(node.ID, eTags)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
node, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
node.ForcedTags,
|
node.ForcedTags,
|
||||||
|
@ -367,7 +367,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
// test removing tags
|
// test removing tags
|
||||||
err = db.SetTags(node.ID, []string{})
|
err = db.SetTags(node.ID, []string{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
node, err = db.getNode("test", "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
|
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
|
||||||
}
|
}
|
||||||
|
@ -567,7 +567,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||||
user, err := adb.CreateUser("test")
|
user, err := adb.CreateUser("test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -699,10 +699,10 @@ func TestListEphemeralNodes(t *testing.T) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
|
pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
|
|
@ -23,27 +23,29 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
uid types.UserID,
|
// TODO(kradalby): Should be ID, not name
|
||||||
|
userName string,
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*types.PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
||||||
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
|
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||||
func CreatePreAuthKey(
|
func CreatePreAuthKey(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
uid types.UserID,
|
// TODO(kradalby): Should be ID, not name
|
||||||
|
userName string,
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*types.PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
user, err := GetUserByID(tx, uid)
|
user, err := GetUserByUsername(tx, userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -87,15 +89,15 @@ func CreatePreAuthKey(
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
|
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||||
return ListPreAuthKeysByUser(rx, uid)
|
return ListPreAuthKeys(rx, userName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
|
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||||
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
|
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
||||||
user, err := GetUserByID(tx, uid)
|
user, err := GetUserByUsername(tx, userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,14 +11,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
// ID does not exist
|
_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
||||||
_, err := db.CreatePreAuthKey(12345, true, false, nil, nil)
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// Did we get a valid key?
|
// Did we get a valid key?
|
||||||
|
@ -26,18 +26,17 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
c.Assert(len(key.Key), check.Equals, 48)
|
c.Assert(len(key.Key), check.Equals, 48)
|
||||||
|
|
||||||
// Make sure the User association is populated
|
// Make sure the User association is populated
|
||||||
c.Assert(key.User.ID, check.Equals, user.ID)
|
c.Assert(key.User.Name, check.Equals, user.Name)
|
||||||
|
|
||||||
// ID does not exist
|
_, err = db.ListPreAuthKeys("bogus")
|
||||||
_, err = db.ListPreAuthKeys(1000000)
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
keys, err := db.ListPreAuthKeys(user.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
|
|
||||||
// Make sure the User association is populated
|
// Make sure the User association is populated
|
||||||
c.Assert((keys)[0].User.ID, check.Equals, user.ID)
|
c.Assert((keys)[0].User.Name, check.Equals, user.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
|
@ -45,7 +44,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now().Add(-5 * time.Second)
|
now := time.Now().Add(-5 * time.Second)
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -63,7 +62,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||||
user, err := db.CreateUser("test3")
|
user, err := db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -75,7 +74,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test4")
|
user, err := db.CreateUser("test4")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -97,7 +96,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test5")
|
user, err := db.CreateUser("test5")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -119,7 +118,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test6")
|
user, err := db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -131,7 +130,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test3")
|
user, err := db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.IsNil)
|
c.Assert(pak.Expiration, check.IsNil)
|
||||||
|
|
||||||
|
@ -148,7 +147,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||||
user, err := db.CreateUser("test6")
|
user, err := db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
db.DB.Save(&pak)
|
db.DB.Save(&pak)
|
||||||
|
@ -161,15 +160,15 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||||
user, err := db.CreateUser("test8")
|
user, err := db.CreateUser("test8")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"})
|
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
||||||
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||||
|
|
||||||
tags := []string{"tag:test1", "tag:test2"}
|
tags := []string{"tag:test1", "tag:test2"}
|
||||||
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||||
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate)
|
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
listedPaks, err := db.ListPreAuthKeys("test8")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
gotTags := listedPaks[0].Proto().GetAclTags()
|
gotTags := listedPaks[0].Proto().GetAclTags()
|
||||||
sort.Sort(sort.StringSlice(gotTags))
|
sort.Sort(sort.StringSlice(gotTags))
|
||||||
|
|
|
@ -639,7 +639,7 @@ func EnableAutoApprovedRoutes(
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Uint("user.id", node.User.ID).
|
Str("user", node.User.Name).
|
||||||
Strs("routeApprovers", routeApprovers).
|
Strs("routeApprovers", routeApprovers).
|
||||||
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
||||||
Msg("looking up route for autoapproving")
|
Msg("looking up route for autoapproving")
|
||||||
|
|
|
@ -35,10 +35,10 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "test_get_route_node")
|
_, err = db.getNode("test", "test_get_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
|
@ -79,10 +79,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
_, err = db.getNode("test", "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -153,10 +153,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
_, err = db.getNode("test", "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -234,10 +234,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
_, err = db.getNode("test", "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
|
|
@ -40,21 +40,21 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error {
|
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return DestroyUser(tx, uid)
|
return DestroyUser(tx, name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// DestroyUser destroys a User. Returns error if the User does
|
// DestroyUser destroys a User. Returns error if the User does
|
||||||
// not exist or if there are nodes associated with it.
|
// not exist or if there are nodes associated with it.
|
||||||
func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
func DestroyUser(tx *gorm.DB, name string) error {
|
||||||
user, err := GetUserByID(tx, uid)
|
user, err := GetUserByUsername(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return ErrUserNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := ListNodesByUser(tx, uid)
|
nodes, err := ListNodesByUser(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||||
return ErrUserStillHasNodes
|
return ErrUserStillHasNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := ListPreAuthKeysByUser(tx, uid)
|
keys, err := ListPreAuthKeys(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -80,17 +80,17 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error {
|
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return RenameUser(tx, uid, newName)
|
return RenameUser(tx, oldName, newName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenameUser renames a User. Returns error if the User does
|
// RenameUser renames a User. Returns error if the User does
|
||||||
// not exist or if another User exists with the new name.
|
// not exist or if another User exists with the new name.
|
||||||
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||||
var err error
|
var err error
|
||||||
oldUser, err := GetUserByID(tx, uid)
|
oldUser, err := GetUserByUsername(tx, oldName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -98,25 +98,50 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
_, err = GetUserByUsername(tx, newName)
|
||||||
|
if err == nil {
|
||||||
|
return ErrUserExists
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrUserNotFound) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
oldUser.Name = newName
|
oldUser.Name = newName
|
||||||
|
|
||||||
if err := tx.Save(&oldUser).Error; err != nil {
|
if result := tx.Save(&oldUser); result.Error != nil {
|
||||||
return err
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||||
return GetUserByID(rx, uid)
|
return GetUserByUsername(rx, name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
user := types.User{}
|
user := types.User{}
|
||||||
if result := tx.First(&user, "id = ?", uid); errors.Is(
|
if result := tx.First(&user, "name = ?", name); errors.Is(
|
||||||
|
result.Error,
|
||||||
|
gorm.ErrRecordNotFound,
|
||||||
|
) {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) {
|
||||||
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||||
|
return GetUserByID(rx, id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) {
|
||||||
|
user := types.User{}
|
||||||
|
if result := tx.First(&user, "id = ?", id); errors.Is(
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -144,65 +169,54 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||||
return ListUsers(rx, where...)
|
return ListUsers(rx)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers gets all the existing users.
|
// ListUsers gets all the existing users.
|
||||||
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
func ListUsers(tx *gorm.DB) ([]types.User, error) {
|
||||||
if len(where) > 1 {
|
|
||||||
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
|
|
||||||
}
|
|
||||||
|
|
||||||
var user *types.User
|
|
||||||
if len(where) == 1 {
|
|
||||||
user = where[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
users := []types.User{}
|
users := []types.User{}
|
||||||
if err := tx.Where(user).Find(&users).Error; err != nil {
|
if err := tx.Find(&users).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByName returns a user if the provided username is
|
// ListNodesByUser gets all the nodes in a given user.
|
||||||
// unique, and otherwise an error.
|
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
|
||||||
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
err := util.CheckForFQDNRules(name)
|
||||||
users, err := hsdb.ListUsers(&types.User{Name: name})
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
user, err := GetUserByUsername(tx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(users) != 1 {
|
|
||||||
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &users[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListNodesByUser gets all the nodes in a given user.
|
|
||||||
func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
|
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil {
|
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error {
|
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return AssignNodeToUser(tx, node, uid)
|
return AssignNodeToUser(tx, node, username)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AssignNodeToUser assigns a Node to a user.
|
// AssignNodeToUser assigns a Node to a user.
|
||||||
func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error {
|
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
|
||||||
user, err := GetUserByID(tx, uid)
|
err := util.CheckForFQDNRules(username)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
user, err := GetUserByUsername(tx, username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
@ -19,24 +17,24 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = db.DestroyUser(types.UserID(user.ID))
|
err = db.DestroyUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetUserByID(types.UserID(user.ID))
|
_, err = db.GetUserByName("test")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
err := db.DestroyUser(9998)
|
err := db.DestroyUser("test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = db.DestroyUser(types.UserID(user.ID))
|
err = db.DestroyUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
|
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||||
|
@ -46,7 +44,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
user, err = db.CreateUser("test")
|
user, err = db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -59,7 +57,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
trx := db.DB.Save(&node)
|
trx := db.DB.Save(&node)
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
|
|
||||||
err = db.DestroyUser(types.UserID(user.ID))
|
err = db.DestroyUser("test")
|
||||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,28 +70,24 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = db.RenameUser(types.UserID(userTest.ID), "test-renamed")
|
err = db.RenameUser("test", "test-renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
users, err = db.ListUsers(&types.User{Name: "test"})
|
_, err = db.GetUserByName("test")
|
||||||
c.Assert(err, check.Equals, nil)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
c.Assert(len(users), check.Equals, 0)
|
|
||||||
|
|
||||||
users, err = db.ListUsers(&types.User{Name: "test-renamed"})
|
_, err = db.GetUserByName("test-renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
|
||||||
|
|
||||||
err = db.RenameUser(99988, "test")
|
err = db.RenameUser("test-does-not-exit", "test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
userTest2, err := db.CreateUser("test2")
|
userTest2, err := db.CreateUser("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||||
|
|
||||||
err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed")
|
err = db.RenameUser("test2", "test-renamed")
|
||||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
c.Assert(err, check.Equals, ErrUserExists)
|
||||||
c.Fatalf("expected failure with unique constraint, got: %s", err.Error())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
|
@ -103,7 +97,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
newUser, err := db.CreateUser("new")
|
newUser, err := db.CreateUser("new")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -117,15 +111,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
|
err = db.AssignNodeToUser(&node, newUser.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, newUser.ID)
|
c.Assert(node.UserID, check.Equals, newUser.ID)
|
||||||
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, 9584849)
|
err = db.AssignNodeToUser(&node, "non-existing-user")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
|
err = db.AssignNodeToUser(&node, newUser.Name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, newUser.ID)
|
c.Assert(node.UserID, check.Equals, newUser.ID)
|
||||||
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
|
@ -65,34 +65,24 @@ func (api headscaleV1APIServer) RenameUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.RenameUserRequest,
|
request *v1.RenameUserRequest,
|
||||||
) (*v1.RenameUserResponse, error) {
|
) (*v1.RenameUserResponse, error) {
|
||||||
oldUser, err := api.h.db.GetUserByName(request.GetOldName())
|
err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
user, err := api.h.db.GetUserByName(request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newUser, err := api.h.db.GetUserByName(request.GetNewName())
|
return &v1.RenameUserResponse{User: user.Proto()}, nil
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &v1.RenameUserResponse{User: newUser.Proto()}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) DeleteUser(
|
func (api headscaleV1APIServer) DeleteUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteUserRequest,
|
request *v1.DeleteUserRequest,
|
||||||
) (*v1.DeleteUserResponse, error) {
|
) (*v1.DeleteUserResponse, error) {
|
||||||
user, err := api.h.db.GetUserByName(request.GetName())
|
err := api.h.db.DestroyUser(request.GetName())
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = api.h.db.DestroyUser(types.UserID(user.ID))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -141,13 +131,8 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := api.h.db.GetUserByName(request.GetUser())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
preAuthKey, err := api.h.db.CreatePreAuthKey(
|
preAuthKey, err := api.h.db.CreatePreAuthKey(
|
||||||
types.UserID(user.ID),
|
request.GetUser(),
|
||||||
request.GetReusable(),
|
request.GetReusable(),
|
||||||
request.GetEphemeral(),
|
request.GetEphemeral(),
|
||||||
&expiration,
|
&expiration,
|
||||||
|
@ -183,12 +168,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListPreAuthKeysRequest,
|
request *v1.ListPreAuthKeysRequest,
|
||||||
) (*v1.ListPreAuthKeysResponse, error) {
|
) (*v1.ListPreAuthKeysResponse, error) {
|
||||||
user, err := api.h.db.GetUserByName(request.GetUser())
|
preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser())
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -426,20 +406,10 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListNodesRequest,
|
request *v1.ListNodesRequest,
|
||||||
) (*v1.ListNodesResponse, error) {
|
) (*v1.ListNodesResponse, error) {
|
||||||
// TODO(kradalby): it looks like this can be simplified a lot,
|
|
||||||
// the filtering of nodes by user, vs nodes as a whole can
|
|
||||||
// probably be done once.
|
|
||||||
// TODO(kradalby): This should be done in one tx.
|
|
||||||
|
|
||||||
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||||
if request.GetUser() != "" {
|
if request.GetUser() != "" {
|
||||||
user, err := api.h.db.GetUserByName(request.GetUser())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return db.ListNodesByUser(rx, types.UserID(user.ID))
|
return db.ListNodesByUser(rx, request.GetUser())
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -495,18 +465,12 @@ func (api headscaleV1APIServer) MoveNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.MoveNodeRequest,
|
request *v1.MoveNodeRequest,
|
||||||
) (*v1.MoveNodeResponse, error) {
|
) (*v1.MoveNodeResponse, error) {
|
||||||
// TODO(kradalby): This should be done in one tx.
|
|
||||||
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := api.h.db.GetUserByName(request.GetUser())
|
err = api.h.db.AssignNodeToUser(node, request.GetUser())
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ type PreAuthKey struct {
|
||||||
|
|
||||||
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
||||||
protoKey := v1.PreAuthKey{
|
protoKey := v1.PreAuthKey{
|
||||||
User: key.User.Username(),
|
User: key.User.Name,
|
||||||
Id: strconv.FormatUint(key.ID, util.Base10),
|
Id: strconv.FormatUint(key.ID, util.Base10),
|
||||||
Key: key.Key,
|
Key: key.Key,
|
||||||
Ephemeral: key.Ephemeral,
|
Ephemeral: key.Ephemeral,
|
||||||
|
|
|
@ -21,12 +21,12 @@ type User struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
// The index `idx_name_provider_identifier` is to enforce uniqueness
|
// The index `idx_name_provider_identifier` is to enforce uniqueness
|
||||||
// between Name and ProviderIdentifier. This ensures that
|
// between Name and ProviderIdentifier. This ensures that
|
||||||
// you can have multiple users with the same name in OIDC,
|
// you can have multiple usersnames of the same name in OIDC,
|
||||||
// but not if you only run with CLI users.
|
// but not if you only run with CLI users.
|
||||||
|
|
||||||
// Username for the user, is used if email is empty
|
// Username for the user, is used if email is empty
|
||||||
// Should not be used, please use Username().
|
// Should not be used, please use Username().
|
||||||
Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"`
|
Name string `gorm:"index,uniqueIndex:idx_name_provider_identifier"`
|
||||||
|
|
||||||
// Typically the full name of the user
|
// Typically the full name of the user
|
||||||
DisplayName string
|
DisplayName string
|
||||||
|
@ -54,9 +54,9 @@ type User struct {
|
||||||
// enabled with OIDC, which means that there is a domain involved which
|
// enabled with OIDC, which means that there is a domain involved which
|
||||||
// should be used throughout headscale, in information returned to the
|
// should be used throughout headscale, in information returned to the
|
||||||
// user and the Policy engine.
|
// user and the Policy engine.
|
||||||
|
// If the username does not contain an '@' it will be added to the end.
|
||||||
func (u *User) Username() string {
|
func (u *User) Username() string {
|
||||||
username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
|
username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
|
||||||
|
|
||||||
// TODO(kradalby): Wire up all of this for the future
|
// TODO(kradalby): Wire up all of this for the future
|
||||||
// if !strings.Contains(username, "@") {
|
// if !strings.Contains(username, "@") {
|
||||||
// username = username + "@"
|
// username = username + "@"
|
||||||
|
|
Loading…
Reference in a new issue