From 19bc8b6e01cf4bd690fd29752e6c0bbf5bdd9060 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 26 Oct 2024 12:27:12 -0500 Subject: [PATCH] report if filter has changed Signed-off-by: Kristoffer Dalby --- hscontrol/grpcv1.go | 26 ++---- hscontrol/mapper/mapper.go | 7 +- hscontrol/mapper/mapper_test.go | 1 - hscontrol/policy/pm.go | 44 +++++---- hscontrol/policy/pm_test.go | 158 ++++++++++++++++++++++++++++++++ 5 files changed, 194 insertions(+), 42 deletions(-) create mode 100644 hscontrol/policy/pm_test.go diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index f907ec0d..a221d519 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -728,20 +728,7 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } - users, err := api.h.db.ListUsers() - if err != nil { - return nil, fmt.Errorf("loading users from database to validate policy: %w", err) - } - - err = api.h.polMan.SetNodes(nodes) - if err != nil { - return nil, fmt.Errorf("setting nodes: %w", err) - } - err = api.h.polMan.SetUsers(users) - if err != nil { - return nil, fmt.Errorf("setting users: %w", err) - } - err = api.h.polMan.SetPolicy([]byte(p)) + changed, err := api.h.polMan.SetPolicy([]byte(p)) if err != nil { return nil, fmt.Errorf("setting policy: %w", err) } @@ -758,10 +745,13 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + // Only send update if the packet filter has changed. + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-update", "na") + api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } response := &v1.SetPolicyResponse{ Policy: updated.Data, diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 5ad66782..51c96f8c 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -156,7 +156,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, - users []types.User, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp, err := m.baseWithConfigMapResponse(node, capVer) @@ -190,12 +189,8 @@ func (m *Mapper) FullMapResponse( if err != nil { return nil, err } - users, err := m.db.ListUsers() - if err != nil { - return nil, err - } - resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, mapRequest.Version) if err != nil { return nil, err } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index a1f3eb38..4ee8c644 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -474,7 +474,6 @@ func Test_fullMapResponse(t *testing.T) { got, err := mappy.fullMapResponse( tt.node, tt.peers, - []types.User{user1, user2}, 0, ) diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index a94dd746..8ca9f1db 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" "tailscale.com/tailcfg" + "tailscale.com/util/deephash" ) type PolicyManager interface { @@ -18,9 +19,9 @@ type PolicyManager interface { Tags(*types.Node) []string ApproversForRoute(netip.Prefix) []string IPsForUser(string) (*netipx.IPSet, error) - SetPolicy([]byte) error - SetUsers(users []types.User) error - SetNodes(nodes types.Nodes) error + SetPolicy([]byte) (bool, error) + SetUsers(users []types.User) (bool, error) + SetNodes(nodes types.Nodes) (bool, error) } func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { @@ -50,7 +51,7 @@ func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (Polic nodes: nodes, } - err = pm.updateLocked() + _, err = pm.updateLocked() if err != nil { return nil, err } @@ -65,7 +66,7 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod nodes: nodes, } - err := pm.updateLocked() + _, err := pm.updateLocked() if err != nil { return nil, err } @@ -74,24 +75,33 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod } type PolicyManagerV1 struct { - mu sync.Mutex - pol *ACLPolicy - users []types.User - nodes types.Nodes - filter []tailcfg.FilterRule + mu sync.Mutex + pol *ACLPolicy + + users []types.User + nodes types.Nodes + + filterHash deephash.Sum + filter []tailcfg.FilterRule } // updateLocked updates the filter rules based on the current policy and nodes. // It must be called with the lock held. -func (pm *PolicyManagerV1) updateLocked() error { +func (pm *PolicyManagerV1) updateLocked() (bool, error) { filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) if err != nil { - return fmt.Errorf("compiling filter rules: %w", err) + return false, fmt.Errorf("compiling filter rules: %w", err) + } + + filterHash := deephash.Hash(&filter) + if filterHash == pm.filterHash { + return false, nil } pm.filter = filter + pm.filterHash = filterHash - return nil + return true, nil } func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule { @@ -107,10 +117,10 @@ func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, erro return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) } -func (pm *PolicyManagerV1) SetPolicy(polB []byte) error { +func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) { pol, err := LoadACLPolicyFromBytes(polB) if err != nil { - return fmt.Errorf("parsing policy: %w", err) + return false, fmt.Errorf("parsing policy: %w", err) } pm.mu.Lock() @@ -122,7 +132,7 @@ func (pm *PolicyManagerV1) SetPolicy(polB []byte) error { } // SetUsers updates the users in the policy manager and updates the filter rules. -func (pm *PolicyManagerV1) SetUsers(users []types.User) error { +func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -131,7 +141,7 @@ func (pm *PolicyManagerV1) SetUsers(users []types.User) error { } // SetNodes updates the nodes in the policy manager and updates the filter rules. -func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error { +func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() pm.nodes = nodes diff --git a/hscontrol/policy/pm_test.go b/hscontrol/policy/pm_test.go new file mode 100644 index 00000000..24b78e4d --- /dev/null +++ b/hscontrol/policy/pm_test.go @@ -0,0 +1,158 @@ +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func TestPolicySetChange(t *testing.T) { + users := []types.User{ + { + Model: gorm.Model{ID: 1}, + Name: "testuser", + }, + } + tests := []struct { + name string + users []types.User + nodes types.Nodes + policy []byte + wantUsersChange bool + wantNodesChange bool + wantPolicyChange bool + wantFilter []tailcfg.FilterRule + }{ + { + name: "set-nodes", + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantNodesChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users", + users: users, + wantUsersChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users-and-node", + users: users, + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantUsersChange: false, + wantNodesChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-policy", + policy: []byte(` +{ +"acls": [ + { + "action": "accept", + "src": [ + "100.64.0.61", + ], + "dst": [ + "100.64.0.62:*", + ], + }, + ], +} + `), + wantPolicyChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.61/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol := ` +{ + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.64.0.1", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +` + pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{}) + require.NoError(t, err) + + if tt.policy != nil { + change, err := pm.SetPolicy(tt.policy) + require.NoError(t, err) + + assert.Equal(t, tt.wantPolicyChange, change) + } + + if tt.users != nil { + change, err := pm.SetUsers(tt.users) + require.NoError(t, err) + + assert.Equal(t, tt.wantUsersChange, change) + } + + if tt.nodes != nil { + change, err := pm.SetNodes(tt.nodes) + require.NoError(t, err) + + assert.Equal(t, tt.wantNodesChange, change) + } + + if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { + t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) + } + }) + } +}