mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
report if filter has changed
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
8ecba121cc
commit
19bc8b6e01
5 changed files with 194 additions and 42 deletions
|
@ -728,20 +728,7 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||||
}
|
}
|
||||||
users, err := api.h.db.ListUsers()
|
changed, err := api.h.polMan.SetPolicy([]byte(p))
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = api.h.polMan.SetNodes(nodes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setting nodes: %w", err)
|
|
||||||
}
|
|
||||||
err = api.h.polMan.SetUsers(users)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setting users: %w", err)
|
|
||||||
}
|
|
||||||
err = api.h.polMan.SetPolicy([]byte(p))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setting policy: %w", err)
|
return nil, fmt.Errorf("setting policy: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -758,10 +745,13 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only send update if the packet filter has changed.
|
||||||
|
if changed {
|
||||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
Type: types.StateFullUpdate,
|
Type: types.StateFullUpdate,
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
response := &v1.SetPolicyResponse{
|
response := &v1.SetPolicyResponse{
|
||||||
Policy: updated.Data,
|
Policy: updated.Data,
|
||||||
|
|
|
@ -156,7 +156,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||||
func (m *Mapper) fullMapResponse(
|
func (m *Mapper) fullMapResponse(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
users []types.User,
|
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||||
|
@ -190,12 +189,8 @@ func (m *Mapper) FullMapResponse(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
users, err := m.db.ListUsers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version)
|
resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -474,7 +474,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
got, err := mappy.fullMapResponse(
|
got, err := mappy.fullMapResponse(
|
||||||
tt.node,
|
tt.node,
|
||||||
tt.peers,
|
tt.peers,
|
||||||
[]types.User{user1, user2},
|
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/util/deephash"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PolicyManager interface {
|
type PolicyManager interface {
|
||||||
|
@ -18,9 +19,9 @@ type PolicyManager interface {
|
||||||
Tags(*types.Node) []string
|
Tags(*types.Node) []string
|
||||||
ApproversForRoute(netip.Prefix) []string
|
ApproversForRoute(netip.Prefix) []string
|
||||||
IPsForUser(string) (*netipx.IPSet, error)
|
IPsForUser(string) (*netipx.IPSet, error)
|
||||||
SetPolicy([]byte) error
|
SetPolicy([]byte) (bool, error)
|
||||||
SetUsers(users []types.User) error
|
SetUsers(users []types.User) (bool, error)
|
||||||
SetNodes(nodes types.Nodes) error
|
SetNodes(nodes types.Nodes) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||||
|
@ -50,7 +51,7 @@ func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (Polic
|
||||||
nodes: nodes,
|
nodes: nodes,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pm.updateLocked()
|
_, err = pm.updateLocked()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -65,7 +66,7 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod
|
||||||
nodes: nodes,
|
nodes: nodes,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := pm.updateLocked()
|
_, err := pm.updateLocked()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -76,22 +77,31 @@ func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nod
|
||||||
type PolicyManagerV1 struct {
|
type PolicyManagerV1 struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
pol *ACLPolicy
|
pol *ACLPolicy
|
||||||
|
|
||||||
users []types.User
|
users []types.User
|
||||||
nodes types.Nodes
|
nodes types.Nodes
|
||||||
|
|
||||||
|
filterHash deephash.Sum
|
||||||
filter []tailcfg.FilterRule
|
filter []tailcfg.FilterRule
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||||
// It must be called with the lock held.
|
// It must be called with the lock held.
|
||||||
func (pm *PolicyManagerV1) updateLocked() error {
|
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
|
||||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("compiling filter rules: %w", err)
|
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filterHash := deephash.Hash(&filter)
|
||||||
|
if filterHash == pm.filterHash {
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.filter = filter
|
pm.filter = filter
|
||||||
|
pm.filterHash = filterHash
|
||||||
|
|
||||||
return nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||||
|
@ -107,10 +117,10 @@ func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, erro
|
||||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PolicyManagerV1) SetPolicy(polB []byte) error {
|
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
|
||||||
pol, err := LoadACLPolicyFromBytes(polB)
|
pol, err := LoadACLPolicyFromBytes(polB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing policy: %w", err)
|
return false, fmt.Errorf("parsing policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
|
@ -122,7 +132,7 @@ func (pm *PolicyManagerV1) SetPolicy(polB []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||||
func (pm *PolicyManagerV1) SetUsers(users []types.User) error {
|
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
@ -131,7 +141,7 @@ func (pm *PolicyManagerV1) SetUsers(users []types.User) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||||
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error {
|
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
pm.nodes = nodes
|
pm.nodes = nodes
|
||||||
|
|
158
hscontrol/policy/pm_test.go
Normal file
158
hscontrol/policy/pm_test.go
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicySetChange(t *testing.T) {
|
||||||
|
users := []types.User{
|
||||||
|
{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "testuser",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
users []types.User
|
||||||
|
nodes types.Nodes
|
||||||
|
policy []byte
|
||||||
|
wantUsersChange bool
|
||||||
|
wantNodesChange bool
|
||||||
|
wantPolicyChange bool
|
||||||
|
wantFilter []tailcfg.FilterRule
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "set-nodes",
|
||||||
|
nodes: types.Nodes{
|
||||||
|
{
|
||||||
|
IPv4: iap("100.64.0.2"),
|
||||||
|
User: users[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantNodesChange: false,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-users",
|
||||||
|
users: users,
|
||||||
|
wantUsersChange: false,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-users-and-node",
|
||||||
|
users: users,
|
||||||
|
nodes: types.Nodes{
|
||||||
|
{
|
||||||
|
IPv4: iap("100.64.0.2"),
|
||||||
|
User: users[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantUsersChange: false,
|
||||||
|
wantNodesChange: true,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"100.64.0.2/32"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-policy",
|
||||||
|
policy: []byte(`
|
||||||
|
{
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"100.64.0.61",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"100.64.0.62:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`),
|
||||||
|
wantPolicyChange: true,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"100.64.0.61/32"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pol := `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:example": [
|
||||||
|
"testuser",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.64.0.1",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"group:example",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`
|
||||||
|
pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tt.policy != nil {
|
||||||
|
change, err := pm.SetPolicy(tt.policy)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantPolicyChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.users != nil {
|
||||||
|
change, err := pm.SetUsers(tt.users)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantUsersChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.nodes != nil {
|
||||||
|
change, err := pm.SetNodes(tt.nodes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantNodesChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
|
||||||
|
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue