diff --git a/hscontrol/app.go b/hscontrol/app.go index 5c85b064..7fb68bc9 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -1026,14 +1026,18 @@ func (h *Headscale) loadACLPolicy() error { if err != nil { return fmt.Errorf("loading nodes from database to validate policy: %w", err) } + users, err := h.db.ListUsers() + if err != nil { + return fmt.Errorf("loading users from database to validate policy: %w", err) + } - _, err = pol.CompileFilterRules(nodes) + _, err = pol.CompileFilterRules(users, nodes) if err != nil { return fmt.Errorf("verifying policy rules: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) if err != nil { return fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 888f48db..5bcbd546 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(testPeers), check.Equals, 9) - adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) + adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) - testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) + testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 086261aa..c89a10f8 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes( if approvedAlias == node.User.Username() { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { + users, err := ListUsers(tx) + if err != nil { + return fmt.Errorf("looking up users to expand route alias: %w", err) + } + // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) + approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 68793716..e3291d8f 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } + users, err := api.h.db.ListUsers() + if err != nil { + return nil, fmt.Errorf("loading users from database to validate policy: %w", err) + } - _, err = pol.CompileFilterRules(nodes) + _, err = pol.CompileFilterRules(users, nodes) if err != nil { return nil, fmt.Errorf("verifying policy rules: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 20aa674d..2d6abd7e 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -155,6 +155,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, + users []types.User, pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { @@ -169,6 +170,7 @@ func (m *Mapper) fullMapResponse( pol, node, capVer, + users, peers, peers, m.cfg, @@ -191,8 +193,12 @@ func (m *Mapper) FullMapResponse( if err != nil { return nil, err } + users, err := m.db.ListUsers() + if err != nil { + return nil, err + } - resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) if err != nil { return nil, err } @@ -255,6 +261,11 @@ func (m *Mapper) PeerChangedResponse( return nil, err } + users, err := m.db.ListUsers() + if err != nil { + return nil, fmt.Errorf("listing users for map response: %w", err) + } + var removedIDs []tailcfg.NodeID var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { @@ -278,6 +289,7 @@ func (m *Mapper) PeerChangedResponse( pol, node, mapRequest.Version, + users, peers, changedNodes, m.cfg, @@ -510,16 +522,17 @@ func appendPeerChanges( pol *policy.ACLPolicy, node *types.Node, capVer tailcfg.CapabilityVersion, + users []types.User, peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(append(peers, node)) + packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) if err != nil { return err } - sshPolicy, err := pol.CompileSSHPolicy(node, peers) + sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) if err != nil { return err } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 32ea5352..566a8fc4 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -171,6 +171,9 @@ func Test_fullMapResponse(t *testing.T) { lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) + user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"} + user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"} + mini := &types.Node{ ID: 0, MachineKey: mustMK( @@ -185,8 +188,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, @@ -265,8 +268,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.2"), Hostname: "peer1", GivenName: "peer1", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -320,8 +323,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.3"), Hostname: "peer2", GivenName: "peer2", - UserID: 1, - User: types.User{Name: "peer2"}, + UserID: user2.ID, + User: user2, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -480,6 +483,7 @@ func Test_fullMapResponse(t *testing.T) { got, err := mappy.fullMapResponse( tt.node, tt.peers, + []types.User{user1, user2}, tt.pol, 0, ) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index ff73985b..b1690af3 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests( policy *ACLPolicy, node *types.Node, peers types.Nodes, + users []types.User, ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { // If there is no policy defined, we default to allow all if policy == nil { return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.CompileFilterRules(append(peers, node)) + rules, err := policy.CompileFilterRules(users, append(peers, node)) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - sshPolicy, err := policy.CompileSSHPolicy(node, peers) + sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests( // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) CompileFilterRules( + users []types.User, nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { if pol == nil { @@ -176,7 +178,7 @@ func (pol *ACLPolicy) CompileFilterRules( var srcIPs []string for srcIndex, src := range acl.Sources { - srcs, err := pol.expandSource(src, nodes) + srcs, err := pol.expandSource(src, users, nodes) if err != nil { return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) } @@ -197,6 +199,7 @@ func (pol *ACLPolicy) CompileFilterRules( expanded, err := pol.ExpandAlias( nodes, + users, alias, ) if err != nil { @@ -281,6 +284,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F func (pol *ACLPolicy) CompileSSHPolicy( node *types.Node, + users []types.User, peers types.Nodes, ) (*tailcfg.SSHPolicy, error) { if pol == nil { @@ -312,7 +316,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( for index, sshACL := range pol.SSHs { var dest netipx.IPSetBuilder for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, node), src) + expanded, err := pol.ExpandAlias(append(peers, node), users, src) if err != nil { return nil, err } @@ -363,6 +367,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( } else { expandedSrcs, err := pol.ExpandAlias( peers, + users, rawSrc, ) if err != nil { @@ -512,9 +517,10 @@ func parseProtocol(protocol string) ([]int, bool, error) { // with the given src alias. func (pol *ACLPolicy) expandSource( src string, + users []types.User, nodes types.Nodes, ) ([]string, error) { - ipSet, err := pol.ExpandAlias(nodes, src) + ipSet, err := pol.ExpandAlias(nodes, users, src) if err != nil { return []string{}, err } @@ -538,6 +544,7 @@ func (pol *ACLPolicy) expandSource( // and transform these in IPAddresses. func (pol *ACLPolicy) ExpandAlias( nodes types.Nodes, + users []types.User, alias string, ) (*netipx.IPSet, error) { if isWildcard(alias) { @@ -552,12 +559,12 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is a group if isGroup(alias) { - return pol.expandIPsFromGroup(alias, nodes) + return pol.expandIPsFromGroup(alias, users, nodes) } // if alias is a tag if isTag(alias) { - return pol.expandIPsFromTag(alias, nodes) + return pol.expandIPsFromTag(alias, users, nodes) } if isAutoGroup(alias) { @@ -565,7 +572,7 @@ func (pol *ACLPolicy) ExpandAlias( } // if alias is a user - if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { + if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { return ips, err } @@ -574,7 +581,7 @@ func (pol *ACLPolicy) ExpandAlias( if h, ok := pol.Hosts[alias]; ok { log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.ExpandAlias(nodes, h.String()) + return pol.ExpandAlias(nodes, users, h.String()) } // if alias is an IP @@ -751,16 +758,17 @@ func (pol *ACLPolicy) expandUsersFromGroup( func (pol *ACLPolicy) expandIPsFromGroup( group string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - users, err := pol.expandUsersFromGroup(group) + userTokens, err := pol.expandUsersFromGroup(group) if err != nil { return &netipx.IPSet{}, err } - for _, user := range users { - filteredNodes := filterNodesByUser(nodes, user) + for _, user := range userTokens { + filteredNodes := filterNodesByUser(nodes, users, user) for _, node := range filteredNodes { node.AppendToIPSet(&build) } @@ -771,6 +779,7 @@ func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromTag( alias string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder @@ -803,7 +812,7 @@ func (pol *ACLPolicy) expandIPsFromTag( // filter out nodes per tag owner for _, user := range owners { - nodes := filterNodesByUser(nodes, user) + nodes := filterNodesByUser(nodes, users, user) for _, node := range nodes { if node.Hostinfo == nil { continue @@ -820,11 +829,12 @@ func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromUser( user string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - filteredNodes := filterNodesByUser(nodes, user) + filteredNodes := filterNodesByUser(nodes, users, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) // shortcurcuit if we have no nodes to get ips from. @@ -953,10 +963,43 @@ func (pol *ACLPolicy) TagsOfNode( return validTags, invalidTags } -func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { +// filterNodesByUser returns a list of nodes that match the given userToken from a +// policy. +// Matching nodes are determined by first matching the user token to a user by checking: +// - If it is an ID that mactches the user database ID +// - It is the Provider Identifier from OIDC +// - It matches the username or email of a user +// +// If the token matches more than one user, zero nodes will returned. +func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { var out types.Nodes + + var potentialUsers []types.User + for _, user := range users { + if id, err := strconv.ParseUint(userToken, 10, 64); err == nil && user.ID == uint(id) { + potentialUsers = append(potentialUsers, user) + break + } + if user.ProviderIdentifier == userToken { + potentialUsers = append(potentialUsers, user) + break + } + if user.Email == userToken { + potentialUsers = append(potentialUsers, user) + } + if user.Name == userToken { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) != 1 { + return nil + } + + user := potentialUsers[0] + for _, node := range nodes { - if node.User.Username() == user { + if node.User.ID == user.ID { out = append(out, node) } } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 1c6e4de8..cd17611f 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -375,18 +376,24 @@ func TestParsing(t *testing.T) { return } - rules, err := pol.CompileFilterRules(types.Nodes{ - &types.Node{ - IPv4: iap("100.100.100.100"), + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser", + } + rules, err := pol.CompileFilterRules( + []types.User{ + user, }, - &types.Node{ - IPv4: iap("200.200.200.200"), - User: types.User{ - Name: "testuser", + types.Nodes{ + &types.Node{ + IPv4: iap("100.100.100.100"), }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) + &types.Node{ + IPv4: iap("200.200.200.200"), + User: user, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -533,7 +540,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.CompileFilterRules(types.Nodes{}) + rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -549,7 +556,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -568,7 +575,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -584,7 +591,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } @@ -861,6 +868,13 @@ func Test_expandPorts(t *testing.T) { } func Test_listNodesInUser(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "marc"}, + {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, + {Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"}, + {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, + } + type args struct { nodes types.Nodes user string @@ -874,47 +888,101 @@ func Test_listNodesInUser(t *testing.T) { name: "1 node in user", args: args{ nodes: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, user: "joe", }, want: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, }, { name: "3 nodes, 2 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, user: "marc", }, want: types.Nodes{ - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, }, { name: "5 nodes, 0 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - &types.Node{ID: 4, User: types.User{Name: "marc"}}, - &types.Node{ID: 5, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, + &types.Node{ID: 4, User: users[0]}, + &types.Node{ID: 5, User: users[0]}, }, user: "mickael", }, want: nil, }, + { + name: "match-by-userid", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, + &types.Node{ID: 4, User: users[0]}, + &types.Node{ID: 5, User: users[0]}, + }, + user: "2", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + }, + }, + { + name: "match-by-provider-ident", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[2]}, + }, + }, + { + name: "match-by-email", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + }, + }, + { + name: "multi-match-is-zero", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + &types.Node{ID: 3, User: users[3]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got := filterNodesByUser(test.args.nodes, test.args.user) + got := filterNodesByUser(test.args.nodes, users, test.args.user) if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) @@ -940,6 +1008,12 @@ func Test_expandAlias(t *testing.T) { return s } + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "joe"}, + {Model: gorm.Model{ID: 2}, Name: "marc"}, + {Model: gorm.Model{ID: 3}, Name: "mickael"}, + } + type field struct { pol ACLPolicy } @@ -989,19 +1063,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1022,19 +1096,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1185,7 +1259,7 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1194,7 +1268,7 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1203,11 +1277,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], }, }, }, @@ -1260,21 +1334,21 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1295,12 +1369,12 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1309,11 +1383,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1350,12 +1424,12 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{}, }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{}, }, }, @@ -1368,6 +1442,7 @@ func Test_expandAlias(t *testing.T) { t.Run(test.name, func(t *testing.T) { got, err := test.field.pol.ExpandAlias( test.args.nodes, + users, test.args.alias, ) if (err != nil) != test.wantErr { @@ -1715,6 +1790,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.field.pol.CompileFilterRules( + []types.User{}, tt.args.nodes, ) if (err != nil) != tt.wantErr { @@ -1834,6 +1910,13 @@ func TestTheInternet(t *testing.T) { } func TestReduceFilterRules(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "mickael"}, + {Model: gorm.Model{ID: 2}, Name: "user1"}, + {Model: gorm.Model{ID: 3}, Name: "user2"}, + {Model: gorm.Model{ID: 4}, Name: "user100"}, + } + tests := []struct { name string node *types.Node @@ -1855,13 +1938,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, + User: users[0], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, + User: users[0], }, }, want: []tailcfg.FilterRule{}, @@ -1888,7 +1971,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -1899,7 +1982,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1967,19 +2050,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, // "internal" exit node &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2026,12 +2109,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2113,7 +2196,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2122,12 +2205,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2215,7 +2298,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -2224,12 +2307,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2292,7 +2375,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -2301,12 +2384,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2362,7 +2445,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -2372,7 +2455,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2400,6 +2483,7 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, _ := tt.pol.CompileFilterRules( + users, append(tt.peers, tt.node), ) @@ -3391,7 +3475,7 @@ func TestSSHRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) + got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers) assert.NoError(t, err) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -3474,14 +3558,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { RequestTags: []string{"tag:test"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + node := &types.Node{ - ID: 0, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 0, - User: types.User{ - Name: "user1", - }, + ID: 0, + Hostname: "testnodes", + IPv4: iap("100.64.0.1"), + UserID: 0, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3498,7 +3585,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3532,7 +3619,8 @@ func TestInvalidTagValidUser(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3549,7 +3637,7 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3583,7 +3671,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3608,7 +3697,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3637,15 +3726,17 @@ func TestValidTagInvalidUser(t *testing.T) { Hostname: "webserver", RequestTags: []string{"tag:webapp"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } node := &types.Node{ - ID: 1, - Hostname: "webserver", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 1, + Hostname: "webserver", + IPv4: iap("100.64.0.1"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3656,13 +3747,11 @@ func TestValidTagInvalidUser(t *testing.T) { } nodes2 := &types.Node{ - ID: 2, - Hostname: "user", - IPv4: iap("100.64.0.2"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 2, + Hostname: "user", + IPv4: iap("100.64.0.2"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo2, } @@ -3678,7 +3767,7 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{