From 50165ce9e1746303a14790fe0b6e9222ca455942 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 21 Oct 2024 11:58:59 -0600 Subject: [PATCH 1/5] resolve user identifier to stable ID currently, the policy approach node to user matching with a quite naive approach looking at the username provided in the policy and matched it with the username on the nodes. This worked ok as long as usernames were unique and did not change. As usernames are no longer guarenteed to be unique in an OIDC environment we cant rely on this. This changes the mechanism that matches the user string (now user token) with nodes: - first find all potential users by looking up: - database ID - provider ID (OIDC) - username/email If more than one user is matching, then the query is rejected, and zero matching nodes are returned. When a single user is found, the node is matched against the User database ID, which are also present on the actual node. This means that from this commit, users can use the following to identify users in the policy: - provider identity (iss + sub) - username - email - database id There are more changes coming to this, so it is not recommended to start using any of these new abilities, with the exception of email, which will not change since it includes an @. Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 8 +- hscontrol/db/node_test.go | 4 +- hscontrol/db/routes.go | 7 +- hscontrol/grpcv1.go | 8 +- hscontrol/mapper/mapper.go | 19 +- hscontrol/mapper/mapper_test.go | 16 +- hscontrol/policy/acls.go | 92 +++++-- hscontrol/policy/acls_test.go | 425 ++++++++++++++++++++++++-------- 8 files changed, 445 insertions(+), 134 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 5c85b064..7fb68bc9 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -1026,14 +1026,18 @@ func (h *Headscale) loadACLPolicy() error { if err != nil { return fmt.Errorf("loading nodes from database to validate policy: %w", err) } + users, err := h.db.ListUsers() + if err != nil { + return fmt.Errorf("loading users from database to validate policy: %w", err) + } - _, err = pol.CompileFilterRules(nodes) + _, err = pol.CompileFilterRules(users, nodes) if err != nil { return fmt.Errorf("verifying policy rules: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) if err != nil { return fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 888f48db..5bcbd546 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(testPeers), check.Equals, 9) - adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) + adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) - testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) + testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 086261aa..c89a10f8 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes( if approvedAlias == node.User.Username() { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { + users, err := ListUsers(tx) + if err != nil { + return fmt.Errorf("looking up users to expand route alias: %w", err) + } + // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) + approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 68793716..e3291d8f 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } + users, err := api.h.db.ListUsers() + if err != nil { + return nil, fmt.Errorf("loading users from database to validate policy: %w", err) + } - _, err = pol.CompileFilterRules(nodes) + _, err = pol.CompileFilterRules(users, nodes) if err != nil { return nil, fmt.Errorf("verifying policy rules: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 3db1e159..5205a112 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, + users []types.User, pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { @@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse( pol, node, capVer, + users, peers, peers, m.cfg, @@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse( if err != nil { return nil, err } + users, err := m.db.ListUsers() + if err != nil { + return nil, err + } - resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) if err != nil { return nil, err } @@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse( return nil, err } + users, err := m.db.ListUsers() + if err != nil { + return nil, fmt.Errorf("listing users for map response: %w", err) + } + var removedIDs []tailcfg.NodeID var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { @@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse( pol, node, mapRequest.Version, + users, peers, changedNodes, m.cfg, @@ -508,16 +520,17 @@ func appendPeerChanges( pol *policy.ACLPolicy, node *types.Node, capVer tailcfg.CapabilityVersion, + users []types.User, peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(append(peers, node)) + packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) if err != nil { return err } - sshPolicy, err := pol.CompileSSHPolicy(node, peers) + sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) if err != nil { return err } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 37ed5c42..8dd51808 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) { lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) + user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"} + user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"} + mini := &types.Node{ ID: 0, MachineKey: mustMK( @@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, @@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.2"), Hostname: "peer1", GivenName: "peer1", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.3"), Hostname: "peer2", GivenName: "peer2", - UserID: 1, - User: types.User{Name: "peer2"}, + UserID: user2.ID, + User: user2, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) { got, err := mappy.fullMapResponse( tt.node, tt.peers, + []types.User{user1, user2}, tt.pol, 0, ) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index ff73985b..8e2d1961 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests( policy *ACLPolicy, node *types.Node, peers types.Nodes, + users []types.User, ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { // If there is no policy defined, we default to allow all if policy == nil { return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.CompileFilterRules(append(peers, node)) + rules, err := policy.CompileFilterRules(users, append(peers, node)) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - sshPolicy, err := policy.CompileSSHPolicy(node, peers) + sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests( // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) CompileFilterRules( + users []types.User, nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { if pol == nil { @@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules( var srcIPs []string for srcIndex, src := range acl.Sources { - srcs, err := pol.expandSource(src, nodes) + srcs, err := pol.expandSource(src, users, nodes) if err != nil { - return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) + return nil, fmt.Errorf( + "parsing policy, acl index: %d->%d: %w", + index, + srcIndex, + err, + ) } srcIPs = append(srcIPs, srcs...) } @@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules( expanded, err := pol.ExpandAlias( nodes, + users, alias, ) if err != nil { @@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F func (pol *ACLPolicy) CompileSSHPolicy( node *types.Node, + users []types.User, peers types.Nodes, ) (*tailcfg.SSHPolicy, error) { if pol == nil { @@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( for index, sshACL := range pol.SSHs { var dest netipx.IPSetBuilder for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, node), src) + expanded, err := pol.ExpandAlias(append(peers, node), users, src) if err != nil { return nil, err } @@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy( case "check": checkAction, err := sshCheckAction(sshACL.CheckPeriod) if err != nil { - return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) + return nil, fmt.Errorf( + "parsing SSH policy, parsing check duration, index: %d: %w", + index, + err, + ) } else { action = *checkAction } default: - return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) + return nil, fmt.Errorf( + "parsing SSH policy, unknown action %q, index: %d: %w", + sshACL.Action, + index, + err, + ) } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( } else { expandedSrcs, err := pol.ExpandAlias( peers, + users, rawSrc, ) if err != nil { @@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) { // with the given src alias. func (pol *ACLPolicy) expandSource( src string, + users []types.User, nodes types.Nodes, ) ([]string, error) { - ipSet, err := pol.ExpandAlias(nodes, src) + ipSet, err := pol.ExpandAlias(nodes, users, src) if err != nil { return []string{}, err } @@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource( // and transform these in IPAddresses. func (pol *ACLPolicy) ExpandAlias( nodes types.Nodes, + users []types.User, alias string, ) (*netipx.IPSet, error) { if isWildcard(alias) { @@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is a group if isGroup(alias) { - return pol.expandIPsFromGroup(alias, nodes) + return pol.expandIPsFromGroup(alias, users, nodes) } // if alias is a tag if isTag(alias) { - return pol.expandIPsFromTag(alias, nodes) + return pol.expandIPsFromTag(alias, users, nodes) } if isAutoGroup(alias) { @@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias( } // if alias is a user - if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { + if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { return ips, err } @@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias( if h, ok := pol.Hosts[alias]; ok { log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.ExpandAlias(nodes, h.String()) + return pol.ExpandAlias(nodes, users, h.String()) } // if alias is an IP @@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup( func (pol *ACLPolicy) expandIPsFromGroup( group string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - users, err := pol.expandUsersFromGroup(group) + userTokens, err := pol.expandUsersFromGroup(group) if err != nil { return &netipx.IPSet{}, err } - for _, user := range users { - filteredNodes := filterNodesByUser(nodes, user) + for _, user := range userTokens { + filteredNodes := filterNodesByUser(nodes, users, user) for _, node := range filteredNodes { node.AppendToIPSet(&build) } @@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromTag( alias string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder @@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag( // filter out nodes per tag owner for _, user := range owners { - nodes := filterNodesByUser(nodes, user) + nodes := filterNodesByUser(nodes, users, user) for _, node := range nodes { if node.Hostinfo == nil { continue @@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromUser( user string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - filteredNodes := filterNodesByUser(nodes, user) + filteredNodes := filterNodesByUser(nodes, users, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) // shortcurcuit if we have no nodes to get ips from. @@ -953,10 +977,40 @@ func (pol *ACLPolicy) TagsOfNode( return validTags, invalidTags } -func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { +// filterNodesByUser returns a list of nodes that match the given userToken from a +// policy. +// Matching nodes are determined by first matching the user token to a user by checking: +// - If it is an ID that mactches the user database ID +// - It is the Provider Identifier from OIDC +// - It matches the username or email of a user +// +// If the token matches more than one user, zero nodes will returned. +func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { var out types.Nodes + + var potentialUsers []types.User + for _, user := range users { + if user.ProviderIdentifier == userToken { + potentialUsers = append(potentialUsers, user) + + break + } + if user.Email == userToken { + potentialUsers = append(potentialUsers, user) + } + if user.Name == userToken { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) != 1 { + return nil + } + + user := potentialUsers[0] + for _, node := range nodes { - if node.User.Username() == user { + if node.User.ID == user.ID { out = append(out, node) } } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 1c6e4de8..f13d7f42 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -2,8 +2,10 @@ package policy import ( "errors" + "math/rand/v2" "net/netip" "slices" + "sort" "testing" "github.com/google/go-cmp/cmp" @@ -14,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -375,18 +378,24 @@ func TestParsing(t *testing.T) { return } - rules, err := pol.CompileFilterRules(types.Nodes{ - &types.Node{ - IPv4: iap("100.100.100.100"), + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser", + } + rules, err := pol.CompileFilterRules( + []types.User{ + user, }, - &types.Node{ - IPv4: iap("200.200.200.200"), - User: types.User{ - Name: "testuser", + types.Nodes{ + &types.Node{ + IPv4: iap("100.100.100.100"), }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) + &types.Node{ + IPv4: iap("200.200.200.200"), + User: user, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -533,7 +542,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.CompileFilterRules(types.Nodes{}) + rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -549,7 +558,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -568,7 +577,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -584,7 +593,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } @@ -861,6 +870,14 @@ func Test_expandPorts(t *testing.T) { } func Test_listNodesInUser(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "marc"}, + {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, + {Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"}, + {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, + {Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, + } + type args struct { nodes types.Nodes user string @@ -874,50 +891,239 @@ func Test_listNodesInUser(t *testing.T) { name: "1 node in user", args: args{ nodes: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, user: "joe", }, want: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, }, { name: "3 nodes, 2 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, user: "marc", }, want: types.Nodes{ - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, }, { name: "5 nodes, 0 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - &types.Node{ID: 4, User: types.User{Name: "marc"}}, - &types.Node{ID: 5, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, + &types.Node{ID: 4, User: users[0]}, + &types.Node{ID: 5, User: users[0]}, }, user: "mickael", }, want: nil, }, + { + name: "match-by-provider-ident", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[2]}, + }, + }, + { + name: "match-by-email", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + }, + }, + { + name: "multi-match-is-zero", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + &types.Node{ID: 3, User: users[3]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-email-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match email, then provider id + &types.Node{ID: 3, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-username-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match username, then provider id + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-duplicate-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-unique-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "marc", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + }, + }, + { + name: "all-users-no-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[1]}, + }, + }, + { + name: "all-users-no-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working@headscale.net", + }, + want: nil, + }, + { + name: "all-users-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 3, User: users[2]}, + }, + }, + { + name: "all-users-no-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "http://oidc.org/4321", + }, + want: nil, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got := filterNodesByUser(test.args.nodes, test.args.user) + for range 1000 { + ns := test.args.nodes + rand.Shuffle(len(ns), func(i, j int) { + ns[i], ns[j] = ns[j], ns[i] + }) + got := filterNodesByUser(ns, users, test.args.user) + sort.Slice(got, func(i, j int) bool { + return got[i].ID < got[j].ID + }) - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) + if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { + t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff) + } } }) } @@ -940,6 +1146,12 @@ func Test_expandAlias(t *testing.T) { return s } + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "joe"}, + {Model: gorm.Model{ID: 2}, Name: "marc"}, + {Model: gorm.Model{ID: 3}, Name: "mickael"}, + } + type field struct { pol ACLPolicy } @@ -989,19 +1201,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1022,19 +1234,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1185,7 +1397,7 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1194,7 +1406,7 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1203,11 +1415,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], }, }, }, @@ -1260,21 +1472,21 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1295,12 +1507,12 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1309,11 +1521,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1350,12 +1562,12 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{}, }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{}, }, }, @@ -1368,6 +1580,7 @@ func Test_expandAlias(t *testing.T) { t.Run(test.name, func(t *testing.T) { got, err := test.field.pol.ExpandAlias( test.args.nodes, + users, test.args.alias, ) if (err != nil) != test.wantErr { @@ -1715,6 +1928,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.field.pol.CompileFilterRules( + []types.User{}, tt.args.nodes, ) if (err != nil) != tt.wantErr { @@ -1834,6 +2048,13 @@ func TestTheInternet(t *testing.T) { } func TestReduceFilterRules(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "mickael"}, + {Model: gorm.Model{ID: 2}, Name: "user1"}, + {Model: gorm.Model{ID: 3}, Name: "user2"}, + {Model: gorm.Model{ID: 4}, Name: "user100"}, + } + tests := []struct { name string node *types.Node @@ -1855,13 +2076,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, + User: users[0], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, + User: users[0], }, }, want: []tailcfg.FilterRule{}, @@ -1888,7 +2109,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -1899,7 +2120,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1967,19 +2188,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, // "internal" exit node &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2026,12 +2247,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2113,7 +2334,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2122,12 +2343,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2215,7 +2436,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -2224,12 +2445,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2292,7 +2513,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -2301,12 +2522,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2362,7 +2583,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -2372,7 +2593,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2400,6 +2621,7 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, _ := tt.pol.CompileFilterRules( + users, append(tt.peers, tt.node), ) @@ -3391,7 +3613,7 @@ func TestSSHRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) + got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers) assert.NoError(t, err) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -3474,14 +3696,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { RequestTags: []string{"tag:test"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + node := &types.Node{ - ID: 0, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 0, - User: types.User{ - Name: "user1", - }, + ID: 0, + Hostname: "testnodes", + IPv4: iap("100.64.0.1"), + UserID: 0, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3498,7 +3723,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3532,7 +3757,8 @@ func TestInvalidTagValidUser(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3549,7 +3775,7 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3583,7 +3809,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3608,7 +3835,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3637,15 +3864,17 @@ func TestValidTagInvalidUser(t *testing.T) { Hostname: "webserver", RequestTags: []string{"tag:webapp"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } node := &types.Node{ - ID: 1, - Hostname: "webserver", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 1, + Hostname: "webserver", + IPv4: iap("100.64.0.1"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3656,13 +3885,11 @@ func TestValidTagInvalidUser(t *testing.T) { } nodes2 := &types.Node{ - ID: 2, - Hostname: "user", - IPv4: iap("100.64.0.2"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 2, + Hostname: "user", + IPv4: iap("100.64.0.2"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo2, } @@ -3678,7 +3905,7 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ From 6afb554e20cb06ed535ce9abd29a7abbdb8a72e4 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 11:42:14 -0500 Subject: [PATCH 2/5] wrap policy in policy manager interface Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 33 ++++--- hscontrol/db/node_test.go | 15 ++- hscontrol/db/routes.go | 20 +--- hscontrol/grpcv1.go | 27 +++--- hscontrol/mapper/mapper.go | 44 ++++----- hscontrol/mapper/mapper_test.go | 4 +- hscontrol/mapper/tail.go | 8 +- hscontrol/mapper/tail_test.go | 5 +- hscontrol/policy/pm.go | 164 ++++++++++++++++++++++++++++++++ hscontrol/poll.go | 16 ++-- 10 files changed, 246 insertions(+), 90 deletions(-) create mode 100644 hscontrol/policy/pm.go diff --git a/hscontrol/app.go b/hscontrol/app.go index 7fb68bc9..3489d18f 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -88,7 +88,7 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - ACLPolicy *policy.ACLPolicy + polMan policy.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier @@ -499,7 +499,7 @@ func (h *Headscale) Serve() error { // Fetch an initial DERP Map before we start serving h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier) + h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan) if h.cfg.DERP.ServerEnabled { // When embedded DERP is enabled we always need a STUN server @@ -774,7 +774,7 @@ func (h *Headscale) Serve() error { log.Error().Err(err).Msg("failed to reload ACL policy") } - if h.ACLPolicy != nil { + if h.polMan != nil { log.Info(). Msg("ACL policy successfully reloaded, notifying nodes of change") @@ -995,8 +995,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { func (h *Headscale) loadACLPolicy() error { var ( - pol *policy.ACLPolicy - err error + pm policy.PolicyManager ) switch h.cfg.Policy.Mode { @@ -1009,10 +1008,6 @@ func (h *Headscale) loadACLPolicy() error { } absPath := util.AbsolutePathFromConfigPath(path) - pol, err = policy.LoadACLPolicyFromPath(absPath) - if err != nil { - return fmt.Errorf("failed to load ACL policy from file: %w", err) - } // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still @@ -1031,13 +1026,13 @@ func (h *Headscale) loadACLPolicy() error { return fmt.Errorf("loading users from database to validate policy: %w", err) } - _, err = pol.CompileFilterRules(users, nodes) + pm, err = policy.NewPolicyManagerFromPath(absPath, users, nodes) if err != nil { - return fmt.Errorf("verifying policy rules: %w", err) + return fmt.Errorf("loading policy from file: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) + _, err = pm.SSHPolicy(nodes[0]) if err != nil { return fmt.Errorf("verifying SSH rules: %w", err) } @@ -1053,9 +1048,17 @@ func (h *Headscale) loadACLPolicy() error { return fmt.Errorf("failed to get policy from database: %w", err) } - pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data)) + nodes, err := h.db.ListNodes() if err != nil { - return fmt.Errorf("failed to parse policy: %w", err) + return fmt.Errorf("loading nodes from database to validate policy: %w", err) + } + users, err := h.db.ListUsers() + if err != nil { + return fmt.Errorf("loading users from database to validate policy: %w", err) + } + pm, err = policy.NewPolicyManager([]byte(p.Data), users, nodes) + if err != nil { + return fmt.Errorf("loading policy from database: %w", err) } default: log.Fatal(). @@ -1063,7 +1066,7 @@ func (h *Headscale) loadACLPolicy() error { Msg("Unknown ACL policy mode") } - h.ACLPolicy = pol + h.polMan = pm return nil } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 5bcbd546..46dce68b 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -559,10 +559,6 @@ func TestAutoApproveRoutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { adb, err := newTestDB() assert.NoError(t, err) - pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) - - assert.NoError(t, err) - assert.NotNil(t, pol) user, err := adb.CreateUser("test") assert.NoError(t, err) @@ -599,8 +595,17 @@ func TestAutoApproveRoutes(t *testing.T) { node0ByID, err := adb.GetNodeByID(0) assert.NoError(t, err) + users, err := adb.ListUsers() + assert.NoError(t, err) + + nodes, err := adb.ListNodes() + assert.NoError(t, err) + + pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes) + assert.NoError(t, err) + // TODO(kradalby): Check state update - err = adb.EnableAutoApprovedRoutes(pol, node0ByID) + err = adb.EnableAutoApprovedRoutes(pm, node0ByID) assert.NoError(t, err) enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index c89a10f8..ebb08563 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -598,18 +598,18 @@ func failoverRoute( } func (hsdb *HSDatabase) EnableAutoApprovedRoutes( - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { return hsdb.Write(func(tx *gorm.DB) error { - return EnableAutoApprovedRoutes(tx, aclPolicy, node) + return EnableAutoApprovedRoutes(tx, polMan, node) }) } // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. func EnableAutoApprovedRoutes( tx *gorm.DB, - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { if node.IPv4 == nil && node.IPv6 == nil { @@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes( continue } - routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( - netip.Prefix(advertisedRoute.Prefix), - ) - if err != nil { - return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err) - } + routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix)) log.Trace(). Str("node", node.Hostname). @@ -648,13 +643,8 @@ func EnableAutoApprovedRoutes( if approvedAlias == node.User.Username() { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { - users, err := ListUsers(tx) - if err != nil { - return fmt.Errorf("looking up users to expand route alias: %w", err) - } - // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) + approvedIps, err := polMan.IPsForUser(approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index e3291d8f..f907ec0d 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -21,7 +21,6 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" ) @@ -450,10 +449,7 @@ func (api headscaleV1APIServer) ListNodes( resp.Online = true } - validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - node, - ) - resp.InvalidTags = invalidTags + validTags := api.h.polMan.Tags(node) resp.ValidTags = validTags response[index] = resp } @@ -723,11 +719,6 @@ func (api headscaleV1APIServer) SetPolicy( p := request.GetPolicy() - pol, err := policy.LoadACLPolicyFromBytes([]byte(p)) - if err != nil { - return nil, fmt.Errorf("loading ACL policy file: %w", err) - } - // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -742,13 +733,21 @@ func (api headscaleV1APIServer) SetPolicy( return nil, fmt.Errorf("loading users from database to validate policy: %w", err) } - _, err = pol.CompileFilterRules(users, nodes) + err = api.h.polMan.SetNodes(nodes) if err != nil { - return nil, fmt.Errorf("verifying policy rules: %w", err) + return nil, fmt.Errorf("setting nodes: %w", err) + } + err = api.h.polMan.SetUsers(users) + if err != nil { + return nil, fmt.Errorf("setting users: %w", err) + } + err = api.h.polMan.SetPolicy([]byte(p)) + if err != nil { + return nil, fmt.Errorf("setting policy: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) + _, err = api.h.polMan.SSHPolicy(nodes[0]) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } @@ -759,8 +758,6 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - api.h.ACLPolicy = pol - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ Type: types.StateFullUpdate, diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5205a112..6899dc25 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -55,6 +55,7 @@ type Mapper struct { cfg *types.Config derpMap *tailcfg.DERPMap notif *notifier.Notifier + polMan policy.PolicyManager uid string created time.Time @@ -71,6 +72,7 @@ func NewMapper( cfg *types.Config, derpMap *tailcfg.DERPMap, notif *notifier.Notifier, + polMan policy.PolicyManager, ) *Mapper { uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) @@ -79,6 +81,7 @@ func NewMapper( cfg: cfg, derpMap: derpMap, notif: notif, + polMan: polMan, uid: uid, created: time.Now(), @@ -154,10 +157,9 @@ func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, users []types.User, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, capVer) + resp, err := m.baseWithConfigMapResponse(node, capVer) if err != nil { return nil, err } @@ -165,7 +167,7 @@ func (m *Mapper) fullMapResponse( err = appendPeerChanges( resp, true, // full change - pol, + m.polMan, node, capVer, users, @@ -184,7 +186,6 @@ func (m *Mapper) fullMapResponse( func (m *Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { peers, err := m.ListPeers(node.ID) @@ -196,7 +197,7 @@ func (m *Mapper) FullMapResponse( return nil, err } - resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version) if err != nil { return nil, err } @@ -210,10 +211,9 @@ func (m *Mapper) FullMapResponse( func (m *Mapper) ReadOnlyMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) + resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) if err != nil { return nil, err } @@ -249,7 +249,6 @@ func (m *Mapper) PeerChangedResponse( node *types.Node, changed map[types.NodeID]bool, patches []*tailcfg.PeerChange, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { resp := m.baseMapResponse() @@ -284,7 +283,7 @@ func (m *Mapper) PeerChangedResponse( err = appendPeerChanges( &resp, false, // partial change - pol, + m.polMan, node, mapRequest.Version, users, @@ -315,7 +314,7 @@ func (m *Mapper) PeerChangedResponse( // Add the node itself, it might have changed, and particularly // if there are no patches or changes, this is a self update. - tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg) + tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg) if err != nil { return nil, err } @@ -330,7 +329,6 @@ func (m *Mapper) PeerChangedPatchResponse( mapRequest tailcfg.MapRequest, node *types.Node, changed []*tailcfg.PeerChange, - pol *policy.ACLPolicy, ) ([]byte, error) { resp := m.baseMapResponse() resp.PeersChangedPatch = changed @@ -459,12 +457,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse { // incremental. func (m *Mapper) baseWithConfigMapResponse( node *types.Node, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, pol, m.cfg) + tailnode, err := tailNode(node, capVer, m.polMan, m.cfg) if err != nil { return nil, err } @@ -517,7 +514,7 @@ func appendPeerChanges( resp *tailcfg.MapResponse, fullChange bool, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, capVer tailcfg.CapabilityVersion, users []types.User, @@ -525,27 +522,24 @@ func appendPeerChanges( changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) - if err != nil { - return err - } + filter := polMan.Filter() - sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) + sshPolicy, err := polMan.SSHPolicy(node) if err != nil { return err } // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. - if len(packetFilter) > 0 { - changed = policy.FilterNodesByACL(node, changed, packetFilter) + if len(filter) > 0 { + changed = policy.FilterNodesByACL(node, changed, filter) } profiles := generateUserProfiles(node, changed) dnsConfig := generateDNSConfig(cfg, node) - tailPeers, err := tailNodes(changed, capVer, pol, cfg) + tailPeers, err := tailNodes(changed, capVer, polMan, cfg) if err != nil { return err } @@ -570,7 +564,7 @@ func appendPeerChanges( // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node, packetFilter), + "base": policy.ReduceFilterRules(node, filter), } } else { // This is a hack to avoid sending an empty list of packet filters. @@ -578,11 +572,11 @@ func appendPeerChanges( // be omitted, causing the client to consider it unchanged, keeping the // previous packet filter. Worst case, this can cause a node that previously // has access to a node to _not_ loose access if an empty (allow none) is sent. - reduced := policy.ReduceFilterRules(node, packetFilter) + reduced := policy.ReduceFilterRules(node, filter) if len(reduced) > 0 { resp.PacketFilter = reduced } else { - resp.PacketFilter = packetFilter + resp.PacketFilter = filter } } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 8dd51808..a1f3eb38 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -461,18 +461,20 @@ func Test_fullMapResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node)) + mappy := NewMapper( nil, tt.cfg, tt.derpMap, nil, + polMan, ) got, err := mappy.fullMapResponse( tt.node, tt.peers, []types.User{user1, user2}, - tt.pol, 0, ) diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 24c521dc..4082df2b 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -14,7 +14,7 @@ import ( func tailNodes( nodes types.Nodes, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) ([]*tailcfg.Node, error) { tNodes := make([]*tailcfg.Node, len(nodes)) @@ -23,7 +23,7 @@ func tailNodes( node, err := tailNode( node, capVer, - pol, + polMan, cfg, ) if err != nil { @@ -40,7 +40,7 @@ func tailNodes( func tailNode( node *types.Node, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) (*tailcfg.Node, error) { addrs := node.Prefixes() @@ -81,7 +81,7 @@ func tailNode( return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) } - tags, _ := pol.TagsOfNode(node) + tags := polMan.Tags(node) tags = lo.Uniq(append(tags, node.ForcedTags...)) tNode := tailcfg.Node{ diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index b6692c16..9d7f1fed 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node}) cfg := &types.Config{ BaseDomain: tt.baseDomain, DNSConfig: tt.dnsConfig, @@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) { got, err := tailNode( tt.node, 0, - tt.pol, + polMan, cfg, ) @@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) { tn, err := tailNode( node, 0, - &policy.ACLPolicy{}, + &policy.PolicyManagerV1{}, &types.Config{}, ) if err != nil { diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go new file mode 100644 index 00000000..a94dd746 --- /dev/null +++ b/hscontrol/policy/pm.go @@ -0,0 +1,164 @@ +package policy + +import ( + "fmt" + "io" + "net/netip" + "os" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/tailcfg" +) + +type PolicyManager interface { + Filter() []tailcfg.FilterRule + SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) + Tags(*types.Node) []string + ApproversForRoute(netip.Prefix) []string + IPsForUser(string) (*netipx.IPSet, error) + SetPolicy([]byte) error + SetUsers(users []types.User) error + SetNodes(nodes types.Nodes) error +} + +func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { + policyFile, err := os.Open(path) + if err != nil { + return nil, err + } + defer policyFile.Close() + + policyBytes, err := io.ReadAll(policyFile) + if err != nil { + return nil, err + } + + return NewPolicyManager(policyBytes, users, nodes) +} + +func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { + pol, err := LoadACLPolicyFromBytes(polB) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) { + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + err := pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +type PolicyManagerV1 struct { + mu sync.Mutex + pol *ACLPolicy + users []types.User + nodes types.Nodes + filter []tailcfg.FilterRule +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManagerV1) updateLocked() error { + filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) + if err != nil { + return fmt.Errorf("compiling filter rules: %w", err) + } + + pm.filter = filter + + return nil +} + +func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.filter +} + +func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) +} + +func (pm *PolicyManagerV1) SetPolicy(polB []byte) error { + pol, err := LoadACLPolicyFromBytes(polB) + if err != nil { + return fmt.Errorf("parsing policy: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.pol = pol + + return pm.updateLocked() +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetUsers(users []types.User) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.users = users + return pm.updateLocked() +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.nodes = nodes + return pm.updateLocked() +} + +func (pm *PolicyManagerV1) Tags(node *types.Node) []string { + if pm == nil { + return nil + } + + tags, _ := pm.pol.TagsOfNode(node) + return tags +} + +func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string { + // TODO(kradalby): This can be a parse error of the address in the policy, + // in the new policy this will be typed and not a problem, in this policy + // we will just return empty list + approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route) + return approvers +} + +func (pm *PolicyManagerV1) IPsForUser(user string) (*netipx.IPSet, error) { + ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, user) + if err != nil { + return nil, err + } + return ips, nil +} diff --git a/hscontrol/poll.go b/hscontrol/poll.go index a8ae01f4..d41744cd 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() { switch update.Type { case types.StateFullUpdate: m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) + data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) case types.StatePeerChanged: changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) @@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() { lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "change" case types.StatePeerChangedPatch: m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) + data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches) updateType = "patch" case types.StatePeerRemoved: changed := make(map[types.NodeID]bool, len(update.Removed)) @@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() { changed[nodeID] = false } m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "remove" case types.StateSelfUpdate: lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) // create the map so an empty (self) update is sent - data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) updateType = "remove" case types.StateDERPUpdated: m.tracef("Sending DERPUpdate MapResponse") @@ -488,9 +488,9 @@ func (m *mapSession) handleEndpointUpdate() { return } - if m.h.ACLPolicy != nil { + if m.h.polMan != nil { // update routes with peer information - err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) + err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node) if err != nil { m.errf(err, "Error running auto approved routes") mapResponseEndpointUpdates.WithLabelValues("error").Inc() @@ -544,7 +544,7 @@ func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleReadOnlyRequest() { m.tracef("Client asked for a lite update, responding without peers") - mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy) + mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node) if err != nil { m.errf(err, "Failed to create MapResponse") http.Error(m.w, "", http.StatusInternalServerError) From 8ecba121cc3b25a2c45718652cd67914ac41755d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 11:43:18 -0500 Subject: [PATCH 3/5] remove unused args Signed-off-by: Kristoffer Dalby --- hscontrol/mapper/mapper.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 6899dc25..5ad66782 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -170,8 +170,6 @@ func (m *Mapper) fullMapResponse( m.polMan, node, capVer, - users, - peers, peers, m.cfg, ) @@ -258,11 +256,6 @@ func (m *Mapper) PeerChangedResponse( return nil, err } - users, err := m.db.ListUsers() - if err != nil { - return nil, fmt.Errorf("listing users for map response: %w", err) - } - var removedIDs []tailcfg.NodeID var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { @@ -286,8 +279,6 @@ func (m *Mapper) PeerChangedResponse( m.polMan, node, mapRequest.Version, - users, - peers, changedNodes, m.cfg, ) @@ -517,8 +508,6 @@ func appendPeerChanges( polMan policy.PolicyManager, node *types.Node, capVer tailcfg.CapabilityVersion, - users []types.User, - peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { From 19bc8b6e01cf4bd690fd29752e6c0bbf5bdd9060 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 12:27:12 -0500 Subject: [PATCH 4/5] report if filter has changed Signed-off-by: Kristoffer Dalby --- hscontrol/grpcv1.go | 26 ++---- hscontrol/mapper/mapper.go | 7 +- hscontrol/mapper/mapper_test.go | 1 - hscontrol/policy/pm.go | 44 +++++---- hscontrol/policy/pm_test.go | 158 ++++++++++++++++++++++++++++++++ 5 files changed, 194 insertions(+), 42 deletions(-) create mode 100644 hscontrol/policy/pm_test.go diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index f907ec0d..a221d519 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -728,20 +728,7 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } - users, err := api.h.db.ListUsers() - if err != nil { - return nil, fmt.Errorf("loading users from database to validate policy: %w", err) - } - - err = api.h.polMan.SetNodes(nodes) - if err != nil { - return nil, fmt.Errorf("setting nodes: %w", err) - } - err = api.h.polMan.SetUsers(users) - if err != nil { - return nil, fmt.Errorf("setting users: %w", err) - } - err = api.h.polMan.SetPolicy([]byte(p)) + changed, err := api.h.polMan.SetPolicy([]byte(p)) if err != nil { return nil, fmt.Errorf("setting policy: %w", err) } @@ -758,10 +745,13 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + // Only send update if the packet filter has changed. + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-update", "na") + api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } response := &v1.SetPolicyResponse{ Policy: updated.Data, diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5ad66782..51c96f8c 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -156,7 +156,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, - users []types.User, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp, err := m.baseWithConfigMapResponse(node, capVer) @@ -190,12 +189,8 @@ func (m *Mapper) FullMapResponse( if err != nil { return nil, err } - users, err := m.db.ListUsers() - if err != nil { - return nil, err - } - resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, mapRequest.Version) if err != nil { return nil, err } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index a1f3eb38..4ee8c644 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -474,7 +474,6 @@ func Test_fullMapResponse(t *testing.T) { got, err := mappy.fullMapResponse( tt.node, tt.peers, - []types.User{user1, user2}, 0, ) diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index a94dd746..8ca9f1db 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" "tailscale.com/tailcfg" + "tailscale.com/util/deephash" ) type PolicyManager interface { @@ -18,9 +19,9 @@ type PolicyManager interface { Tags(*types.Node) []string ApproversForRoute(netip.Prefix) []string IPsForUser(string) (*netipx.IPSet, error) - SetPolicy([]byte) error - SetUsers(users []types.User) error - SetNodes(nodes types.Nodes) error + SetPolicy([]byte) (bool, error) + SetUsers(users []types.User) (bool, error) + SetNodes(nodes types.Nodes) (bool, error) } func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { @@ -50,7 +51,7 @@ func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (Polic nodes: nodes, } - err = pm.updateLocked() + _, err = pm.updateLocked() if err != nil { return nil, err } @@ -65,7 +66,7 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod nodes: nodes, } - err := pm.updateLocked() + _, err := pm.updateLocked() if err != nil { return nil, err } @@ -74,24 +75,33 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod } type PolicyManagerV1 struct { - mu sync.Mutex - pol *ACLPolicy - users []types.User - nodes types.Nodes - filter []tailcfg.FilterRule + mu sync.Mutex + pol *ACLPolicy + + users []types.User + nodes types.Nodes + + filterHash deephash.Sum + filter []tailcfg.FilterRule } // updateLocked updates the filter rules based on the current policy and nodes. // It must be called with the lock held. -func (pm *PolicyManagerV1) updateLocked() error { +func (pm *PolicyManagerV1) updateLocked() (bool, error) { filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) if err != nil { - return fmt.Errorf("compiling filter rules: %w", err) + return false, fmt.Errorf("compiling filter rules: %w", err) + } + + filterHash := deephash.Hash(&filter) + if filterHash == pm.filterHash { + return false, nil } pm.filter = filter + pm.filterHash = filterHash - return nil + return true, nil } func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule { @@ -107,10 +117,10 @@ func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, erro return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) } -func (pm *PolicyManagerV1) SetPolicy(polB []byte) error { +func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) { pol, err := LoadACLPolicyFromBytes(polB) if err != nil { - return fmt.Errorf("parsing policy: %w", err) + return false, fmt.Errorf("parsing policy: %w", err) } pm.mu.Lock() @@ -122,7 +132,7 @@ func (pm *PolicyManagerV1) SetPolicy(polB []byte) error { } // SetUsers updates the users in the policy manager and updates the filter rules. -func (pm *PolicyManagerV1) SetUsers(users []types.User) error { +func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -131,7 +141,7 @@ func (pm *PolicyManagerV1) SetUsers(users []types.User) error { } // SetNodes updates the nodes in the policy manager and updates the filter rules. -func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error { +func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() pm.nodes = nodes diff --git a/hscontrol/policy/pm_test.go b/hscontrol/policy/pm_test.go new file mode 100644 index 00000000..24b78e4d --- /dev/null +++ b/hscontrol/policy/pm_test.go @@ -0,0 +1,158 @@ +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func TestPolicySetChange(t *testing.T) { + users := []types.User{ + { + Model: gorm.Model{ID: 1}, + Name: "testuser", + }, + } + tests := []struct { + name string + users []types.User + nodes types.Nodes + policy []byte + wantUsersChange bool + wantNodesChange bool + wantPolicyChange bool + wantFilter []tailcfg.FilterRule + }{ + { + name: "set-nodes", + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantNodesChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users", + users: users, + wantUsersChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users-and-node", + users: users, + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantUsersChange: false, + wantNodesChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-policy", + policy: []byte(` +{ +"acls": [ + { + "action": "accept", + "src": [ + "100.64.0.61", + ], + "dst": [ + "100.64.0.62:*", + ], + }, + ], +} + `), + wantPolicyChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.61/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol := ` +{ + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.64.0.1", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +` + pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{}) + require.NoError(t, err) + + if tt.policy != nil { + change, err := pm.SetPolicy(tt.policy) + require.NoError(t, err) + + assert.Equal(t, tt.wantPolicyChange, change) + } + + if tt.users != nil { + change, err := pm.SetUsers(tt.users) + require.NoError(t, err) + + assert.Equal(t, tt.wantUsersChange, change) + } + + if tt.nodes != nil { + change, err := pm.SetNodes(tt.nodes) + require.NoError(t, err) + + assert.Equal(t, tt.wantNodesChange, change) + } + + if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { + t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) + } + }) + } +} From 8d5b04f3d3ec6694c3f6a21c70ff966f174dc1f3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 12:53:04 -0500 Subject: [PATCH 5/5] hook up user and node changes to policy Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ hscontrol/auth.go | 7 +++++++ hscontrol/grpcv1.go | 15 +++++++++++++++ hscontrol/oidc.go | 14 ++++++++++++++ 4 files changed, 80 insertions(+) diff --git a/hscontrol/app.go b/hscontrol/app.go index 3489d18f..b4d36caa 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -165,6 +165,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.db, app.nodeNotifier, app.ipAlloc, + app.polMan, ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { @@ -472,6 +473,48 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } +func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + users, err := db.ListUsers() + if err != nil { + return err + } + + changed, err := polMan.SetUsers(users) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + +func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + nodes, err := db.ListNodes() + if err != nil { + return err + } + + changed, err := polMan.SetNodes(nodes) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { if profilingEnabled { @@ -770,6 +813,7 @@ func (h *Headscale) Serve() error { Msg("Received SIGHUP, reloading ACL and Config") // TODO(kradalby): Reload config on SIGHUP + // TODO(kradalby): Only update if we set a new policy if err := h.loadACLPolicy(); err != nil { log.Error().Err(err).Msg("failed to reload ACL policy") } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 67545031..2b23aad3 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey( return } + + err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier) + if err != nil { + http.Error(writer, "Internal server error", http.StatusInternalServerError) + return + } + } err = h.db.Write(func(tx *gorm.DB) error { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a221d519..51134e7e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -57,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -86,6 +91,11 @@ func (api headscaleV1APIServer) DeleteUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.DeleteUserResponse{}, nil } @@ -220,6 +230,11 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } + err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using node: %w", err) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 84267b41..5028e244 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -53,6 +54,7 @@ type AuthProviderOIDC struct { registrationCache *zcache.Cache[string, key.MachinePublic] notifier *notifier.Notifier ipAlloc *db.IPAllocator + polMan policy.PolicyManager oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -65,6 +67,7 @@ func NewAuthProviderOIDC( db *db.HSDatabase, notif *notifier.Notifier, ipAlloc *db.IPAllocator, + polMan policy.PolicyManager, ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already @@ -96,6 +99,7 @@ func NewAuthProviderOIDC( registrationCache: registrationCache, notifier: notif, ipAlloc: ipAlloc, + polMan: polMan, oidcProvider: oidcProvider, oauth2Config: oauth2Config, @@ -461,6 +465,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return nil, fmt.Errorf("creating or updating user: %w", err) } + err = usersChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return user, nil } @@ -484,6 +493,11 @@ func (a *AuthProviderOIDC) registerNode( return fmt.Errorf("could not register node: %w", err) } + err = nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return fmt.Errorf("updating resources using node: %w", err) + } + return nil }