mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
report if filter has changed
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
8ecba121cc
commit
19bc8b6e01
5 changed files with 194 additions and 42 deletions
|
@ -728,20 +728,7 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||
}
|
||||
users, err := api.h.db.ListUsers()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||
}
|
||||
|
||||
err = api.h.polMan.SetNodes(nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting nodes: %w", err)
|
||||
}
|
||||
err = api.h.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting users: %w", err)
|
||||
}
|
||||
err = api.h.polMan.SetPolicy([]byte(p))
|
||||
changed, err := api.h.polMan.SetPolicy([]byte(p))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting policy: %w", err)
|
||||
}
|
||||
|
@ -758,10 +745,13 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Only send update if the packet filter has changed.
|
||||
if changed {
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StateFullUpdate,
|
||||
})
|
||||
}
|
||||
|
||||
response := &v1.SetPolicyResponse{
|
||||
Policy: updated.Data,
|
||||
|
|
|
@ -156,7 +156,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
|||
func (m *Mapper) fullMapResponse(
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
users []types.User,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||
|
@ -190,12 +189,8 @@ func (m *Mapper) FullMapResponse(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users, err := m.db.ListUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version)
|
||||
resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -474,7 +474,6 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
got, err := mappy.fullMapResponse(
|
||||
tt.node,
|
||||
tt.peers,
|
||||
[]types.User{user1, user2},
|
||||
0,
|
||||
)
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
type PolicyManager interface {
|
||||
|
@ -18,9 +19,9 @@ type PolicyManager interface {
|
|||
Tags(*types.Node) []string
|
||||
ApproversForRoute(netip.Prefix) []string
|
||||
IPsForUser(string) (*netipx.IPSet, error)
|
||||
SetPolicy([]byte) error
|
||||
SetUsers(users []types.User) error
|
||||
SetNodes(nodes types.Nodes) error
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes types.Nodes) (bool, error)
|
||||
}
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
|
@ -50,7 +51,7 @@ func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (Polic
|
|||
nodes: nodes,
|
||||
}
|
||||
|
||||
err = pm.updateLocked()
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -65,7 +66,7 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod
|
|||
nodes: nodes,
|
||||
}
|
||||
|
||||
err := pm.updateLocked()
|
||||
_, err := pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -76,22 +77,31 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod
|
|||
type PolicyManagerV1 struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManagerV1) updateLocked() error {
|
||||
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compiling filter rules: %w", err)
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
if filterHash == pm.filterHash {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||
|
@ -107,10 +117,10 @@ func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, erro
|
|||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SetPolicy(polB []byte) error {
|
||||
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing policy: %w", err)
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
|
@ -122,7 +132,7 @@ func (pm *PolicyManagerV1) SetPolicy(polB []byte) error {
|
|||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetUsers(users []types.User) error {
|
||||
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
|
@ -131,7 +141,7 @@ func (pm *PolicyManagerV1) SetUsers(users []types.User) error {
|
|||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error {
|
||||
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
|
|
158
hscontrol/policy/pm_test.go
Normal file
158
hscontrol/policy/pm_test.go
Normal file
|
@ -0,0 +1,158 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestPolicySetChange(t *testing.T) {
|
||||
users := []types.User{
|
||||
{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "testuser",
|
||||
},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
policy []byte
|
||||
wantUsersChange bool
|
||||
wantNodesChange bool
|
||||
wantPolicyChange bool
|
||||
wantFilter []tailcfg.FilterRule
|
||||
}{
|
||||
{
|
||||
name: "set-nodes",
|
||||
nodes: types.Nodes{
|
||||
{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: users[0],
|
||||
},
|
||||
},
|
||||
wantNodesChange: false,
|
||||
wantFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-users",
|
||||
users: users,
|
||||
wantUsersChange: false,
|
||||
wantFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-users-and-node",
|
||||
users: users,
|
||||
nodes: types.Nodes{
|
||||
{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: users[0],
|
||||
},
|
||||
},
|
||||
wantUsersChange: false,
|
||||
wantNodesChange: true,
|
||||
wantFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.2/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set-policy",
|
||||
policy: []byte(`
|
||||
{
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"100.64.0.61",
|
||||
],
|
||||
"dst": [
|
||||
"100.64.0.62:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`),
|
||||
wantPolicyChange: true,
|
||||
wantFilter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.64.0.61/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pol := `
|
||||
{
|
||||
"groups": {
|
||||
"group:example": [
|
||||
"testuser",
|
||||
],
|
||||
},
|
||||
|
||||
"hosts": {
|
||||
"host-1": "100.64.0.1",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"group:example",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`
|
||||
pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.policy != nil {
|
||||
change, err := pm.SetPolicy(tt.policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.wantPolicyChange, change)
|
||||
}
|
||||
|
||||
if tt.users != nil {
|
||||
change, err := pm.SetUsers(tt.users)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.wantUsersChange, change)
|
||||
}
|
||||
|
||||
if tt.nodes != nil {
|
||||
change, err := pm.SetNodes(tt.nodes)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.wantNodesChange, change)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
|
||||
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue