move logic for validating node names (#2127)

* move logic for validating node names

this commits moves the generation of "given names" of nodes
into the registration function, and adds validation of renames
to RenameNode using the same logic.

Fixes #2121

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* fix double arg

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-09-11 18:27:49 +02:00 committed by GitHub
parent 64319f79ff
commit 064c46f2a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 134 additions and 115 deletions

View file

@ -66,7 +66,7 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest, regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
) { ) {
logInfo, logTrace, logErr := logAuthFunc(regReq, machineKey) logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
now := time.Now().UTC() now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB") logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
@ -105,16 +105,6 @@ func (h *Headscale) handleRegister(
logInfo("Node not found in database, creating new") logInfo("Node not found in database, creating new")
givenName, err := h.db.GenerateGivenName(
machineKey,
regReq.Hostinfo.Hostname,
)
if err != nil {
logErr(err, "Failed to generate given name for node")
return
}
// The node did not have a key to authenticate, which means // The node did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI) // that we rely on a method that calls back some how (OpenID or CLI)
// We create the node and then keep it around until a callback // We create the node and then keep it around until a callback
@ -122,7 +112,6 @@ func (h *Headscale) handleRegister(
newNode := types.Node{ newNode := types.Node{
MachineKey: machineKey, MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname, Hostname: regReq.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: regReq.NodeKey, NodeKey: regReq.NodeKey,
LastSeen: &now, LastSeen: &now,
Expiry: &time.Time{}, Expiry: &time.Time{},
@ -354,21 +343,8 @@ func (h *Headscale) handleAuthKey(
} else { } else {
now := time.Now().UTC() now := time.Now().UTC()
givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed to generate given name for node")
return
}
nodeToRegister := types.Node{ nodeToRegister := types.Node{
Hostname: registerRequest.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID, UserID: pak.User.ID,
User: pak.User, User: pak.User,
MachineKey: machineKey, MachineKey: machineKey,

View file

@ -90,20 +90,6 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
}) })
} }
func listNodesByGivenName(tx *gorm.DB, givenName string) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Preload("AuthKey.User").
Preload("User").
Preload("Routes").
Where("given_name = ?", givenName).Find(&nodes).Error; err != nil {
return nil, err
}
return nodes, nil
}
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return getNode(rx, user, name) return getNode(rx, user, name)
@ -242,9 +228,9 @@ func SetTags(
} }
// RenameNode takes a Node struct and a new GivenName for the nodes // RenameNode takes a Node struct and a new GivenName for the nodes
// and renames it. // and renames it. If the name is not unique, it will return an error.
func RenameNode(tx *gorm.DB, func RenameNode(tx *gorm.DB,
nodeID uint64, newName string, nodeID types.NodeID, newName string,
) error { ) error {
err := util.CheckForFQDNRules( err := util.CheckForFQDNRules(
newName, newName,
@ -253,6 +239,15 @@ func RenameNode(tx *gorm.DB,
return fmt.Errorf("renaming node: %w", err) return fmt.Errorf("renaming node: %w", err)
} }
uniq, err := isUnqiueName(tx, newName)
if err != nil {
return fmt.Errorf("checking if name is unique: %w", err)
}
if !uniq {
return fmt.Errorf("name is not unique: %s", newName)
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
return fmt.Errorf("failed to rename node in the database: %w", err) return fmt.Errorf("failed to rename node in the database: %w", err)
} }
@ -415,6 +410,15 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
node.IPv4 = ipv4 node.IPv4 = ipv4
node.IPv6 = ipv6 node.IPv6 = ipv6
if node.GivenName == "" {
givenName, err := ensureUniqueGivenName(tx, node.Hostname)
if err != nil {
return nil, fmt.Errorf("failed to ensure unique given name: %w", err)
}
node.GivenName = givenName
}
if err := tx.Save(&node).Error; err != nil { if err := tx.Save(&node).Error; err != nil {
return nil, fmt.Errorf("failed register(save) node in the database: %w", err) return nil, fmt.Errorf("failed register(save) node in the database: %w", err)
} }
@ -642,40 +646,32 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
return normalizedHostname, nil return normalizedHostname, nil
} }
func (hsdb *HSDatabase) GenerateGivenName( func isUnqiueName(tx *gorm.DB, name string) (bool, error) {
mkey key.MachinePublic, nodes := types.Nodes{}
suppliedName string, if err := tx.
) (string, error) { Where("given_name = ?", name).Find(&nodes).Error; err != nil {
return Read(hsdb.DB, func(rx *gorm.DB) (string, error) { return false, err
return GenerateGivenName(rx, mkey, suppliedName) }
})
return len(nodes) == 0, nil
} }
func GenerateGivenName( func ensureUniqueGivenName(
tx *gorm.DB, tx *gorm.DB,
mkey key.MachinePublic, name string,
suppliedName string,
) (string, error) { ) (string, error) {
givenName, err := generateGivenName(suppliedName, false) givenName, err := generateGivenName(name, false)
if err != nil { if err != nil {
return "", err return "", err
} }
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ unique, err := isUnqiueName(tx, givenName)
nodes, err := listNodesByGivenName(tx, givenName)
if err != nil { if err != nil {
return "", err return "", err
} }
var nodeFound *types.Node if !unique {
for idx, node := range nodes { postfixedName, err := generateGivenName(name, true)
if node.GivenName == givenName {
nodeFound = nodes[idx]
}
}
if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() {
postfixedName, err := generateGivenName(suppliedName, true)
if err != nil { if err != nil {
return "", err return "", err
} }

View file

@ -19,6 +19,7 @@ import (
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
@ -313,51 +314,6 @@ func (s *Suite) TestExpireNode(c *check.C) {
c.Assert(nodeFromDB.IsExpired(), check.Equals, true) c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
} }
func (s *Suite) TestGenerateGivenName(c *check.C) {
user1, err := db.CreateUser("user-1")
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode("user-1", "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
machineKey2 := key.NewMachine()
node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "hostname-1",
GivenName: "hostname-1",
UserID: user1.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)
givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2")
comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-2", comment)
givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1")
comment = check.Commentf("Same user, same node, same hostname, no conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Equals, "hostname-1", comment)
givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1")
comment = check.Commentf("Same user, unique nodes, same hostname, conflict")
c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment)
}
func (s *Suite) TestSetTags(c *check.C) { func (s *Suite) TestSetTags(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@ -778,3 +734,100 @@ func TestListEphemeralNodes(t *testing.T) {
assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID) assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID)
assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname) assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname)
} }
func TestRenameNode(t *testing.T) {
db, err := newTestDB()
if err != nil {
t.Fatalf("creating db: %s", err)
}
user, err := db.CreateUser("test")
assert.NoError(t, err)
user2, err := db.CreateUser("test2")
assert.NoError(t, err)
node := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
node2 := types.Node{
ID: 0,
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
err = db.DB.Save(&node).Error
assert.NoError(t, err)
err = db.DB.Save(&node2).Error
assert.NoError(t, err)
err = db.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, node, nil, nil)
if err != nil {
return err
}
_, err = RegisterNode(tx, node2, nil, nil)
return err
})
assert.NoError(t, err)
nodes, err := db.ListNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName)
t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName)
assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName)
assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName)
assert.Equal(t, nodes[0].Hostname, nodes[1].Hostname)
assert.NotEqual(t, nodes[0].Hostname, nodes[1].GivenName)
assert.Contains(t, nodes[1].GivenName, nodes[0].Hostname)
assert.Equal(t, nodes[0].GivenName, nodes[1].Hostname)
assert.Len(t, nodes[0].Hostname, 4)
assert.Len(t, nodes[1].Hostname, 4)
assert.Len(t, nodes[0].GivenName, 4)
assert.Len(t, nodes[1].GivenName, 13)
// Nodes can be renamed to a unique name
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "newname")
})
assert.NoError(t, err)
nodes, err = db.ListNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, nodes[0].Hostname, "test")
assert.Equal(t, nodes[0].GivenName, "newname")
// Nodes can reuse name that is no longer used
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[1].ID, "test")
})
assert.NoError(t, err)
nodes, err = db.ListNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, nodes[0].Hostname, "test")
assert.Equal(t, nodes[0].GivenName, "newname")
assert.Equal(t, nodes[1].GivenName, "test")
// Nodes cannot be renamed to used names
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "test")
})
assert.ErrorContains(t, err, "name is not unique")
}

View file

@ -373,7 +373,7 @@ func (api headscaleV1APIServer) RenameNode(
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) { node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
err := db.RenameNode( err := db.RenameNode(
tx, tx,
request.GetNodeId(), types.NodeID(request.GetNodeId()),
request.GetNewName(), request.GetNewName(),
) )
if err != nil { if err != nil {
@ -802,18 +802,12 @@ func (api headscaleV1APIServer) DebugCreateNode(
return nil, err return nil, err
} }
givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName())
if err != nil {
return nil, err
}
nodeKey := key.NewNode() nodeKey := key.NewNode()
newNode := types.Node{ newNode := types.Node{
MachineKey: mkey, MachineKey: mkey,
NodeKey: nodeKey.Public(), NodeKey: nodeKey.Public(),
Hostname: request.GetName(), Hostname: request.GetName(),
GivenName: givenName,
User: *user, User: *user,
Expiry: &time.Time{}, Expiry: &time.Time{},