report if filter has changed

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-26 12:27:12 -05:00
parent 8ecba121cc
commit 19bc8b6e01
No known key found for this signature in database
5 changed files with 194 additions and 42 deletions

View file

@ -728,20 +728,7 @@ func (api headscaleV1APIServer) SetPolicy(
if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
}
users, err := api.h.db.ListUsers()
if err != nil {
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
}
err = api.h.polMan.SetNodes(nodes)
if err != nil {
return nil, fmt.Errorf("setting nodes: %w", err)
}
err = api.h.polMan.SetUsers(users)
if err != nil {
return nil, fmt.Errorf("setting users: %w", err)
}
err = api.h.polMan.SetPolicy([]byte(p))
changed, err := api.h.polMan.SetPolicy([]byte(p))
if err != nil {
return nil, fmt.Errorf("setting policy: %w", err)
}
@ -758,10 +745,13 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err
}
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
// Only send update if the packet filter has changed.
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}
response := &v1.SetPolicyResponse{
Policy: updated.Data,

View file

@ -156,7 +156,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
func (m *Mapper) fullMapResponse(
node *types.Node,
peers types.Nodes,
users []types.User,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, capVer)
@ -190,12 +189,8 @@ func (m *Mapper) FullMapResponse(
if err != nil {
return nil, err
}
users, err := m.db.ListUsers()
if err != nil {
return nil, err
}
resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version)
resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
if err != nil {
return nil, err
}

View file

@ -474,7 +474,6 @@ func Test_fullMapResponse(t *testing.T) {
got, err := mappy.fullMapResponse(
tt.node,
tt.peers,
[]types.User{user1, user2},
0,
)

View file

@ -10,6 +10,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
)
type PolicyManager interface {
@ -18,9 +19,9 @@ type PolicyManager interface {
Tags(*types.Node) []string
ApproversForRoute(netip.Prefix) []string
IPsForUser(string) (*netipx.IPSet, error)
SetPolicy([]byte) error
SetUsers(users []types.User) error
SetNodes(nodes types.Nodes) error
SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes types.Nodes) (bool, error)
}
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
@ -50,7 +51,7 @@ func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (Polic
nodes: nodes,
}
err = pm.updateLocked()
_, err = pm.updateLocked()
if err != nil {
return nil, err
}
@ -65,7 +66,7 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod
nodes: nodes,
}
err := pm.updateLocked()
_, err := pm.updateLocked()
if err != nil {
return nil, err
}
@ -74,24 +75,33 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod
}
type PolicyManagerV1 struct {
mu sync.Mutex
pol *ACLPolicy
users []types.User
nodes types.Nodes
filter []tailcfg.FilterRule
mu sync.Mutex
pol *ACLPolicy
users []types.User
nodes types.Nodes
filterHash deephash.Sum
filter []tailcfg.FilterRule
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManagerV1) updateLocked() error {
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return fmt.Errorf("compiling filter rules: %w", err)
return false, fmt.Errorf("compiling filter rules: %w", err)
}
filterHash := deephash.Hash(&filter)
if filterHash == pm.filterHash {
return false, nil
}
pm.filter = filter
pm.filterHash = filterHash
return nil
return true, nil
}
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
@ -107,10 +117,10 @@ func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, erro
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
}
func (pm *PolicyManagerV1) SetPolicy(polB []byte) error {
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
pol, err := LoadACLPolicyFromBytes(polB)
if err != nil {
return fmt.Errorf("parsing policy: %w", err)
return false, fmt.Errorf("parsing policy: %w", err)
}
pm.mu.Lock()
@ -122,7 +132,7 @@ func (pm *PolicyManagerV1) SetPolicy(polB []byte) error {
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManagerV1) SetUsers(users []types.User) error {
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
@ -131,7 +141,7 @@ func (pm *PolicyManagerV1) SetUsers(users []types.User) error {
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error {
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes

158
hscontrol/policy/pm_test.go Normal file
View file

@ -0,0 +1,158 @@
package policy
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func TestPolicySetChange(t *testing.T) {
users := []types.User{
{
Model: gorm.Model{ID: 1},
Name: "testuser",
},
}
tests := []struct {
name string
users []types.User
nodes types.Nodes
policy []byte
wantUsersChange bool
wantNodesChange bool
wantPolicyChange bool
wantFilter []tailcfg.FilterRule
}{
{
name: "set-nodes",
nodes: types.Nodes{
{
IPv4: iap("100.64.0.2"),
User: users[0],
},
},
wantNodesChange: false,
wantFilter: []tailcfg.FilterRule{
{
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-users",
users: users,
wantUsersChange: false,
wantFilter: []tailcfg.FilterRule{
{
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-users-and-node",
users: users,
nodes: types.Nodes{
{
IPv4: iap("100.64.0.2"),
User: users[0],
},
},
wantUsersChange: false,
wantNodesChange: true,
wantFilter: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.2/32"},
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-policy",
policy: []byte(`
{
"acls": [
{
"action": "accept",
"src": [
"100.64.0.61",
],
"dst": [
"100.64.0.62:*",
],
},
],
}
`),
wantPolicyChange: true,
wantFilter: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.61/32"},
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pol := `
{
"groups": {
"group:example": [
"testuser",
],
},
"hosts": {
"host-1": "100.64.0.1",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"group:example",
],
"dst": [
"host-1:*",
],
},
],
}
`
pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
require.NoError(t, err)
if tt.policy != nil {
change, err := pm.SetPolicy(tt.policy)
require.NoError(t, err)
assert.Equal(t, tt.wantPolicyChange, change)
}
if tt.users != nil {
change, err := pm.SetUsers(tt.users)
require.NoError(t, err)
assert.Equal(t, tt.wantUsersChange, change)
}
if tt.nodes != nil {
change, err := pm.SetNodes(tt.nodes)
require.NoError(t, err)
assert.Equal(t, tt.wantNodesChange, change)
}
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
}
})
}
}