From 50165ce9e1746303a14790fe0b6e9222ca455942 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 21 Oct 2024 11:58:59 -0600 Subject: [PATCH 01/13] resolve user identifier to stable ID currently, the policy approach node to user matching with a quite naive approach looking at the username provided in the policy and matched it with the username on the nodes. This worked ok as long as usernames were unique and did not change. As usernames are no longer guarenteed to be unique in an OIDC environment we cant rely on this. This changes the mechanism that matches the user string (now user token) with nodes: - first find all potential users by looking up: - database ID - provider ID (OIDC) - username/email If more than one user is matching, then the query is rejected, and zero matching nodes are returned. When a single user is found, the node is matched against the User database ID, which are also present on the actual node. This means that from this commit, users can use the following to identify users in the policy: - provider identity (iss + sub) - username - email - database id There are more changes coming to this, so it is not recommended to start using any of these new abilities, with the exception of email, which will not change since it includes an @. Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 8 +- hscontrol/db/node_test.go | 4 +- hscontrol/db/routes.go | 7 +- hscontrol/grpcv1.go | 8 +- hscontrol/mapper/mapper.go | 19 +- hscontrol/mapper/mapper_test.go | 16 +- hscontrol/policy/acls.go | 92 +++++-- hscontrol/policy/acls_test.go | 425 ++++++++++++++++++++++++-------- 8 files changed, 445 insertions(+), 134 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 5c85b064..7fb68bc9 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -1026,14 +1026,18 @@ func (h *Headscale) loadACLPolicy() error { if err != nil { return fmt.Errorf("loading nodes from database to validate policy: %w", err) } + users, err := h.db.ListUsers() + if err != nil { + return fmt.Errorf("loading users from database to validate policy: %w", err) + } - _, err = pol.CompileFilterRules(nodes) + _, err = pol.CompileFilterRules(users, nodes) if err != nil { return fmt.Errorf("verifying policy rules: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) if err != nil { return fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 888f48db..5bcbd546 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(testPeers), check.Equals, 9) - adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) + adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) - testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) + testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 086261aa..c89a10f8 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes( if approvedAlias == node.User.Username() { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { + users, err := ListUsers(tx) + if err != nil { + return fmt.Errorf("looking up users to expand route alias: %w", err) + } + // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) + approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 68793716..e3291d8f 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } + users, err := api.h.db.ListUsers() + if err != nil { + return nil, fmt.Errorf("loading users from database to validate policy: %w", err) + } - _, err = pol.CompileFilterRules(nodes) + _, err = pol.CompileFilterRules(users, nodes) if err != nil { return nil, fmt.Errorf("verifying policy rules: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 3db1e159..5205a112 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, + users []types.User, pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { @@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse( pol, node, capVer, + users, peers, peers, m.cfg, @@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse( if err != nil { return nil, err } + users, err := m.db.ListUsers() + if err != nil { + return nil, err + } - resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) if err != nil { return nil, err } @@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse( return nil, err } + users, err := m.db.ListUsers() + if err != nil { + return nil, fmt.Errorf("listing users for map response: %w", err) + } + var removedIDs []tailcfg.NodeID var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { @@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse( pol, node, mapRequest.Version, + users, peers, changedNodes, m.cfg, @@ -508,16 +520,17 @@ func appendPeerChanges( pol *policy.ACLPolicy, node *types.Node, capVer tailcfg.CapabilityVersion, + users []types.User, peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(append(peers, node)) + packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) if err != nil { return err } - sshPolicy, err := pol.CompileSSHPolicy(node, peers) + sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) if err != nil { return err } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 37ed5c42..8dd51808 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) { lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) + user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"} + user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"} + mini := &types.Node{ ID: 0, MachineKey: mustMK( @@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, @@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.2"), Hostname: "peer1", GivenName: "peer1", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.3"), Hostname: "peer2", GivenName: "peer2", - UserID: 1, - User: types.User{Name: "peer2"}, + UserID: user2.ID, + User: user2, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) { got, err := mappy.fullMapResponse( tt.node, tt.peers, + []types.User{user1, user2}, tt.pol, 0, ) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index ff73985b..8e2d1961 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests( policy *ACLPolicy, node *types.Node, peers types.Nodes, + users []types.User, ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { // If there is no policy defined, we default to allow all if policy == nil { return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.CompileFilterRules(append(peers, node)) + rules, err := policy.CompileFilterRules(users, append(peers, node)) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - sshPolicy, err := policy.CompileSSHPolicy(node, peers) + sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests( // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) CompileFilterRules( + users []types.User, nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { if pol == nil { @@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules( var srcIPs []string for srcIndex, src := range acl.Sources { - srcs, err := pol.expandSource(src, nodes) + srcs, err := pol.expandSource(src, users, nodes) if err != nil { - return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) + return nil, fmt.Errorf( + "parsing policy, acl index: %d->%d: %w", + index, + srcIndex, + err, + ) } srcIPs = append(srcIPs, srcs...) } @@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules( expanded, err := pol.ExpandAlias( nodes, + users, alias, ) if err != nil { @@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F func (pol *ACLPolicy) CompileSSHPolicy( node *types.Node, + users []types.User, peers types.Nodes, ) (*tailcfg.SSHPolicy, error) { if pol == nil { @@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( for index, sshACL := range pol.SSHs { var dest netipx.IPSetBuilder for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, node), src) + expanded, err := pol.ExpandAlias(append(peers, node), users, src) if err != nil { return nil, err } @@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy( case "check": checkAction, err := sshCheckAction(sshACL.CheckPeriod) if err != nil { - return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) + return nil, fmt.Errorf( + "parsing SSH policy, parsing check duration, index: %d: %w", + index, + err, + ) } else { action = *checkAction } default: - return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) + return nil, fmt.Errorf( + "parsing SSH policy, unknown action %q, index: %d: %w", + sshACL.Action, + index, + err, + ) } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( } else { expandedSrcs, err := pol.ExpandAlias( peers, + users, rawSrc, ) if err != nil { @@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) { // with the given src alias. func (pol *ACLPolicy) expandSource( src string, + users []types.User, nodes types.Nodes, ) ([]string, error) { - ipSet, err := pol.ExpandAlias(nodes, src) + ipSet, err := pol.ExpandAlias(nodes, users, src) if err != nil { return []string{}, err } @@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource( // and transform these in IPAddresses. func (pol *ACLPolicy) ExpandAlias( nodes types.Nodes, + users []types.User, alias string, ) (*netipx.IPSet, error) { if isWildcard(alias) { @@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is a group if isGroup(alias) { - return pol.expandIPsFromGroup(alias, nodes) + return pol.expandIPsFromGroup(alias, users, nodes) } // if alias is a tag if isTag(alias) { - return pol.expandIPsFromTag(alias, nodes) + return pol.expandIPsFromTag(alias, users, nodes) } if isAutoGroup(alias) { @@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias( } // if alias is a user - if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { + if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { return ips, err } @@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias( if h, ok := pol.Hosts[alias]; ok { log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.ExpandAlias(nodes, h.String()) + return pol.ExpandAlias(nodes, users, h.String()) } // if alias is an IP @@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup( func (pol *ACLPolicy) expandIPsFromGroup( group string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - users, err := pol.expandUsersFromGroup(group) + userTokens, err := pol.expandUsersFromGroup(group) if err != nil { return &netipx.IPSet{}, err } - for _, user := range users { - filteredNodes := filterNodesByUser(nodes, user) + for _, user := range userTokens { + filteredNodes := filterNodesByUser(nodes, users, user) for _, node := range filteredNodes { node.AppendToIPSet(&build) } @@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromTag( alias string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder @@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag( // filter out nodes per tag owner for _, user := range owners { - nodes := filterNodesByUser(nodes, user) + nodes := filterNodesByUser(nodes, users, user) for _, node := range nodes { if node.Hostinfo == nil { continue @@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromUser( user string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - filteredNodes := filterNodesByUser(nodes, user) + filteredNodes := filterNodesByUser(nodes, users, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) // shortcurcuit if we have no nodes to get ips from. @@ -953,10 +977,40 @@ func (pol *ACLPolicy) TagsOfNode( return validTags, invalidTags } -func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { +// filterNodesByUser returns a list of nodes that match the given userToken from a +// policy. +// Matching nodes are determined by first matching the user token to a user by checking: +// - If it is an ID that mactches the user database ID +// - It is the Provider Identifier from OIDC +// - It matches the username or email of a user +// +// If the token matches more than one user, zero nodes will returned. +func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { var out types.Nodes + + var potentialUsers []types.User + for _, user := range users { + if user.ProviderIdentifier == userToken { + potentialUsers = append(potentialUsers, user) + + break + } + if user.Email == userToken { + potentialUsers = append(potentialUsers, user) + } + if user.Name == userToken { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) != 1 { + return nil + } + + user := potentialUsers[0] + for _, node := range nodes { - if node.User.Username() == user { + if node.User.ID == user.ID { out = append(out, node) } } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 1c6e4de8..f13d7f42 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -2,8 +2,10 @@ package policy import ( "errors" + "math/rand/v2" "net/netip" "slices" + "sort" "testing" "github.com/google/go-cmp/cmp" @@ -14,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -375,18 +378,24 @@ func TestParsing(t *testing.T) { return } - rules, err := pol.CompileFilterRules(types.Nodes{ - &types.Node{ - IPv4: iap("100.100.100.100"), + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser", + } + rules, err := pol.CompileFilterRules( + []types.User{ + user, }, - &types.Node{ - IPv4: iap("200.200.200.200"), - User: types.User{ - Name: "testuser", + types.Nodes{ + &types.Node{ + IPv4: iap("100.100.100.100"), }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) + &types.Node{ + IPv4: iap("200.200.200.200"), + User: user, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -533,7 +542,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.CompileFilterRules(types.Nodes{}) + rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -549,7 +558,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -568,7 +577,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -584,7 +593,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } @@ -861,6 +870,14 @@ func Test_expandPorts(t *testing.T) { } func Test_listNodesInUser(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "marc"}, + {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, + {Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"}, + {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, + {Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, + } + type args struct { nodes types.Nodes user string @@ -874,50 +891,239 @@ func Test_listNodesInUser(t *testing.T) { name: "1 node in user", args: args{ nodes: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, user: "joe", }, want: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, }, { name: "3 nodes, 2 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, user: "marc", }, want: types.Nodes{ - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, }, { name: "5 nodes, 0 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - &types.Node{ID: 4, User: types.User{Name: "marc"}}, - &types.Node{ID: 5, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, + &types.Node{ID: 4, User: users[0]}, + &types.Node{ID: 5, User: users[0]}, }, user: "mickael", }, want: nil, }, + { + name: "match-by-provider-ident", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[2]}, + }, + }, + { + name: "match-by-email", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + }, + }, + { + name: "multi-match-is-zero", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + &types.Node{ID: 3, User: users[3]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-email-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match email, then provider id + &types.Node{ID: 3, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-username-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match username, then provider id + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-duplicate-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-unique-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "marc", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + }, + }, + { + name: "all-users-no-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[1]}, + }, + }, + { + name: "all-users-no-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working@headscale.net", + }, + want: nil, + }, + { + name: "all-users-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 3, User: users[2]}, + }, + }, + { + name: "all-users-no-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "http://oidc.org/4321", + }, + want: nil, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got := filterNodesByUser(test.args.nodes, test.args.user) + for range 1000 { + ns := test.args.nodes + rand.Shuffle(len(ns), func(i, j int) { + ns[i], ns[j] = ns[j], ns[i] + }) + got := filterNodesByUser(ns, users, test.args.user) + sort.Slice(got, func(i, j int) bool { + return got[i].ID < got[j].ID + }) - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) + if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { + t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff) + } } }) } @@ -940,6 +1146,12 @@ func Test_expandAlias(t *testing.T) { return s } + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "joe"}, + {Model: gorm.Model{ID: 2}, Name: "marc"}, + {Model: gorm.Model{ID: 3}, Name: "mickael"}, + } + type field struct { pol ACLPolicy } @@ -989,19 +1201,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1022,19 +1234,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1185,7 +1397,7 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1194,7 +1406,7 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1203,11 +1415,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], }, }, }, @@ -1260,21 +1472,21 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1295,12 +1507,12 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1309,11 +1521,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1350,12 +1562,12 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{}, }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{}, }, }, @@ -1368,6 +1580,7 @@ func Test_expandAlias(t *testing.T) { t.Run(test.name, func(t *testing.T) { got, err := test.field.pol.ExpandAlias( test.args.nodes, + users, test.args.alias, ) if (err != nil) != test.wantErr { @@ -1715,6 +1928,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.field.pol.CompileFilterRules( + []types.User{}, tt.args.nodes, ) if (err != nil) != tt.wantErr { @@ -1834,6 +2048,13 @@ func TestTheInternet(t *testing.T) { } func TestReduceFilterRules(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "mickael"}, + {Model: gorm.Model{ID: 2}, Name: "user1"}, + {Model: gorm.Model{ID: 3}, Name: "user2"}, + {Model: gorm.Model{ID: 4}, Name: "user100"}, + } + tests := []struct { name string node *types.Node @@ -1855,13 +2076,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, + User: users[0], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, + User: users[0], }, }, want: []tailcfg.FilterRule{}, @@ -1888,7 +2109,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -1899,7 +2120,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1967,19 +2188,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, // "internal" exit node &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2026,12 +2247,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2113,7 +2334,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2122,12 +2343,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2215,7 +2436,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -2224,12 +2445,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2292,7 +2513,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -2301,12 +2522,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2362,7 +2583,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -2372,7 +2593,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2400,6 +2621,7 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, _ := tt.pol.CompileFilterRules( + users, append(tt.peers, tt.node), ) @@ -3391,7 +3613,7 @@ func TestSSHRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) + got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers) assert.NoError(t, err) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -3474,14 +3696,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { RequestTags: []string{"tag:test"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + node := &types.Node{ - ID: 0, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 0, - User: types.User{ - Name: "user1", - }, + ID: 0, + Hostname: "testnodes", + IPv4: iap("100.64.0.1"), + UserID: 0, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3498,7 +3723,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3532,7 +3757,8 @@ func TestInvalidTagValidUser(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3549,7 +3775,7 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3583,7 +3809,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3608,7 +3835,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3637,15 +3864,17 @@ func TestValidTagInvalidUser(t *testing.T) { Hostname: "webserver", RequestTags: []string{"tag:webapp"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } node := &types.Node{ - ID: 1, - Hostname: "webserver", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 1, + Hostname: "webserver", + IPv4: iap("100.64.0.1"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3656,13 +3885,11 @@ func TestValidTagInvalidUser(t *testing.T) { } nodes2 := &types.Node{ - ID: 2, - Hostname: "user", - IPv4: iap("100.64.0.2"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 2, + Hostname: "user", + IPv4: iap("100.64.0.2"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo2, } @@ -3678,7 +3905,7 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ From 6afb554e20cb06ed535ce9abd29a7abbdb8a72e4 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 11:42:14 -0500 Subject: [PATCH 02/13] wrap policy in policy manager interface Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 33 ++++--- hscontrol/db/node_test.go | 15 ++- hscontrol/db/routes.go | 20 +--- hscontrol/grpcv1.go | 27 +++--- hscontrol/mapper/mapper.go | 44 ++++----- hscontrol/mapper/mapper_test.go | 4 +- hscontrol/mapper/tail.go | 8 +- hscontrol/mapper/tail_test.go | 5 +- hscontrol/policy/pm.go | 164 ++++++++++++++++++++++++++++++++ hscontrol/poll.go | 16 ++-- 10 files changed, 246 insertions(+), 90 deletions(-) create mode 100644 hscontrol/policy/pm.go diff --git a/hscontrol/app.go b/hscontrol/app.go index 7fb68bc9..3489d18f 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -88,7 +88,7 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - ACLPolicy *policy.ACLPolicy + polMan policy.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier @@ -499,7 +499,7 @@ func (h *Headscale) Serve() error { // Fetch an initial DERP Map before we start serving h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier) + h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan) if h.cfg.DERP.ServerEnabled { // When embedded DERP is enabled we always need a STUN server @@ -774,7 +774,7 @@ func (h *Headscale) Serve() error { log.Error().Err(err).Msg("failed to reload ACL policy") } - if h.ACLPolicy != nil { + if h.polMan != nil { log.Info(). Msg("ACL policy successfully reloaded, notifying nodes of change") @@ -995,8 +995,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { func (h *Headscale) loadACLPolicy() error { var ( - pol *policy.ACLPolicy - err error + pm policy.PolicyManager ) switch h.cfg.Policy.Mode { @@ -1009,10 +1008,6 @@ func (h *Headscale) loadACLPolicy() error { } absPath := util.AbsolutePathFromConfigPath(path) - pol, err = policy.LoadACLPolicyFromPath(absPath) - if err != nil { - return fmt.Errorf("failed to load ACL policy from file: %w", err) - } // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still @@ -1031,13 +1026,13 @@ func (h *Headscale) loadACLPolicy() error { return fmt.Errorf("loading users from database to validate policy: %w", err) } - _, err = pol.CompileFilterRules(users, nodes) + pm, err = policy.NewPolicyManagerFromPath(absPath, users, nodes) if err != nil { - return fmt.Errorf("verifying policy rules: %w", err) + return fmt.Errorf("loading policy from file: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) + _, err = pm.SSHPolicy(nodes[0]) if err != nil { return fmt.Errorf("verifying SSH rules: %w", err) } @@ -1053,9 +1048,17 @@ func (h *Headscale) loadACLPolicy() error { return fmt.Errorf("failed to get policy from database: %w", err) } - pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data)) + nodes, err := h.db.ListNodes() if err != nil { - return fmt.Errorf("failed to parse policy: %w", err) + return fmt.Errorf("loading nodes from database to validate policy: %w", err) + } + users, err := h.db.ListUsers() + if err != nil { + return fmt.Errorf("loading users from database to validate policy: %w", err) + } + pm, err = policy.NewPolicyManager([]byte(p.Data), users, nodes) + if err != nil { + return fmt.Errorf("loading policy from database: %w", err) } default: log.Fatal(). @@ -1063,7 +1066,7 @@ func (h *Headscale) loadACLPolicy() error { Msg("Unknown ACL policy mode") } - h.ACLPolicy = pol + h.polMan = pm return nil } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 5bcbd546..46dce68b 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -559,10 +559,6 @@ func TestAutoApproveRoutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { adb, err := newTestDB() assert.NoError(t, err) - pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) - - assert.NoError(t, err) - assert.NotNil(t, pol) user, err := adb.CreateUser("test") assert.NoError(t, err) @@ -599,8 +595,17 @@ func TestAutoApproveRoutes(t *testing.T) { node0ByID, err := adb.GetNodeByID(0) assert.NoError(t, err) + users, err := adb.ListUsers() + assert.NoError(t, err) + + nodes, err := adb.ListNodes() + assert.NoError(t, err) + + pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes) + assert.NoError(t, err) + // TODO(kradalby): Check state update - err = adb.EnableAutoApprovedRoutes(pol, node0ByID) + err = adb.EnableAutoApprovedRoutes(pm, node0ByID) assert.NoError(t, err) enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index c89a10f8..ebb08563 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -598,18 +598,18 @@ func failoverRoute( } func (hsdb *HSDatabase) EnableAutoApprovedRoutes( - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { return hsdb.Write(func(tx *gorm.DB) error { - return EnableAutoApprovedRoutes(tx, aclPolicy, node) + return EnableAutoApprovedRoutes(tx, polMan, node) }) } // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. func EnableAutoApprovedRoutes( tx *gorm.DB, - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { if node.IPv4 == nil && node.IPv6 == nil { @@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes( continue } - routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( - netip.Prefix(advertisedRoute.Prefix), - ) - if err != nil { - return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err) - } + routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix)) log.Trace(). Str("node", node.Hostname). @@ -648,13 +643,8 @@ func EnableAutoApprovedRoutes( if approvedAlias == node.User.Username() { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { - users, err := ListUsers(tx) - if err != nil { - return fmt.Errorf("looking up users to expand route alias: %w", err) - } - // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias) + approvedIps, err := polMan.IPsForUser(approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index e3291d8f..f907ec0d 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -21,7 +21,6 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" ) @@ -450,10 +449,7 @@ func (api headscaleV1APIServer) ListNodes( resp.Online = true } - validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - node, - ) - resp.InvalidTags = invalidTags + validTags := api.h.polMan.Tags(node) resp.ValidTags = validTags response[index] = resp } @@ -723,11 +719,6 @@ func (api headscaleV1APIServer) SetPolicy( p := request.GetPolicy() - pol, err := policy.LoadACLPolicyFromBytes([]byte(p)) - if err != nil { - return nil, fmt.Errorf("loading ACL policy file: %w", err) - } - // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -742,13 +733,21 @@ func (api headscaleV1APIServer) SetPolicy( return nil, fmt.Errorf("loading users from database to validate policy: %w", err) } - _, err = pol.CompileFilterRules(users, nodes) + err = api.h.polMan.SetNodes(nodes) if err != nil { - return nil, fmt.Errorf("verifying policy rules: %w", err) + return nil, fmt.Errorf("setting nodes: %w", err) + } + err = api.h.polMan.SetUsers(users) + if err != nil { + return nil, fmt.Errorf("setting users: %w", err) + } + err = api.h.polMan.SetPolicy([]byte(p)) + if err != nil { + return nil, fmt.Errorf("setting policy: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], users, nodes) + _, err = api.h.polMan.SSHPolicy(nodes[0]) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } @@ -759,8 +758,6 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - api.h.ACLPolicy = pol - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ Type: types.StateFullUpdate, diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5205a112..6899dc25 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -55,6 +55,7 @@ type Mapper struct { cfg *types.Config derpMap *tailcfg.DERPMap notif *notifier.Notifier + polMan policy.PolicyManager uid string created time.Time @@ -71,6 +72,7 @@ func NewMapper( cfg *types.Config, derpMap *tailcfg.DERPMap, notif *notifier.Notifier, + polMan policy.PolicyManager, ) *Mapper { uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) @@ -79,6 +81,7 @@ func NewMapper( cfg: cfg, derpMap: derpMap, notif: notif, + polMan: polMan, uid: uid, created: time.Now(), @@ -154,10 +157,9 @@ func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, users []types.User, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, capVer) + resp, err := m.baseWithConfigMapResponse(node, capVer) if err != nil { return nil, err } @@ -165,7 +167,7 @@ func (m *Mapper) fullMapResponse( err = appendPeerChanges( resp, true, // full change - pol, + m.polMan, node, capVer, users, @@ -184,7 +186,6 @@ func (m *Mapper) fullMapResponse( func (m *Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { peers, err := m.ListPeers(node.ID) @@ -196,7 +197,7 @@ func (m *Mapper) FullMapResponse( return nil, err } - resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version) if err != nil { return nil, err } @@ -210,10 +211,9 @@ func (m *Mapper) FullMapResponse( func (m *Mapper) ReadOnlyMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) + resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) if err != nil { return nil, err } @@ -249,7 +249,6 @@ func (m *Mapper) PeerChangedResponse( node *types.Node, changed map[types.NodeID]bool, patches []*tailcfg.PeerChange, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { resp := m.baseMapResponse() @@ -284,7 +283,7 @@ func (m *Mapper) PeerChangedResponse( err = appendPeerChanges( &resp, false, // partial change - pol, + m.polMan, node, mapRequest.Version, users, @@ -315,7 +314,7 @@ func (m *Mapper) PeerChangedResponse( // Add the node itself, it might have changed, and particularly // if there are no patches or changes, this is a self update. - tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg) + tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg) if err != nil { return nil, err } @@ -330,7 +329,6 @@ func (m *Mapper) PeerChangedPatchResponse( mapRequest tailcfg.MapRequest, node *types.Node, changed []*tailcfg.PeerChange, - pol *policy.ACLPolicy, ) ([]byte, error) { resp := m.baseMapResponse() resp.PeersChangedPatch = changed @@ -459,12 +457,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse { // incremental. func (m *Mapper) baseWithConfigMapResponse( node *types.Node, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, pol, m.cfg) + tailnode, err := tailNode(node, capVer, m.polMan, m.cfg) if err != nil { return nil, err } @@ -517,7 +514,7 @@ func appendPeerChanges( resp *tailcfg.MapResponse, fullChange bool, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, capVer tailcfg.CapabilityVersion, users []types.User, @@ -525,27 +522,24 @@ func appendPeerChanges( changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(users, append(peers, node)) - if err != nil { - return err - } + filter := polMan.Filter() - sshPolicy, err := pol.CompileSSHPolicy(node, users, peers) + sshPolicy, err := polMan.SSHPolicy(node) if err != nil { return err } // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. - if len(packetFilter) > 0 { - changed = policy.FilterNodesByACL(node, changed, packetFilter) + if len(filter) > 0 { + changed = policy.FilterNodesByACL(node, changed, filter) } profiles := generateUserProfiles(node, changed) dnsConfig := generateDNSConfig(cfg, node) - tailPeers, err := tailNodes(changed, capVer, pol, cfg) + tailPeers, err := tailNodes(changed, capVer, polMan, cfg) if err != nil { return err } @@ -570,7 +564,7 @@ func appendPeerChanges( // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node, packetFilter), + "base": policy.ReduceFilterRules(node, filter), } } else { // This is a hack to avoid sending an empty list of packet filters. @@ -578,11 +572,11 @@ func appendPeerChanges( // be omitted, causing the client to consider it unchanged, keeping the // previous packet filter. Worst case, this can cause a node that previously // has access to a node to _not_ loose access if an empty (allow none) is sent. - reduced := policy.ReduceFilterRules(node, packetFilter) + reduced := policy.ReduceFilterRules(node, filter) if len(reduced) > 0 { resp.PacketFilter = reduced } else { - resp.PacketFilter = packetFilter + resp.PacketFilter = filter } } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 8dd51808..a1f3eb38 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -461,18 +461,20 @@ func Test_fullMapResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node)) + mappy := NewMapper( nil, tt.cfg, tt.derpMap, nil, + polMan, ) got, err := mappy.fullMapResponse( tt.node, tt.peers, []types.User{user1, user2}, - tt.pol, 0, ) diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 24c521dc..4082df2b 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -14,7 +14,7 @@ import ( func tailNodes( nodes types.Nodes, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) ([]*tailcfg.Node, error) { tNodes := make([]*tailcfg.Node, len(nodes)) @@ -23,7 +23,7 @@ func tailNodes( node, err := tailNode( node, capVer, - pol, + polMan, cfg, ) if err != nil { @@ -40,7 +40,7 @@ func tailNodes( func tailNode( node *types.Node, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) (*tailcfg.Node, error) { addrs := node.Prefixes() @@ -81,7 +81,7 @@ func tailNode( return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) } - tags, _ := pol.TagsOfNode(node) + tags := polMan.Tags(node) tags = lo.Uniq(append(tags, node.ForcedTags...)) tNode := tailcfg.Node{ diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index b6692c16..9d7f1fed 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node}) cfg := &types.Config{ BaseDomain: tt.baseDomain, DNSConfig: tt.dnsConfig, @@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) { got, err := tailNode( tt.node, 0, - tt.pol, + polMan, cfg, ) @@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) { tn, err := tailNode( node, 0, - &policy.ACLPolicy{}, + &policy.PolicyManagerV1{}, &types.Config{}, ) if err != nil { diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go new file mode 100644 index 00000000..a94dd746 --- /dev/null +++ b/hscontrol/policy/pm.go @@ -0,0 +1,164 @@ +package policy + +import ( + "fmt" + "io" + "net/netip" + "os" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/tailcfg" +) + +type PolicyManager interface { + Filter() []tailcfg.FilterRule + SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) + Tags(*types.Node) []string + ApproversForRoute(netip.Prefix) []string + IPsForUser(string) (*netipx.IPSet, error) + SetPolicy([]byte) error + SetUsers(users []types.User) error + SetNodes(nodes types.Nodes) error +} + +func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { + policyFile, err := os.Open(path) + if err != nil { + return nil, err + } + defer policyFile.Close() + + policyBytes, err := io.ReadAll(policyFile) + if err != nil { + return nil, err + } + + return NewPolicyManager(policyBytes, users, nodes) +} + +func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { + pol, err := LoadACLPolicyFromBytes(polB) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) { + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + err := pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +type PolicyManagerV1 struct { + mu sync.Mutex + pol *ACLPolicy + users []types.User + nodes types.Nodes + filter []tailcfg.FilterRule +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManagerV1) updateLocked() error { + filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) + if err != nil { + return fmt.Errorf("compiling filter rules: %w", err) + } + + pm.filter = filter + + return nil +} + +func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.filter +} + +func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) +} + +func (pm *PolicyManagerV1) SetPolicy(polB []byte) error { + pol, err := LoadACLPolicyFromBytes(polB) + if err != nil { + return fmt.Errorf("parsing policy: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.pol = pol + + return pm.updateLocked() +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetUsers(users []types.User) error { + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.users = users + return pm.updateLocked() +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.nodes = nodes + return pm.updateLocked() +} + +func (pm *PolicyManagerV1) Tags(node *types.Node) []string { + if pm == nil { + return nil + } + + tags, _ := pm.pol.TagsOfNode(node) + return tags +} + +func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string { + // TODO(kradalby): This can be a parse error of the address in the policy, + // in the new policy this will be typed and not a problem, in this policy + // we will just return empty list + approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route) + return approvers +} + +func (pm *PolicyManagerV1) IPsForUser(user string) (*netipx.IPSet, error) { + ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, user) + if err != nil { + return nil, err + } + return ips, nil +} diff --git a/hscontrol/poll.go b/hscontrol/poll.go index a8ae01f4..d41744cd 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() { switch update.Type { case types.StateFullUpdate: m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) + data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) case types.StatePeerChanged: changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) @@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() { lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "change" case types.StatePeerChangedPatch: m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) + data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches) updateType = "patch" case types.StatePeerRemoved: changed := make(map[types.NodeID]bool, len(update.Removed)) @@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() { changed[nodeID] = false } m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "remove" case types.StateSelfUpdate: lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) // create the map so an empty (self) update is sent - data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) updateType = "remove" case types.StateDERPUpdated: m.tracef("Sending DERPUpdate MapResponse") @@ -488,9 +488,9 @@ func (m *mapSession) handleEndpointUpdate() { return } - if m.h.ACLPolicy != nil { + if m.h.polMan != nil { // update routes with peer information - err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) + err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node) if err != nil { m.errf(err, "Error running auto approved routes") mapResponseEndpointUpdates.WithLabelValues("error").Inc() @@ -544,7 +544,7 @@ func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleReadOnlyRequest() { m.tracef("Client asked for a lite update, responding without peers") - mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy) + mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node) if err != nil { m.errf(err, "Failed to create MapResponse") http.Error(m.w, "", http.StatusInternalServerError) From 8ecba121cc3b25a2c45718652cd67914ac41755d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 11:43:18 -0500 Subject: [PATCH 03/13] remove unused args Signed-off-by: Kristoffer Dalby --- hscontrol/mapper/mapper.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 6899dc25..5ad66782 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -170,8 +170,6 @@ func (m *Mapper) fullMapResponse( m.polMan, node, capVer, - users, - peers, peers, m.cfg, ) @@ -258,11 +256,6 @@ func (m *Mapper) PeerChangedResponse( return nil, err } - users, err := m.db.ListUsers() - if err != nil { - return nil, fmt.Errorf("listing users for map response: %w", err) - } - var removedIDs []tailcfg.NodeID var changedIDs []types.NodeID for nodeID, nodeChanged := range changed { @@ -286,8 +279,6 @@ func (m *Mapper) PeerChangedResponse( m.polMan, node, mapRequest.Version, - users, - peers, changedNodes, m.cfg, ) @@ -517,8 +508,6 @@ func appendPeerChanges( polMan policy.PolicyManager, node *types.Node, capVer tailcfg.CapabilityVersion, - users []types.User, - peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { From 19bc8b6e01cf4bd690fd29752e6c0bbf5bdd9060 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 12:27:12 -0500 Subject: [PATCH 04/13] report if filter has changed Signed-off-by: Kristoffer Dalby --- hscontrol/grpcv1.go | 26 ++---- hscontrol/mapper/mapper.go | 7 +- hscontrol/mapper/mapper_test.go | 1 - hscontrol/policy/pm.go | 44 +++++---- hscontrol/policy/pm_test.go | 158 ++++++++++++++++++++++++++++++++ 5 files changed, 194 insertions(+), 42 deletions(-) create mode 100644 hscontrol/policy/pm_test.go diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index f907ec0d..a221d519 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -728,20 +728,7 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } - users, err := api.h.db.ListUsers() - if err != nil { - return nil, fmt.Errorf("loading users from database to validate policy: %w", err) - } - - err = api.h.polMan.SetNodes(nodes) - if err != nil { - return nil, fmt.Errorf("setting nodes: %w", err) - } - err = api.h.polMan.SetUsers(users) - if err != nil { - return nil, fmt.Errorf("setting users: %w", err) - } - err = api.h.polMan.SetPolicy([]byte(p)) + changed, err := api.h.polMan.SetPolicy([]byte(p)) if err != nil { return nil, fmt.Errorf("setting policy: %w", err) } @@ -758,10 +745,13 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + // Only send update if the packet filter has changed. + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-update", "na") + api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } response := &v1.SetPolicyResponse{ Policy: updated.Data, diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5ad66782..51c96f8c 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -156,7 +156,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, - users []types.User, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp, err := m.baseWithConfigMapResponse(node, capVer) @@ -190,12 +189,8 @@ func (m *Mapper) FullMapResponse( if err != nil { return nil, err } - users, err := m.db.ListUsers() - if err != nil { - return nil, err - } - resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, mapRequest.Version) if err != nil { return nil, err } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index a1f3eb38..4ee8c644 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -474,7 +474,6 @@ func Test_fullMapResponse(t *testing.T) { got, err := mappy.fullMapResponse( tt.node, tt.peers, - []types.User{user1, user2}, 0, ) diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index a94dd746..8ca9f1db 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" "tailscale.com/tailcfg" + "tailscale.com/util/deephash" ) type PolicyManager interface { @@ -18,9 +19,9 @@ type PolicyManager interface { Tags(*types.Node) []string ApproversForRoute(netip.Prefix) []string IPsForUser(string) (*netipx.IPSet, error) - SetPolicy([]byte) error - SetUsers(users []types.User) error - SetNodes(nodes types.Nodes) error + SetPolicy([]byte) (bool, error) + SetUsers(users []types.User) (bool, error) + SetNodes(nodes types.Nodes) (bool, error) } func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { @@ -50,7 +51,7 @@ func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (Polic nodes: nodes, } - err = pm.updateLocked() + _, err = pm.updateLocked() if err != nil { return nil, err } @@ -65,7 +66,7 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod nodes: nodes, } - err := pm.updateLocked() + _, err := pm.updateLocked() if err != nil { return nil, err } @@ -74,24 +75,33 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod } type PolicyManagerV1 struct { - mu sync.Mutex - pol *ACLPolicy - users []types.User - nodes types.Nodes - filter []tailcfg.FilterRule + mu sync.Mutex + pol *ACLPolicy + + users []types.User + nodes types.Nodes + + filterHash deephash.Sum + filter []tailcfg.FilterRule } // updateLocked updates the filter rules based on the current policy and nodes. // It must be called with the lock held. -func (pm *PolicyManagerV1) updateLocked() error { +func (pm *PolicyManagerV1) updateLocked() (bool, error) { filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) if err != nil { - return fmt.Errorf("compiling filter rules: %w", err) + return false, fmt.Errorf("compiling filter rules: %w", err) + } + + filterHash := deephash.Hash(&filter) + if filterHash == pm.filterHash { + return false, nil } pm.filter = filter + pm.filterHash = filterHash - return nil + return true, nil } func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule { @@ -107,10 +117,10 @@ func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, erro return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) } -func (pm *PolicyManagerV1) SetPolicy(polB []byte) error { +func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) { pol, err := LoadACLPolicyFromBytes(polB) if err != nil { - return fmt.Errorf("parsing policy: %w", err) + return false, fmt.Errorf("parsing policy: %w", err) } pm.mu.Lock() @@ -122,7 +132,7 @@ func (pm *PolicyManagerV1) SetPolicy(polB []byte) error { } // SetUsers updates the users in the policy manager and updates the filter rules. -func (pm *PolicyManagerV1) SetUsers(users []types.User) error { +func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -131,7 +141,7 @@ func (pm *PolicyManagerV1) SetUsers(users []types.User) error { } // SetNodes updates the nodes in the policy manager and updates the filter rules. -func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error { +func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() pm.nodes = nodes diff --git a/hscontrol/policy/pm_test.go b/hscontrol/policy/pm_test.go new file mode 100644 index 00000000..24b78e4d --- /dev/null +++ b/hscontrol/policy/pm_test.go @@ -0,0 +1,158 @@ +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func TestPolicySetChange(t *testing.T) { + users := []types.User{ + { + Model: gorm.Model{ID: 1}, + Name: "testuser", + }, + } + tests := []struct { + name string + users []types.User + nodes types.Nodes + policy []byte + wantUsersChange bool + wantNodesChange bool + wantPolicyChange bool + wantFilter []tailcfg.FilterRule + }{ + { + name: "set-nodes", + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantNodesChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users", + users: users, + wantUsersChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users-and-node", + users: users, + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantUsersChange: false, + wantNodesChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-policy", + policy: []byte(` +{ +"acls": [ + { + "action": "accept", + "src": [ + "100.64.0.61", + ], + "dst": [ + "100.64.0.62:*", + ], + }, + ], +} + `), + wantPolicyChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.61/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol := ` +{ + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.64.0.1", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +` + pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{}) + require.NoError(t, err) + + if tt.policy != nil { + change, err := pm.SetPolicy(tt.policy) + require.NoError(t, err) + + assert.Equal(t, tt.wantPolicyChange, change) + } + + if tt.users != nil { + change, err := pm.SetUsers(tt.users) + require.NoError(t, err) + + assert.Equal(t, tt.wantUsersChange, change) + } + + if tt.nodes != nil { + change, err := pm.SetNodes(tt.nodes) + require.NoError(t, err) + + assert.Equal(t, tt.wantNodesChange, change) + } + + if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { + t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) + } + }) + } +} From 8d5b04f3d3ec6694c3f6a21c70ff966f174dc1f3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 12:53:04 -0500 Subject: [PATCH 05/13] hook up user and node changes to policy Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ hscontrol/auth.go | 7 +++++++ hscontrol/grpcv1.go | 15 +++++++++++++++ hscontrol/oidc.go | 14 ++++++++++++++ 4 files changed, 80 insertions(+) diff --git a/hscontrol/app.go b/hscontrol/app.go index 3489d18f..b4d36caa 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -165,6 +165,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.db, app.nodeNotifier, app.ipAlloc, + app.polMan, ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { @@ -472,6 +473,48 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } +func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + users, err := db.ListUsers() + if err != nil { + return err + } + + changed, err := polMan.SetUsers(users) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + +func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + nodes, err := db.ListNodes() + if err != nil { + return err + } + + changed, err := polMan.SetNodes(nodes) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { if profilingEnabled { @@ -770,6 +813,7 @@ func (h *Headscale) Serve() error { Msg("Received SIGHUP, reloading ACL and Config") // TODO(kradalby): Reload config on SIGHUP + // TODO(kradalby): Only update if we set a new policy if err := h.loadACLPolicy(); err != nil { log.Error().Err(err).Msg("failed to reload ACL policy") } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 67545031..2b23aad3 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey( return } + + err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier) + if err != nil { + http.Error(writer, "Internal server error", http.StatusInternalServerError) + return + } + } err = h.db.Write(func(tx *gorm.DB) error { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a221d519..51134e7e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -57,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -86,6 +91,11 @@ func (api headscaleV1APIServer) DeleteUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.DeleteUserResponse{}, nil } @@ -220,6 +230,11 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } + err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using node: %w", err) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 84267b41..5028e244 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -53,6 +54,7 @@ type AuthProviderOIDC struct { registrationCache *zcache.Cache[string, key.MachinePublic] notifier *notifier.Notifier ipAlloc *db.IPAllocator + polMan policy.PolicyManager oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -65,6 +67,7 @@ func NewAuthProviderOIDC( db *db.HSDatabase, notif *notifier.Notifier, ipAlloc *db.IPAllocator, + polMan policy.PolicyManager, ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already @@ -96,6 +99,7 @@ func NewAuthProviderOIDC( registrationCache: registrationCache, notifier: notif, ipAlloc: ipAlloc, + polMan: polMan, oidcProvider: oidcProvider, oauth2Config: oauth2Config, @@ -461,6 +465,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return nil, fmt.Errorf("creating or updating user: %w", err) } + err = usersChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return user, nil } @@ -484,6 +493,11 @@ func (a *AuthProviderOIDC) registerNode( return fmt.Errorf("could not register node: %w", err) } + err = nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return fmt.Errorf("updating resources using node: %w", err) + } + return nil } From 50b62ddfb3b81755e42b0fbb12b8b09cca8a2535 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 18:19:14 -0400 Subject: [PATCH 06/13] fix loading policy manager Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 116 ++++++++++++++++++++--------------------- hscontrol/policy/pm.go | 10 ++-- 2 files changed, 65 insertions(+), 61 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index b4d36caa..a0e105ca 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -88,7 +88,8 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - polMan policy.PolicyManager + polManOnce sync.Once + polMan policy.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier @@ -531,8 +532,7 @@ func (h *Headscale) Serve() error { } var err error - - if err = h.loadACLPolicy(); err != nil { + if err = h.loadPolicyManager(); err != nil { return fmt.Errorf("failed to load ACL policy: %w", err) } @@ -814,7 +814,7 @@ func (h *Headscale) Serve() error { // TODO(kradalby): Reload config on SIGHUP // TODO(kradalby): Only update if we set a new policy - if err := h.loadACLPolicy(); err != nil { + if err := h.loadPolicyManager(); err != nil { log.Error().Err(err).Msg("failed to reload ACL policy") } @@ -1037,22 +1037,9 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } -func (h *Headscale) loadACLPolicy() error { - var ( - pm policy.PolicyManager - ) - - switch h.cfg.Policy.Mode { - case types.PolicyModeFile: - path := h.cfg.Policy.Path - - // It is fine to start headscale without a policy file. - if len(path) == 0 { - return nil - } - - absPath := util.AbsolutePathFromConfigPath(path) - +func (h *Headscale) loadPolicyManager() error { + var errOut error + h.polManOnce.Do(func() { // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -1063,54 +1050,67 @@ func (h *Headscale) loadACLPolicy() error { // allowed to be written to the database. nodes, err := h.db.ListNodes() if err != nil { - return fmt.Errorf("loading nodes from database to validate policy: %w", err) + errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err) + return } users, err := h.db.ListUsers() if err != nil { - return fmt.Errorf("loading users from database to validate policy: %w", err) + errOut = fmt.Errorf("loading users from database to validate policy: %w", err) + return } - pm, err = policy.NewPolicyManagerFromPath(absPath, users, nodes) - if err != nil { - return fmt.Errorf("loading policy from file: %w", err) - } + switch h.cfg.Policy.Mode { + case types.PolicyModeFile: + path := h.cfg.Policy.Path - if len(nodes) > 0 { - _, err = pm.SSHPolicy(nodes[0]) + // It is fine to start headscale without a policy file. + if len(path) == 0 { + h.polMan, err = policy.NewPolicyManager(nil, users, nodes) + if err != nil { + errOut = fmt.Errorf("policy manager with no policy: %w", err) + } + + return + } + + absPath := util.AbsolutePathFromConfigPath(path) + + h.polMan, err = policy.NewPolicyManagerFromPath(absPath, users, nodes) if err != nil { - return fmt.Errorf("verifying SSH rules: %w", err) - } - } - - case types.PolicyModeDB: - p, err := h.db.GetPolicy() - if err != nil { - if errors.Is(err, types.ErrPolicyNotFound) { - return nil + errOut = fmt.Errorf("loading policy from file (%s): %w", absPath, err) + return } - return fmt.Errorf("failed to get policy from database: %w", err) - } + if len(nodes) > 0 { + _, err = h.polMan.SSHPolicy(nodes[0]) + if err != nil { + errOut = fmt.Errorf("verifying SSH rules: %w", err) + return + } + } - nodes, err := h.db.ListNodes() - if err != nil { - return fmt.Errorf("loading nodes from database to validate policy: %w", err) - } - users, err := h.db.ListUsers() - if err != nil { - return fmt.Errorf("loading users from database to validate policy: %w", err) - } - pm, err = policy.NewPolicyManager([]byte(p.Data), users, nodes) - if err != nil { - return fmt.Errorf("loading policy from database: %w", err) - } - default: - log.Fatal(). - Str("mode", string(h.cfg.Policy.Mode)). - Msg("Unknown ACL policy mode") - } + case types.PolicyModeDB: + p, err := h.db.GetPolicy() + if err != nil { + if errors.Is(err, types.ErrPolicyNotFound) { + return + } - h.polMan = pm + errOut = fmt.Errorf("failed to get policy from database: %w", err) + return + } - return nil + h.polMan, err = policy.NewPolicyManager([]byte(p.Data), users, nodes) + if err != nil { + errOut = fmt.Errorf("loading policy from database: %w", err) + return + } + default: + log.Fatal(). + Str("mode", string(h.cfg.Policy.Mode)). + Msg("Unknown ACL policy mode") + } + }) + + return errOut } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 8ca9f1db..a5d736bd 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -40,9 +40,13 @@ func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes } func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { - pol, err := LoadACLPolicyFromBytes(polB) - if err != nil { - return nil, fmt.Errorf("parsing policy: %w", err) + var pol *ACLPolicy + var err error + if polB != nil && len(polB) > 0 { + pol, err = LoadACLPolicyFromBytes(polB) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } } pm := PolicyManagerV1{ From f2ab5e05c91a5b7ada10250504c363eddfeb55da Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 18:29:27 -0400 Subject: [PATCH 07/13] split out reading policy and applying Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 111 ++++++++++++++++++++++++++--------------------- 1 file changed, 62 insertions(+), 49 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index a0e105ca..7c63f389 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -812,13 +812,21 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received SIGHUP, reloading ACL and Config") - // TODO(kradalby): Reload config on SIGHUP - // TODO(kradalby): Only update if we set a new policy if err := h.loadPolicyManager(); err != nil { - log.Error().Err(err).Msg("failed to reload ACL policy") + log.Error().Err(err).Msg("failed to reload Policy") } - if h.polMan != nil { + pol, err := h.policyBytes() + if err != nil { + log.Error().Err(err).Msg("failed to get policy blob") + } + + changed, err := h.polMan.SetPolicy(pol) + if err != nil { + log.Error().Err(err).Msg("failed to set new policy") + } + + if changed { log.Info(). Msg("ACL policy successfully reloaded, notifying nodes of change") @@ -1037,6 +1045,43 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } +// policyBytes returns the appropriate policy for the +// current configuration as a []byte array. +func (h *Headscale) policyBytes() ([]byte, error) { + switch h.cfg.Policy.Mode { + case types.PolicyModeFile: + path := h.cfg.Policy.Path + + // It is fine to start headscale without a policy file. + if len(path) == 0 { + return nil, nil + } + + absPath := util.AbsolutePathFromConfigPath(path) + policyFile, err := os.Open(absPath) + if err != nil { + return nil, err + } + defer policyFile.Close() + + return io.ReadAll(policyFile) + + case types.PolicyModeDB: + p, err := h.db.GetPolicy() + if err != nil { + if errors.Is(err, types.ErrPolicyNotFound) { + return nil, nil + } + + return nil, err + } + + return []byte(p.Data), err + } + + return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode) +} + func (h *Headscale) loadPolicyManager() error { var errOut error h.polManOnce.Do(func() { @@ -1059,56 +1104,24 @@ func (h *Headscale) loadPolicyManager() error { return } - switch h.cfg.Policy.Mode { - case types.PolicyModeFile: - path := h.cfg.Policy.Path + pol, err := h.policyBytes() + if err != nil { + errOut = fmt.Errorf("loading policy bytes: %w", err) + return + } - // It is fine to start headscale without a policy file. - if len(path) == 0 { - h.polMan, err = policy.NewPolicyManager(nil, users, nodes) - if err != nil { - errOut = fmt.Errorf("policy manager with no policy: %w", err) - } + h.polMan, err = policy.NewPolicyManager(pol, users, nodes) + if err != nil { + errOut = fmt.Errorf("creating policy manager: %w", err) + return + } - return - } - - absPath := util.AbsolutePathFromConfigPath(path) - - h.polMan, err = policy.NewPolicyManagerFromPath(absPath, users, nodes) + if len(nodes) > 0 { + _, err = h.polMan.SSHPolicy(nodes[0]) if err != nil { - errOut = fmt.Errorf("loading policy from file (%s): %w", absPath, err) + errOut = fmt.Errorf("verifying SSH rules: %w", err) return } - - if len(nodes) > 0 { - _, err = h.polMan.SSHPolicy(nodes[0]) - if err != nil { - errOut = fmt.Errorf("verifying SSH rules: %w", err) - return - } - } - - case types.PolicyModeDB: - p, err := h.db.GetPolicy() - if err != nil { - if errors.Is(err, types.ErrPolicyNotFound) { - return - } - - errOut = fmt.Errorf("failed to get policy from database: %w", err) - return - } - - h.polMan, err = policy.NewPolicyManager([]byte(p.Data), users, nodes) - if err != nil { - errOut = fmt.Errorf("loading policy from database: %w", err) - return - } - default: - log.Fatal(). - Str("mode", string(h.cfg.Policy.Mode)). - Msg("Unknown ACL policy mode") } }) From 7f665023d82fd74984486a6872d2e679dc609521 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 27 Oct 2024 11:50:47 -0400 Subject: [PATCH 08/13] fix nil pointer in oidc for policy Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 7c63f389..eea7315b 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -154,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } }) + if err = app.loadPolicyManager(); err != nil { + return nil, fmt.Errorf("failed to load ACL policy: %w", err) + } + var authProvider AuthProvider authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { @@ -531,11 +535,6 @@ func (h *Headscale) Serve() error { } } - var err error - if err = h.loadPolicyManager(); err != nil { - return fmt.Errorf("failed to load ACL policy: %w", err) - } - if dumpConfig { spew.Dump(h.cfg) } From dbf2faa4bf55410c530ec2eb2cc8fc0728d67a49 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 27 Oct 2024 12:56:56 -0400 Subject: [PATCH 09/13] fix nil in router Signed-off-by: Kristoffer Dalby --- hscontrol/policy/pm.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index a5d736bd..0e175557 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -165,6 +165,9 @@ func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string { // TODO(kradalby): This can be a parse error of the address in the policy, // in the new policy this will be typed and not a problem, in this policy // we will just return empty list + if pm.pol == nil { + return nil + } approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route) return approvers } From 85a038cfcac932047abd8665d2d5173fdd4e59d0 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 27 Oct 2024 13:53:45 -0400 Subject: [PATCH 10/13] tags approved via acl Signed-off-by: Kristoffer Dalby --- integration/cli_test.go | 167 ++++++++++++++++----------------------- integration/hsic/hsic.go | 4 + 2 files changed, 72 insertions(+), 99 deletions(-) diff --git a/integration/cli_test.go b/integration/cli_test.go index 2b81e814..e8ace2f1 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -13,6 +13,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" ) func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { @@ -786,117 +787,85 @@ func TestNodeTagCommand(t *testing.T) { ) } -func TestNodeAdvertiseTagNoACLCommand(t *testing.T) { +func TestNodeAdvertiseTagCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags")) - assertNoErr(t, err) - - headscale, err := scenario.Headscale() - assertNoErr(t, err) - - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", + tests := []struct { + name string + policy *policy.ACLPolicy + wantTag bool + }{ + { + name: "no-policy", + wantTag: false, }, - &resultMachines, - ) - assert.Nil(t, err) - found := false - for _, node := range resultMachines { - if node.GetInvalidTags() != nil { - for _, tag := range node.GetInvalidTags() { - if tag == "tag:test" { - found = true - } - } - } - } - assert.Equal( - t, - true, - found, - "should not find a node with the tag 'tag:test' in the list of nodes", - ) -} - -func TestNodeAdvertiseTagWithACLCommand(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy( - &policy.ACLPolicy{ - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + { + name: "with-policy", + policy: &policy.ACLPolicy{ + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + TagOwners: map[string][]string{ + "tag:test": {"user1"}, }, }, - TagOwners: map[string][]string{ - "tag:exists": {"user1"}, - }, + wantTag: true, }, - )) - assertNoErr(t, err) + } - headscale, err := scenario.Headscale() - assertNoErr(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + // defer scenario.ShutdownAssertNoPanics(t) - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", - }, - &resultMachines, - ) - assert.Nil(t, err) - found := false - for _, node := range resultMachines { - if node.GetValidTags() != nil { - for _, tag := range node.GetValidTags() { - if tag == "tag:exists" { - found = true + spec := map[string]int{ + "user1": 1, + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{tsic.WithTags([]string{"tag:test"})}, + hsic.WithTestName("cliadvtags"), + hsic.WithACLPolicy(tt.policy), + ) + assertNoErr(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + // Test list all nodes after added seconds + resultMachines := make([]*v1.Node, spec["user1"]) + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--tags", + "--output", "json", + }, + &resultMachines, + ) + assert.Nil(t, err) + found := false + for _, node := range resultMachines { + if tags := node.GetValidTags(); tags != nil { + found = slices.Contains(tags, "tag:test") } } - } + assert.Equalf( + t, + tt.wantTag, + found, + "'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag, + ) + }) } - assert.Equal( - t, - true, - found, - "should not find a node with the tag 'tag:exists' in the list of nodes", - ) } func TestNodeCommand(t *testing.T) { diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index c2ae3336..b747b8b1 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -81,6 +81,10 @@ type Option = func(c *HeadscaleInContainer) // HeadscaleInContainer instance. func WithACLPolicy(acl *policy.ACLPolicy) Option { return func(hsic *HeadscaleInContainer) { + if acl == nil { + return + } + // TODO(kradalby): Move somewhere appropriate hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath From 24f3895b2bd6446c9451efd9a01368979248eadf Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 27 Oct 2024 13:58:55 -0400 Subject: [PATCH 11/13] update error string Signed-off-by: Kristoffer Dalby --- integration/cli_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration/cli_test.go b/integration/cli_test.go index e8ace2f1..19e865cd 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1701,7 +1701,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath, }, ) - assert.ErrorContains(t, err, "verifying policy rules: invalid action") + assert.ErrorContains(t, err, "compiling filter rules: invalid action") // The new policy was invalid, the old one should still be in place, which // is none. From a942fcf50ace9ef0059f189fc399421159057b75 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 29 Oct 2024 08:40:15 -0400 Subject: [PATCH 12/13] fix autoapprove Signed-off-by: Kristoffer Dalby --- hscontrol/app.go | 4 ++++ hscontrol/db/routes.go | 2 +- hscontrol/policy/pm.go | 6 +++--- hscontrol/poll.go | 3 +++ 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index eea7315b..6859bf49 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -478,6 +478,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } +// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// Maybe we should attempt a new in memory state and not go via the DB? func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { users, err := db.ListUsers() if err != nil { @@ -499,6 +501,8 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not return nil } +// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// Maybe we should attempt a new in memory state and not go via the DB? func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { nodes, err := db.ListNodes() if err != nil { diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index ebb08563..dcf238ab 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -644,7 +644,7 @@ func EnableAutoApprovedRoutes( approvedRoutes = append(approvedRoutes, advertisedRoute) } else { // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := polMan.IPsForUser(approvedAlias) + approvedIps, err := polMan.ExpandAlias(approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 0e175557..7dbaed33 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -18,7 +18,7 @@ type PolicyManager interface { SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) Tags(*types.Node) []string ApproversForRoute(netip.Prefix) []string - IPsForUser(string) (*netipx.IPSet, error) + ExpandAlias(string) (*netipx.IPSet, error) SetPolicy([]byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes types.Nodes) (bool, error) @@ -172,8 +172,8 @@ func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string { return approvers } -func (pm *PolicyManagerV1) IPsForUser(user string) (*netipx.IPSet, error) { - ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, user) +func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) { + ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias) if err != nil { return nil, err } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index d41744cd..e6047d45 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -488,6 +488,9 @@ func (m *mapSession) handleEndpointUpdate() { return } + // TODO(kradalby): Only update the node that has actually changed + nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier) + if m.h.polMan != nil { // update routes with peer information err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node) From 014ee87066198460fcd8b99b31427ae2322e38b8 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 29 Oct 2024 15:02:53 -0400 Subject: [PATCH 13/13] update integration test build Signed-off-by: Kristoffer Dalby --- .github/workflows/test-integration.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 65324f77..d177bdb8 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -30,8 +30,7 @@ jobs: - TestPreAuthKeyCorrectUserLoggedInCommand - TestApiKeyCommand - TestNodeTagCommand - - TestNodeAdvertiseTagNoACLCommand - - TestNodeAdvertiseTagWithACLCommand + - TestNodeAdvertiseTagCommand - TestNodeCommand - TestNodeExpireCommand - TestNodeRenameCommand