From 3dc452dee4862873cc0216ec50ea7885d5d73e67 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 21 Oct 2024 11:58:59 -0600 Subject: [PATCH 1/3] resolve user identifier to stable ID currently, the policy approach node to user matching with a quite naive approach looking at the username provided in the policy and matched it with the username on the nodes. This worked ok as long as usernames were unique and did not change. As usernames are no longer guarenteed to be unique in an OIDC environment we cant rely on this. This changes the mechanism that matches the user string (now user token) with nodes: - first find all potential users by looking up: - database ID - provider ID (OIDC) - username/email If more than one user is matching, then the query is rejected, and zero matching nodes are returned. When a single user is found, the node is matched against the User database ID, which are also present on the actual node. This means that from this commit, users can use the following to identify users in the policy: - provider identity (iss + sub) - username - email - database id There are more changes coming to this, so it is not recommended to start using any of these new abilities, with the exception of email, which will not change since it includes an @. Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 8 +- hscontrol/db/node_test.go | 4 +- hscontrol/db/routes.go | 7 +- hscontrol/grpcv1.go | 8 +- hscontrol/mapper/mapper.go | 19 +- hscontrol/mapper/mapper_test.go | 16 +- hscontrol/policy/acls.go | 92 +++++-- hscontrol/policy/acls_test.go | 425 ++++++++++++++++++++++++-------- 8 files changed, 445 insertions(+), 134 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 737e8098..3f651877 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -1027,14 +1027,18 @@ func (h *Headscale) loadACLPolicy() error { if err != nil { return fmt.Errorf("loading nodes from database to validate policy: %w", err) } + users, err := h.db.ListUsers() + if err != nil { + return fmt.Errorf("loading users from database to validate policy: %w", err) + } - _, err = pol.CompileFilterRules(nodes) + _, err = pol.CompileFilterRules(users, nodes) if err != nil { return fmt.Errorf("verifying policy rules: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) if err != nil { return fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 888f48db..5bcbd546 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(testPeers), check.Equals, 9) - adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) + adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) - testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) + testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 086261aa..c89a10f8 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes( if approvedAlias == node.User.Username() { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { + users, err := ListUsers(tx) + if err != nil { + return fmt.Errorf("looking up users to expand route alias: %w", err) + } + // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) + approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 68793716..e3291d8f 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } + users, err := api.h.db.ListUsers() + if err != nil { + return nil, fmt.Errorf("loading users from database to validate policy: %w", err) + } - _, err = pol.CompileFilterRules(nodes) + _, err = pol.CompileFilterRules(users, nodes) if err != nil { return nil, fmt.Errorf("verifying policy rules: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 3db1e159..5205a112 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, + users []types.User, pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { @@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse( pol, node, capVer, + users, peers, peers, m.cfg, @@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse( if err != nil { return nil, err } + users, err := m.db.ListUsers() + if err != nil { + return nil, err + } - resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) if err != nil { return nil, err } @@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse( return nil, err } + users, err := m.db.ListUsers() + if err != nil { + return nil, fmt.Errorf("listing users for map response: %w", err) + } + var removedIDs []tailcfg.NodeID var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { @@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse( pol, node, mapRequest.Version, + users, peers, changedNodes, m.cfg, @@ -508,16 +520,17 @@ func appendPeerChanges( pol *policy.ACLPolicy, node *types.Node, capVer tailcfg.CapabilityVersion, + users []types.User, peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(append(peers, node)) + packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) if err != nil { return err } - sshPolicy, err := pol.CompileSSHPolicy(node, peers) + sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) if err != nil { return err } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 37ed5c42..8dd51808 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) { lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) + user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"} + user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"} + mini := &types.Node{ ID: 0, MachineKey: mustMK( @@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, @@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.2"), Hostname: "peer1", GivenName: "peer1", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.3"), Hostname: "peer2", GivenName: "peer2", - UserID: 1, - User: types.User{Name: "peer2"}, + UserID: user2.ID, + User: user2, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) { got, err := mappy.fullMapResponse( tt.node, tt.peers, + []types.User{user1, user2}, tt.pol, 0, ) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index ff73985b..8e2d1961 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests( policy *ACLPolicy, node *types.Node, peers types.Nodes, + users []types.User, ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { // If there is no policy defined, we default to allow all if policy == nil { return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.CompileFilterRules(append(peers, node)) + rules, err := policy.CompileFilterRules(users, append(peers, node)) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - sshPolicy, err := policy.CompileSSHPolicy(node, peers) + sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests( // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) CompileFilterRules( + users []types.User, nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { if pol == nil { @@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules( var srcIPs []string for srcIndex, src := range acl.Sources { - srcs, err := pol.expandSource(src, nodes) + srcs, err := pol.expandSource(src, users, nodes) if err != nil { - return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) + return nil, fmt.Errorf( + "parsing policy, acl index: %d->%d: %w", + index, + srcIndex, + err, + ) } srcIPs = append(srcIPs, srcs...) } @@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules( expanded, err := pol.ExpandAlias( nodes, + users, alias, ) if err != nil { @@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F func (pol *ACLPolicy) CompileSSHPolicy( node *types.Node, + users []types.User, peers types.Nodes, ) (*tailcfg.SSHPolicy, error) { if pol == nil { @@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( for index, sshACL := range pol.SSHs { var dest netipx.IPSetBuilder for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, node), src) + expanded, err := pol.ExpandAlias(append(peers, node), users, src) if err != nil { return nil, err } @@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy( case "check": checkAction, err := sshCheckAction(sshACL.CheckPeriod) if err != nil { - return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) + return nil, fmt.Errorf( + "parsing SSH policy, parsing check duration, index: %d: %w", + index, + err, + ) } else { action = *checkAction } default: - return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) + return nil, fmt.Errorf( + "parsing SSH policy, unknown action %q, index: %d: %w", + sshACL.Action, + index, + err, + ) } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( } else { expandedSrcs, err := pol.ExpandAlias( peers, + users, rawSrc, ) if err != nil { @@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) { // with the given src alias. func (pol *ACLPolicy) expandSource( src string, + users []types.User, nodes types.Nodes, ) ([]string, error) { - ipSet, err := pol.ExpandAlias(nodes, src) + ipSet, err := pol.ExpandAlias(nodes, users, src) if err != nil { return []string{}, err } @@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource( // and transform these in IPAddresses. func (pol *ACLPolicy) ExpandAlias( nodes types.Nodes, + users []types.User, alias string, ) (*netipx.IPSet, error) { if isWildcard(alias) { @@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is a group if isGroup(alias) { - return pol.expandIPsFromGroup(alias, nodes) + return pol.expandIPsFromGroup(alias, users, nodes) } // if alias is a tag if isTag(alias) { - return pol.expandIPsFromTag(alias, nodes) + return pol.expandIPsFromTag(alias, users, nodes) } if isAutoGroup(alias) { @@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias( } // if alias is a user - if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { + if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { return ips, err } @@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias( if h, ok := pol.Hosts[alias]; ok { log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.ExpandAlias(nodes, h.String()) + return pol.ExpandAlias(nodes, users, h.String()) } // if alias is an IP @@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup( func (pol *ACLPolicy) expandIPsFromGroup( group string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - users, err := pol.expandUsersFromGroup(group) + userTokens, err := pol.expandUsersFromGroup(group) if err != nil { return &netipx.IPSet{}, err } - for _, user := range users { - filteredNodes := filterNodesByUser(nodes, user) + for _, user := range userTokens { + filteredNodes := filterNodesByUser(nodes, users, user) for _, node := range filteredNodes { node.AppendToIPSet(&build) } @@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromTag( alias string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder @@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag( // filter out nodes per tag owner for _, user := range owners { - nodes := filterNodesByUser(nodes, user) + nodes := filterNodesByUser(nodes, users, user) for _, node := range nodes { if node.Hostinfo == nil { continue @@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromUser( user string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - filteredNodes := filterNodesByUser(nodes, user) + filteredNodes := filterNodesByUser(nodes, users, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) // shortcurcuit if we have no nodes to get ips from. @@ -953,10 +977,40 @@ func (pol *ACLPolicy) TagsOfNode( return validTags, invalidTags } -func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { +// filterNodesByUser returns a list of nodes that match the given userToken from a +// policy. +// Matching nodes are determined by first matching the user token to a user by checking: +// - If it is an ID that mactches the user database ID +// - It is the Provider Identifier from OIDC +// - It matches the username or email of a user +// +// If the token matches more than one user, zero nodes will returned. +func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { var out types.Nodes + + var potentialUsers []types.User + for _, user := range users { + if user.ProviderIdentifier == userToken { + potentialUsers = append(potentialUsers, user) + + break + } + if user.Email == userToken { + potentialUsers = append(potentialUsers, user) + } + if user.Name == userToken { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) != 1 { + return nil + } + + user := potentialUsers[0] + for _, node := range nodes { - if node.User.Username() == user { + if node.User.ID == user.ID { out = append(out, node) } } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 1c6e4de8..f13d7f42 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -2,8 +2,10 @@ package policy import ( "errors" + "math/rand/v2" "net/netip" "slices" + "sort" "testing" "github.com/google/go-cmp/cmp" @@ -14,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -375,18 +378,24 @@ func TestParsing(t *testing.T) { return } - rules, err := pol.CompileFilterRules(types.Nodes{ - &types.Node{ - IPv4: iap("100.100.100.100"), + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser", + } + rules, err := pol.CompileFilterRules( + []types.User{ + user, }, - &types.Node{ - IPv4: iap("200.200.200.200"), - User: types.User{ - Name: "testuser", + types.Nodes{ + &types.Node{ + IPv4: iap("100.100.100.100"), }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) + &types.Node{ + IPv4: iap("200.200.200.200"), + User: user, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -533,7 +542,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.CompileFilterRules(types.Nodes{}) + rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -549,7 +558,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -568,7 +577,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -584,7 +593,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } @@ -861,6 +870,14 @@ func Test_expandPorts(t *testing.T) { } func Test_listNodesInUser(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "marc"}, + {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, + {Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"}, + {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, + {Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, + } + type args struct { nodes types.Nodes user string @@ -874,50 +891,239 @@ func Test_listNodesInUser(t *testing.T) { name: "1 node in user", args: args{ nodes: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, user: "joe", }, want: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, }, { name: "3 nodes, 2 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, user: "marc", }, want: types.Nodes{ - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, }, { name: "5 nodes, 0 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - &types.Node{ID: 4, User: types.User{Name: "marc"}}, - &types.Node{ID: 5, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, + &types.Node{ID: 4, User: users[0]}, + &types.Node{ID: 5, User: users[0]}, }, user: "mickael", }, want: nil, }, + { + name: "match-by-provider-ident", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[2]}, + }, + }, + { + name: "match-by-email", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + }, + }, + { + name: "multi-match-is-zero", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + &types.Node{ID: 3, User: users[3]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-email-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match email, then provider id + &types.Node{ID: 3, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-username-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match username, then provider id + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-duplicate-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-unique-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "marc", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + }, + }, + { + name: "all-users-no-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[1]}, + }, + }, + { + name: "all-users-no-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working@headscale.net", + }, + want: nil, + }, + { + name: "all-users-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 3, User: users[2]}, + }, + }, + { + name: "all-users-no-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "http://oidc.org/4321", + }, + want: nil, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got := filterNodesByUser(test.args.nodes, test.args.user) + for range 1000 { + ns := test.args.nodes + rand.Shuffle(len(ns), func(i, j int) { + ns[i], ns[j] = ns[j], ns[i] + }) + got := filterNodesByUser(ns, users, test.args.user) + sort.Slice(got, func(i, j int) bool { + return got[i].ID < got[j].ID + }) - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) + if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { + t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff) + } } }) } @@ -940,6 +1146,12 @@ func Test_expandAlias(t *testing.T) { return s } + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "joe"}, + {Model: gorm.Model{ID: 2}, Name: "marc"}, + {Model: gorm.Model{ID: 3}, Name: "mickael"}, + } + type field struct { pol ACLPolicy } @@ -989,19 +1201,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1022,19 +1234,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1185,7 +1397,7 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1194,7 +1406,7 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1203,11 +1415,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], }, }, }, @@ -1260,21 +1472,21 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1295,12 +1507,12 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1309,11 +1521,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1350,12 +1562,12 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{}, }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{}, }, }, @@ -1368,6 +1580,7 @@ func Test_expandAlias(t *testing.T) { t.Run(test.name, func(t *testing.T) { got, err := test.field.pol.ExpandAlias( test.args.nodes, + users, test.args.alias, ) if (err != nil) != test.wantErr { @@ -1715,6 +1928,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.field.pol.CompileFilterRules( + []types.User{}, tt.args.nodes, ) if (err != nil) != tt.wantErr { @@ -1834,6 +2048,13 @@ func TestTheInternet(t *testing.T) { } func TestReduceFilterRules(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "mickael"}, + {Model: gorm.Model{ID: 2}, Name: "user1"}, + {Model: gorm.Model{ID: 3}, Name: "user2"}, + {Model: gorm.Model{ID: 4}, Name: "user100"}, + } + tests := []struct { name string node *types.Node @@ -1855,13 +2076,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, + User: users[0], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, + User: users[0], }, }, want: []tailcfg.FilterRule{}, @@ -1888,7 +2109,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -1899,7 +2120,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1967,19 +2188,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, // "internal" exit node &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2026,12 +2247,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2113,7 +2334,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2122,12 +2343,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2215,7 +2436,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -2224,12 +2445,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2292,7 +2513,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -2301,12 +2522,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2362,7 +2583,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -2372,7 +2593,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2400,6 +2621,7 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, _ := tt.pol.CompileFilterRules( + users, append(tt.peers, tt.node), ) @@ -3391,7 +3613,7 @@ func TestSSHRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) + got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers) assert.NoError(t, err) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -3474,14 +3696,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { RequestTags: []string{"tag:test"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + node := &types.Node{ - ID: 0, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 0, - User: types.User{ - Name: "user1", - }, + ID: 0, + Hostname: "testnodes", + IPv4: iap("100.64.0.1"), + UserID: 0, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3498,7 +3723,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3532,7 +3757,8 @@ func TestInvalidTagValidUser(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3549,7 +3775,7 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3583,7 +3809,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3608,7 +3835,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3637,15 +3864,17 @@ func TestValidTagInvalidUser(t *testing.T) { Hostname: "webserver", RequestTags: []string{"tag:webapp"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } node := &types.Node{ - ID: 1, - Hostname: "webserver", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 1, + Hostname: "webserver", + IPv4: iap("100.64.0.1"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3656,13 +3885,11 @@ func TestValidTagInvalidUser(t *testing.T) { } nodes2 := &types.Node{ - ID: 2, - Hostname: "user", - IPv4: iap("100.64.0.2"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 2, + Hostname: "user", + IPv4: iap("100.64.0.2"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo2, } @@ -3678,7 +3905,7 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ From 31d398c793809bef6f02d1616b70cd47e7fb9899 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 17 Nov 2024 18:19:23 -0700 Subject: [PATCH 2/3] ensure provider id is found out of order Signed-off-by: Kristoffer Dalby --- hscontrol/policy/acls.go | 5 ++++- hscontrol/policy/acls_test.go | 28 ++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 8e2d1961..2dcbb88a 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -991,7 +991,10 @@ func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) var potentialUsers []types.User for _, user := range users { if user.ProviderIdentifier == userToken { - potentialUsers = append(potentialUsers, user) + // If a user is matching with a known unique field, + // disgard all other users and only keep the current + // user. + potentialUsers = []types.User{user} break } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index f13d7f42..1e1a5860 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -869,13 +869,18 @@ func Test_expandPorts(t *testing.T) { } } -func Test_listNodesInUser(t *testing.T) { +func Test_filterNodesByUser(t *testing.T) { users := []types.User{ {Model: gorm.Model{ID: 1}, Name: "marc"}, {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, {Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"}, {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, {Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, + {Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"}, + {Model: gorm.Model{ID: 7}, Name: "1"}, + {Model: gorm.Model{ID: 8}, Name: "alex", Email: "alex@headscale.net"}, + {Model: gorm.Model{ID: 9}, Name: "alex@headscale.net"}, + {Model: gorm.Model{ID: 10}, Email: "http://oidc.org/1234"}, } type args struct { @@ -947,6 +952,7 @@ func Test_listNodesInUser(t *testing.T) { nodes: types.Nodes{ &types.Node{ID: 1, User: users[1]}, &types.Node{ID: 2, User: users[2]}, + &types.Node{ID: 8, User: users[7]}, }, user: "joe@headscale.net", }, @@ -1057,6 +1063,7 @@ func Test_listNodesInUser(t *testing.T) { &types.Node{ID: 3, User: users[2]}, &types.Node{ID: 4, User: users[3]}, &types.Node{ID: 5, User: users[4]}, + &types.Node{ID: 8, User: users[7]}, }, user: "joe@headscale.net", }, @@ -1064,6 +1071,17 @@ func Test_listNodesInUser(t *testing.T) { &types.Node{ID: 2, User: users[1]}, }, }, + { + name: "email-as-username-duplicate", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[7]}, + &types.Node{ID: 2, User: users[8]}, + }, + user: "alex@headscale.net", + }, + want: nil, + }, { name: "all-users-no-email-random-order", args: args{ @@ -1087,6 +1105,7 @@ func Test_listNodesInUser(t *testing.T) { &types.Node{ID: 3, User: users[2]}, &types.Node{ID: 4, User: users[3]}, &types.Node{ID: 5, User: users[4]}, + &types.Node{ID: 6, User: users[5]}, }, user: "http://oidc.org/1234", }, @@ -1103,6 +1122,7 @@ func Test_listNodesInUser(t *testing.T) { &types.Node{ID: 3, User: users[2]}, &types.Node{ID: 4, User: users[3]}, &types.Node{ID: 5, User: users[4]}, + &types.Node{ID: 6, User: users[5]}, }, user: "http://oidc.org/4321", }, @@ -1116,7 +1136,11 @@ func Test_listNodesInUser(t *testing.T) { rand.Shuffle(len(ns), func(i, j int) { ns[i], ns[j] = ns[j], ns[i] }) - got := filterNodesByUser(ns, users, test.args.user) + us := users + rand.Shuffle(len(us), func(i, j int) { + us[i], us[j] = us[j], us[i] + }) + got := filterNodesByUser(ns, us, test.args.user) sort.Slice(got, func(i, j int) bool { return got[i].ID < got[j].ID }) From f8ec54d8162847b1450a738cc28edfaef3b6e710 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 18 Nov 2024 15:58:42 +0100 Subject: [PATCH 3/3] only set username and email if valid Signed-off-by: Kristoffer Dalby --- hscontrol/policy/acls.go | 2 +- hscontrol/policy/acls_test.go | 3 ++- hscontrol/types/users.go | 24 ++++++++++++++++++------ 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 2dcbb88a..5d41d000 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -990,7 +990,7 @@ func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) var potentialUsers []types.User for _, user := range users { - if user.ProviderIdentifier == userToken { + if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == userToken { // If a user is matching with a known unique field, // disgard all other users and only keep the current // user. diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 1e1a5860..1b9a3fd0 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -1,6 +1,7 @@ package policy import ( + "database/sql" "errors" "math/rand/v2" "net/netip" @@ -873,7 +874,7 @@ func Test_filterNodesByUser(t *testing.T) { users := []types.User{ {Model: gorm.Model{ID: 1}, Name: "marc"}, {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, - {Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"}, + {Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: sql.NullString{String: "http://oidc.org/1234", Valid: true}}, {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, {Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, {Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"}, diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index f983d7f5..72cc9e1e 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,6 +2,8 @@ package types import ( "cmp" + "database/sql" + "net/mail" "strconv" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -34,7 +36,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier string `gorm:"index"` + ProviderIdentifier sql.NullString `gorm:"index"` // Provider is the origin of the user account, // same as RegistrationMethod, without authkey. @@ -51,7 +53,7 @@ type User struct { // should be used throughout headscale, in information returned to the // user and the Policy engine. func (u *User) Username() string { - return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) } // DisplayNameOrUsername returns the DisplayName if it exists, otherwise @@ -107,7 +109,7 @@ func (u *User) Proto() *v1.User { CreatedAt: timestamppb.New(u.CreatedAt), DisplayName: u.DisplayName, Email: u.Email, - ProviderId: u.ProviderIdentifier, + ProviderId: u.ProviderIdentifier.String, Provider: u.Provider, ProfilePicUrl: u.ProfilePicURL, } @@ -129,10 +131,20 @@ type OIDCClaims struct { // FromClaim overrides a User from OIDC claims. // All fields will be updated, except for the ID. func (u *User) FromClaim(claims *OIDCClaims) { - u.ProviderIdentifier = claims.Sub + err := util.CheckForFQDNRules(claims.Username) + if err == nil { + u.Name = claims.Username + } + + if claims.EmailVerified { + _, err = mail.ParseAddress(claims.Email) + if err == nil { + u.Email = claims.Email + } + } + + u.ProviderIdentifier.String = claims.Sub u.DisplayName = claims.Name - u.Email = claims.Email - u.Name = claims.Username u.ProfilePicURL = claims.ProfilePictureURL u.Provider = util.RegisterMethodOIDC }