Compare commits

..

No commits in common. "7bd21685c94a56ad3d0579757ddcae2d41872e6a" and "950d062ea36fa5ab598a9b285d960d47969b44b1" have entirely different histories.

12 changed files with 155 additions and 182 deletions

View file

@ -120,12 +120,12 @@ func TestMigrations(t *testing.T) {
dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite", dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
wantFunc: func(t *testing.T, h *HSDatabase) { wantFunc: func(t *testing.T, h *HSDatabase) {
keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest kratest, err := ListPreAuthKeys(rx, "kratest")
if err != nil { if err != nil {
return nil, err return nil, err
} }
testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra testkra, err := ListPreAuthKeys(rx, "testkra")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -91,15 +91,15 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
}) })
} }
func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) { func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return getNode(rx, uid, name) return getNode(rx, user, name)
}) })
} }
// getNode finds a Node by name and user and returns the Node struct. // getNode finds a Node by name and user and returns the Node struct.
func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) { func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
nodes, err := ListNodesByUser(tx, uid) nodes, err := ListNodesByUser(tx, user)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -29,10 +29,10 @@ func (s *Suite) TestGetNode(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -50,7 +50,7 @@ func (s *Suite) TestGetNode(c *check.C) {
trx := db.DB.Save(node) trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
@ -58,7 +58,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
@ -87,7 +87,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
@ -135,7 +135,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode3") _, err = db.getNode(user.Name, "testnode3")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
@ -143,7 +143,7 @@ func (s *Suite) TestListPeers(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0) _, err = db.GetNodeByID(0)
@ -189,7 +189,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
for _, name := range []string{"test", "admin"} { for _, name := range []string{"test", "admin"} {
user, err := db.CreateUser(name) user, err := db.CreateUser(name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
stor = append(stor, base{user, pak}) stor = append(stor, base{user, pak})
} }
@ -281,10 +281,10 @@ func (s *Suite) TestExpireNode(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -302,7 +302,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
} }
db.DB.Save(node) db.DB.Save(node)
nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode") nodeFromDB, err := db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nodeFromDB, check.NotNil) c.Assert(nodeFromDB, check.NotNil)
@ -312,7 +312,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
err = db.NodeSetExpiry(nodeFromDB.ID, now) err = db.NodeSetExpiry(nodeFromDB.ID, now)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") nodeFromDB, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(nodeFromDB.IsExpired(), check.Equals, true) c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
@ -322,10 +322,10 @@ func (s *Suite) TestSetTags(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode") _, err = db.getNode("test", "testnode")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -348,7 +348,7 @@ func (s *Suite) TestSetTags(c *check.C) {
sTags := []string{"tag:test", "tag:foo"} sTags := []string{"tag:test", "tag:foo"}
err = db.SetTags(node.ID, sTags) err = db.SetTags(node.ID, sTags)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node, err = db.getNode(types.UserID(user.ID), "testnode") node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, sTags) c.Assert(node.ForcedTags, check.DeepEquals, sTags)
@ -356,7 +356,7 @@ func (s *Suite) TestSetTags(c *check.C) {
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
err = db.SetTags(node.ID, eTags) err = db.SetTags(node.ID, eTags)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node, err = db.getNode(types.UserID(user.ID), "testnode") node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert( c.Assert(
node.ForcedTags, node.ForcedTags,
@ -367,7 +367,7 @@ func (s *Suite) TestSetTags(c *check.C) {
// test removing tags // test removing tags
err = db.SetTags(node.ID, []string{}) err = db.SetTags(node.ID, []string{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node, err = db.getNode(types.UserID(user.ID), "testnode") node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, []string{}) c.Assert(node.ForcedTags, check.DeepEquals, []string{})
} }
@ -567,7 +567,7 @@ func TestAutoApproveRoutes(t *testing.T) {
user, err := adb.CreateUser("test") user, err := adb.CreateUser("test")
assert.NoError(t, err) assert.NoError(t, err)
pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
nodeKey := key.NewNode() nodeKey := key.NewNode()
@ -699,10 +699,10 @@ func TestListEphemeralNodes(t *testing.T) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
assert.NoError(t, err) assert.NoError(t, err)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
node := types.Node{ node := types.Node{

View file

@ -23,27 +23,29 @@ var (
) )
func (hsdb *HSDatabase) CreatePreAuthKey( func (hsdb *HSDatabase) CreatePreAuthKey(
uid types.UserID, // TODO(kradalby): Should be ID, not name
userName string,
reusable bool, reusable bool,
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
aclTags []string, aclTags []string,
) (*types.PreAuthKey, error) { ) (*types.PreAuthKey, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags) return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
}) })
} }
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey( func CreatePreAuthKey(
tx *gorm.DB, tx *gorm.DB,
uid types.UserID, // TODO(kradalby): Should be ID, not name
userName string,
reusable bool, reusable bool,
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
aclTags []string, aclTags []string,
) (*types.PreAuthKey, error) { ) (*types.PreAuthKey, error) {
user, err := GetUserByID(tx, uid) user, err := GetUserByUsername(tx, userName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,15 +89,15 @@ func CreatePreAuthKey(
return &key, nil return &key, nil
} }
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeysByUser(rx, uid) return ListPreAuthKeys(rx, userName)
}) })
} }
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user. // ListPreAuthKeys returns the list of PreAuthKeys for a user.
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) { func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
user, err := GetUserByID(tx, uid) user, err := GetUserByUsername(tx, userName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -11,14 +11,14 @@ import (
) )
func (*Suite) TestCreatePreAuthKey(c *check.C) { func (*Suite) TestCreatePreAuthKey(c *check.C) {
// ID does not exist _, err := db.CreatePreAuthKey("bogus", true, false, nil, nil)
_, err := db.CreatePreAuthKey(12345, true, false, nil, nil)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
// Did we get a valid key? // Did we get a valid key?
@ -26,18 +26,17 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
c.Assert(len(key.Key), check.Equals, 48) c.Assert(len(key.Key), check.Equals, 48)
// Make sure the User association is populated // Make sure the User association is populated
c.Assert(key.User.ID, check.Equals, user.ID) c.Assert(key.User.Name, check.Equals, user.Name)
// ID does not exist _, err = db.ListPreAuthKeys("bogus")
_, err = db.ListPreAuthKeys(1000000)
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
keys, err := db.ListPreAuthKeys(types.UserID(user.ID)) keys, err := db.ListPreAuthKeys(user.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1) c.Assert(len(keys), check.Equals, 1)
// Make sure the User association is populated // Make sure the User association is populated
c.Assert((keys)[0].User.ID, check.Equals, user.ID) c.Assert((keys)[0].User.Name, check.Equals, user.Name)
} }
func (*Suite) TestExpiredPreAuthKey(c *check.C) { func (*Suite) TestExpiredPreAuthKey(c *check.C) {
@ -45,7 +44,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now().Add(-5 * time.Second) now := time.Now().Add(-5 * time.Second)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
@ -63,7 +62,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) {
user, err := db.CreateUser("test3") user, err := db.CreateUser("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
@ -75,7 +74,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
user, err := db.CreateUser("test4") user, err := db.CreateUser("test4")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node := types.Node{ node := types.Node{
@ -97,7 +96,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
user, err := db.CreateUser("test5") user, err := db.CreateUser("test5")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node := types.Node{ node := types.Node{
@ -119,7 +118,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
user, err := db.CreateUser("test6") user, err := db.CreateUser("test6")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key) key, err := db.ValidatePreAuthKey(pak.Key)
@ -131,7 +130,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) {
user, err := db.CreateUser("test3") user, err := db.CreateUser("test3")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.IsNil) c.Assert(pak.Expiration, check.IsNil)
@ -148,7 +147,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
user, err := db.CreateUser("test6") user, err := db.CreateUser("test6")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak.Used = true pak.Used = true
db.DB.Save(&pak) db.DB.Save(&pak)
@ -161,15 +160,15 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
user, err := db.CreateUser("test8") user, err := db.CreateUser("test8")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) _, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
tags := []string{"tag:test1", "tag:test2"} tags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) _, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) listedPaks, err := db.ListPreAuthKeys("test8")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
gotTags := listedPaks[0].Proto().GetAclTags() gotTags := listedPaks[0].Proto().GetAclTags()
sort.Sort(sort.StringSlice(gotTags)) sort.Sort(sort.StringSlice(gotTags))

View file

@ -639,7 +639,7 @@ func EnableAutoApprovedRoutes(
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
Uint("user.id", node.User.ID). Str("user", node.User.Name).
Strs("routeApprovers", routeApprovers). Strs("routeApprovers", routeApprovers).
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()). Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
Msg("looking up route for autoapproving") Msg("looking up route for autoapproving")

View file

@ -35,10 +35,10 @@ func (s *Suite) TestGetRoutes(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "test_get_route_node") _, err = db.getNode("test", "test_get_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24") route, err := netip.ParsePrefix("10.0.0.0/24")
@ -79,10 +79,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") _, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix( route, err := netip.ParsePrefix(
@ -153,10 +153,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") _, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix( route, err := netip.ParsePrefix(
@ -234,10 +234,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") _, err = db.getNode("test", "test_enable_route_node")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix( prefix, err := netip.ParsePrefix(

View file

@ -40,21 +40,21 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
return &user, nil return &user, nil
} }
func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error { func (hsdb *HSDatabase) DestroyUser(name string) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
return DestroyUser(tx, uid) return DestroyUser(tx, name)
}) })
} }
// DestroyUser destroys a User. Returns error if the User does // DestroyUser destroys a User. Returns error if the User does
// not exist or if there are nodes associated with it. // not exist or if there are nodes associated with it.
func DestroyUser(tx *gorm.DB, uid types.UserID) error { func DestroyUser(tx *gorm.DB, name string) error {
user, err := GetUserByID(tx, uid) user, err := GetUserByUsername(tx, name)
if err != nil { if err != nil {
return err return ErrUserNotFound
} }
nodes, err := ListNodesByUser(tx, uid) nodes, err := ListNodesByUser(tx, name)
if err != nil { if err != nil {
return err return err
} }
@ -62,7 +62,7 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
return ErrUserStillHasNodes return ErrUserStillHasNodes
} }
keys, err := ListPreAuthKeysByUser(tx, uid) keys, err := ListPreAuthKeys(tx, name)
if err != nil { if err != nil {
return err return err
} }
@ -80,17 +80,17 @@ func DestroyUser(tx *gorm.DB, uid types.UserID) error {
return nil return nil
} }
func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error { func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
return RenameUser(tx, uid, newName) return RenameUser(tx, oldName, newName)
}) })
} }
// RenameUser renames a User. Returns error if the User does // RenameUser renames a User. Returns error if the User does
// not exist or if another User exists with the new name. // not exist or if another User exists with the new name.
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { func RenameUser(tx *gorm.DB, oldName, newName string) error {
var err error var err error
oldUser, err := GetUserByID(tx, uid) oldUser, err := GetUserByUsername(tx, oldName)
if err != nil { if err != nil {
return err return err
} }
@ -98,25 +98,50 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
if err != nil { if err != nil {
return err return err
} }
_, err = GetUserByUsername(tx, newName)
if err == nil {
return ErrUserExists
}
if !errors.Is(err, ErrUserNotFound) {
return err
}
oldUser.Name = newName oldUser.Name = newName
if err := tx.Save(&oldUser).Error; err != nil { if result := tx.Save(&oldUser); result.Error != nil {
return err return result.Error
} }
return nil return nil
} }
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) { func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUserByID(rx, uid) return GetUserByUsername(rx, name)
}) })
} }
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) { func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) {
user := types.User{} user := types.User{}
if result := tx.First(&user, "id = ?", uid); errors.Is( if result := tx.First(&user, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrUserNotFound
}
return &user, nil
}
func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUserByID(rx, id)
})
}
func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) {
user := types.User{}
if result := tx.First(&user, "id = ?", id); errors.Is(
result.Error, result.Error,
gorm.ErrRecordNotFound, gorm.ErrRecordNotFound,
) { ) {
@ -144,65 +169,54 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
return &user, nil return &user, nil
} }
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
return ListUsers(rx, where...) return ListUsers(rx)
}) })
} }
// ListUsers gets all the existing users. // ListUsers gets all the existing users.
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { func ListUsers(tx *gorm.DB) ([]types.User, error) {
if len(where) > 1 {
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
}
var user *types.User
if len(where) == 1 {
user = where[0]
}
users := []types.User{} users := []types.User{}
if err := tx.Where(user).Find(&users).Error; err != nil { if err := tx.Find(&users).Error; err != nil {
return nil, err return nil, err
} }
return users, nil return users, nil
} }
// GetUserByName returns a user if the provided username is // ListNodesByUser gets all the nodes in a given user.
// unique, and otherwise an error. func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { err := util.CheckForFQDNRules(name)
users, err := hsdb.ListUsers(&types.User{Name: name}) if err != nil {
return nil, err
}
user, err := GetUserByUsername(tx, name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(users) != 1 {
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
}
return &users[0], nil
}
// ListNodesByUser gets all the nodes in a given user.
func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
nodes := types.Nodes{} nodes := types.Nodes{}
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil { if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
return nil, err return nil, err
} }
return nodes, nil return nodes, nil
} }
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error { func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
return AssignNodeToUser(tx, node, uid) return AssignNodeToUser(tx, node, username)
}) })
} }
// AssignNodeToUser assigns a Node to a user. // AssignNodeToUser assigns a Node to a user.
func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error { func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
user, err := GetUserByID(tx, uid) err := util.CheckForFQDNRules(username)
if err != nil {
return err
}
user, err := GetUserByUsername(tx, username)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,8 +1,6 @@
package db package db
import ( import (
"strings"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1" "gopkg.in/check.v1"
@ -19,24 +17,24 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1) c.Assert(len(users), check.Equals, 1)
err = db.DestroyUser(types.UserID(user.ID)) err = db.DestroyUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.GetUserByID(types.UserID(user.ID)) _, err = db.GetUserByName("test")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }
func (s *Suite) TestDestroyUserErrors(c *check.C) { func (s *Suite) TestDestroyUserErrors(c *check.C) {
err := db.DestroyUser(9998) err := db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserNotFound) c.Assert(err, check.Equals, ErrUserNotFound)
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = db.DestroyUser(types.UserID(user.ID)) err = db.DestroyUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
@ -46,7 +44,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
user, err = db.CreateUser("test") user, err = db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node := types.Node{ node := types.Node{
@ -59,7 +57,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
trx := db.DB.Save(&node) trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
err = db.DestroyUser(types.UserID(user.ID)) err = db.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserStillHasNodes) c.Assert(err, check.Equals, ErrUserStillHasNodes)
} }
@ -72,28 +70,24 @@ func (s *Suite) TestRenameUser(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1) c.Assert(len(users), check.Equals, 1)
err = db.RenameUser(types.UserID(userTest.ID), "test-renamed") err = db.RenameUser("test", "test-renamed")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
users, err = db.ListUsers(&types.User{Name: "test"}) _, err = db.GetUserByName("test")
c.Assert(err, check.Equals, nil) c.Assert(err, check.Equals, ErrUserNotFound)
c.Assert(len(users), check.Equals, 0)
users, err = db.ListUsers(&types.User{Name: "test-renamed"}) _, err = db.GetUserByName("test-renamed")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = db.RenameUser(99988, "test") err = db.RenameUser("test-does-not-exit", "test")
c.Assert(err, check.Equals, ErrUserNotFound) c.Assert(err, check.Equals, ErrUserNotFound)
userTest2, err := db.CreateUser("test2") userTest2, err := db.CreateUser("test2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(userTest2.Name, check.Equals, "test2") c.Assert(userTest2.Name, check.Equals, "test2")
err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed") err = db.RenameUser("test2", "test-renamed")
if !strings.Contains(err.Error(), "UNIQUE constraint failed") { c.Assert(err, check.Equals, ErrUserExists)
c.Fatalf("expected failure with unique constraint, got: %s", err.Error())
}
} }
func (s *Suite) TestSetMachineUser(c *check.C) { func (s *Suite) TestSetMachineUser(c *check.C) {
@ -103,7 +97,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
newUser, err := db.CreateUser("new") newUser, err := db.CreateUser("new")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
node := types.Node{ node := types.Node{
@ -117,15 +111,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
c.Assert(trx.Error, check.IsNil) c.Assert(trx.Error, check.IsNil)
c.Assert(node.UserID, check.Equals, oldUser.ID) c.Assert(node.UserID, check.Equals, oldUser.ID)
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) err = db.AssignNodeToUser(&node, newUser.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.UserID, check.Equals, newUser.ID)
c.Assert(node.User.Name, check.Equals, newUser.Name) c.Assert(node.User.Name, check.Equals, newUser.Name)
err = db.AssignNodeToUser(&node, 9584849) err = db.AssignNodeToUser(&node, "non-existing-user")
c.Assert(err, check.Equals, ErrUserNotFound) c.Assert(err, check.Equals, ErrUserNotFound)
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) err = db.AssignNodeToUser(&node, newUser.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.UserID, check.Equals, newUser.ID)
c.Assert(node.User.Name, check.Equals, newUser.Name) c.Assert(node.User.Name, check.Equals, newUser.Name)

View file

@ -65,34 +65,24 @@ func (api headscaleV1APIServer) RenameUser(
ctx context.Context, ctx context.Context,
request *v1.RenameUserRequest, request *v1.RenameUserRequest,
) (*v1.RenameUserResponse, error) { ) (*v1.RenameUserResponse, error) {
oldUser, err := api.h.db.GetUserByName(request.GetOldName()) err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) user, err := api.h.db.GetUserByName(request.GetNewName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
newUser, err := api.h.db.GetUserByName(request.GetNewName()) return &v1.RenameUserResponse{User: user.Proto()}, nil
if err != nil {
return nil, err
}
return &v1.RenameUserResponse{User: newUser.Proto()}, nil
} }
func (api headscaleV1APIServer) DeleteUser( func (api headscaleV1APIServer) DeleteUser(
ctx context.Context, ctx context.Context,
request *v1.DeleteUserRequest, request *v1.DeleteUserRequest,
) (*v1.DeleteUserResponse, error) { ) (*v1.DeleteUserResponse, error) {
user, err := api.h.db.GetUserByName(request.GetName()) err := api.h.db.DestroyUser(request.GetName())
if err != nil {
return nil, err
}
err = api.h.db.DestroyUser(types.UserID(user.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,13 +131,8 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
} }
} }
user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
preAuthKey, err := api.h.db.CreatePreAuthKey( preAuthKey, err := api.h.db.CreatePreAuthKey(
types.UserID(user.ID), request.GetUser(),
request.GetReusable(), request.GetReusable(),
request.GetEphemeral(), request.GetEphemeral(),
&expiration, &expiration,
@ -183,12 +168,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
ctx context.Context, ctx context.Context,
request *v1.ListPreAuthKeysRequest, request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) { ) (*v1.ListPreAuthKeysResponse, error) {
user, err := api.h.db.GetUserByName(request.GetUser()) preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser())
if err != nil {
return nil, err
}
preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -426,20 +406,10 @@ func (api headscaleV1APIServer) ListNodes(
ctx context.Context, ctx context.Context,
request *v1.ListNodesRequest, request *v1.ListNodesRequest,
) (*v1.ListNodesResponse, error) { ) (*v1.ListNodesResponse, error) {
// TODO(kradalby): it looks like this can be simplified a lot,
// the filtering of nodes by user, vs nodes as a whole can
// probably be done once.
// TODO(kradalby): This should be done in one tx.
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
if request.GetUser() != "" { if request.GetUser() != "" {
user, err := api.h.db.GetUserByName(request.GetUser())
if err != nil {
return nil, err
}
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return db.ListNodesByUser(rx, types.UserID(user.ID)) return db.ListNodesByUser(rx, request.GetUser())
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -495,18 +465,12 @@ func (api headscaleV1APIServer) MoveNode(
ctx context.Context, ctx context.Context,
request *v1.MoveNodeRequest, request *v1.MoveNodeRequest,
) (*v1.MoveNodeResponse, error) { ) (*v1.MoveNodeResponse, error) {
// TODO(kradalby): This should be done in one tx.
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := api.h.db.GetUserByName(request.GetUser()) err = api.h.db.AssignNodeToUser(node, request.GetUser())
if err != nil {
return nil, err
}
err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -26,7 +26,7 @@ type PreAuthKey struct {
func (key *PreAuthKey) Proto() *v1.PreAuthKey { func (key *PreAuthKey) Proto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{ protoKey := v1.PreAuthKey{
User: key.User.Username(), User: key.User.Name,
Id: strconv.FormatUint(key.ID, util.Base10), Id: strconv.FormatUint(key.ID, util.Base10),
Key: key.Key, Key: key.Key,
Ephemeral: key.Ephemeral, Ephemeral: key.Ephemeral,

View file

@ -21,12 +21,12 @@ type User struct {
gorm.Model gorm.Model
// The index `idx_name_provider_identifier` is to enforce uniqueness // The index `idx_name_provider_identifier` is to enforce uniqueness
// between Name and ProviderIdentifier. This ensures that // between Name and ProviderIdentifier. This ensures that
// you can have multiple users with the same name in OIDC, // you can have multiple usersnames of the same name in OIDC,
// but not if you only run with CLI users. // but not if you only run with CLI users.
// Username for the user, is used if email is empty // Username for the user, is used if email is empty
// Should not be used, please use Username(). // Should not be used, please use Username().
Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"` Name string `gorm:"index,uniqueIndex:idx_name_provider_identifier"`
// Typically the full name of the user // Typically the full name of the user
DisplayName string DisplayName string
@ -54,9 +54,9 @@ type User struct {
// enabled with OIDC, which means that there is a domain involved which // enabled with OIDC, which means that there is a domain involved which
// should be used throughout headscale, in information returned to the // should be used throughout headscale, in information returned to the
// user and the Policy engine. // user and the Policy engine.
// If the username does not contain an '@' it will be added to the end.
func (u *User) Username() string { func (u *User) Username() string {
username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
// TODO(kradalby): Wire up all of this for the future // TODO(kradalby): Wire up all of this for the future
// if !strings.Contains(username, "@") { // if !strings.Contains(username, "@") {
// username = username + "@" // username = username + "@"