mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-02 03:33:05 +00:00
wrap policy in policy manager interface
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
50165ce9e1
commit
6afb554e20
10 changed files with 246 additions and 90 deletions
|
@ -88,7 +88,7 @@ type Headscale struct {
|
|||
DERPMap *tailcfg.DERPMap
|
||||
DERPServer *derpServer.DERPServer
|
||||
|
||||
ACLPolicy *policy.ACLPolicy
|
||||
polMan policy.PolicyManager
|
||||
|
||||
mapper *mapper.Mapper
|
||||
nodeNotifier *notifier.Notifier
|
||||
|
@ -499,7 +499,7 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
// Fetch an initial DERP Map before we start serving
|
||||
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
||||
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier)
|
||||
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)
|
||||
|
||||
if h.cfg.DERP.ServerEnabled {
|
||||
// When embedded DERP is enabled we always need a STUN server
|
||||
|
@ -774,7 +774,7 @@ func (h *Headscale) Serve() error {
|
|||
log.Error().Err(err).Msg("failed to reload ACL policy")
|
||||
}
|
||||
|
||||
if h.ACLPolicy != nil {
|
||||
if h.polMan != nil {
|
||||
log.Info().
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
||||
|
@ -995,8 +995,7 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|||
|
||||
func (h *Headscale) loadACLPolicy() error {
|
||||
var (
|
||||
pol *policy.ACLPolicy
|
||||
err error
|
||||
pm policy.PolicyManager
|
||||
)
|
||||
|
||||
switch h.cfg.Policy.Mode {
|
||||
|
@ -1009,10 +1008,6 @@ func (h *Headscale) loadACLPolicy() error {
|
|||
}
|
||||
|
||||
absPath := util.AbsolutePathFromConfigPath(path)
|
||||
pol, err = policy.LoadACLPolicyFromPath(absPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load ACL policy from file: %w", err)
|
||||
}
|
||||
|
||||
// Validate and reject configuration that would error when applied
|
||||
// when creating a map response. This requires nodes, so there is still
|
||||
|
@ -1031,13 +1026,13 @@ func (h *Headscale) loadACLPolicy() error {
|
|||
return fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||
}
|
||||
|
||||
_, err = pol.CompileFilterRules(users, nodes)
|
||||
pm, err = policy.NewPolicyManagerFromPath(absPath, users, nodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verifying policy rules: %w", err)
|
||||
return fmt.Errorf("loading policy from file: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
|
||||
_, err = pm.SSHPolicy(nodes[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("verifying SSH rules: %w", err)
|
||||
}
|
||||
|
@ -1053,9 +1048,17 @@ func (h *Headscale) loadACLPolicy() error {
|
|||
return fmt.Errorf("failed to get policy from database: %w", err)
|
||||
}
|
||||
|
||||
pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data))
|
||||
nodes, err := h.db.ListNodes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse policy: %w", err)
|
||||
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||
}
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||
}
|
||||
pm, err = policy.NewPolicyManager([]byte(p.Data), users, nodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading policy from database: %w", err)
|
||||
}
|
||||
default:
|
||||
log.Fatal().
|
||||
|
@ -1063,7 +1066,7 @@ func (h *Headscale) loadACLPolicy() error {
|
|||
Msg("Unknown ACL policy mode")
|
||||
}
|
||||
|
||||
h.ACLPolicy = pol
|
||||
h.polMan = pm
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -559,10 +559,6 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
adb, err := newTestDB()
|
||||
assert.NoError(t, err)
|
||||
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, pol)
|
||||
|
||||
user, err := adb.CreateUser("test")
|
||||
assert.NoError(t, err)
|
||||
|
@ -599,8 +595,17 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
node0ByID, err := adb.GetNodeByID(0)
|
||||
assert.NoError(t, err)
|
||||
|
||||
users, err := adb.ListUsers()
|
||||
assert.NoError(t, err)
|
||||
|
||||
nodes, err := adb.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// TODO(kradalby): Check state update
|
||||
err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
|
||||
err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
||||
|
|
|
@ -598,18 +598,18 @@ func failoverRoute(
|
|||
}
|
||||
|
||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||
aclPolicy *policy.ACLPolicy,
|
||||
polMan policy.PolicyManager,
|
||||
node *types.Node,
|
||||
) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
|
||||
return EnableAutoApprovedRoutes(tx, polMan, node)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||
func EnableAutoApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
aclPolicy *policy.ACLPolicy,
|
||||
polMan policy.PolicyManager,
|
||||
node *types.Node,
|
||||
) error {
|
||||
if node.IPv4 == nil && node.IPv6 == nil {
|
||||
|
@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes(
|
|||
continue
|
||||
}
|
||||
|
||||
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
|
||||
netip.Prefix(advertisedRoute.Prefix),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err)
|
||||
}
|
||||
routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
|
@ -648,13 +643,8 @@ func EnableAutoApprovedRoutes(
|
|||
if approvedAlias == node.User.Username() {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
} else {
|
||||
users, err := ListUsers(tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("looking up users to expand route alias: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
|
||||
approvedIps, err := polMan.IPsForUser(approvedAlias)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
@ -450,10 +449,7 @@ func (api headscaleV1APIServer) ListNodes(
|
|||
resp.Online = true
|
||||
}
|
||||
|
||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
||||
node,
|
||||
)
|
||||
resp.InvalidTags = invalidTags
|
||||
validTags := api.h.polMan.Tags(node)
|
||||
resp.ValidTags = validTags
|
||||
response[index] = resp
|
||||
}
|
||||
|
@ -723,11 +719,6 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||
|
||||
p := request.GetPolicy()
|
||||
|
||||
pol, err := policy.LoadACLPolicyFromBytes([]byte(p))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading ACL policy file: %w", err)
|
||||
}
|
||||
|
||||
// Validate and reject configuration that would error when applied
|
||||
// when creating a map response. This requires nodes, so there is still
|
||||
// a scenario where they might be allowed if the server has no nodes
|
||||
|
@ -742,13 +733,21 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||
}
|
||||
|
||||
_, err = pol.CompileFilterRules(users, nodes)
|
||||
err = api.h.polMan.SetNodes(nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verifying policy rules: %w", err)
|
||||
return nil, fmt.Errorf("setting nodes: %w", err)
|
||||
}
|
||||
err = api.h.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting users: %w", err)
|
||||
}
|
||||
err = api.h.polMan.SetPolicy([]byte(p))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting policy: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
|
||||
_, err = api.h.polMan.SSHPolicy(nodes[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
||||
}
|
||||
|
@ -759,8 +758,6 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
api.h.ACLPolicy = pol
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StateFullUpdate,
|
||||
|
|
|
@ -55,6 +55,7 @@ type Mapper struct {
|
|||
cfg *types.Config
|
||||
derpMap *tailcfg.DERPMap
|
||||
notif *notifier.Notifier
|
||||
polMan policy.PolicyManager
|
||||
|
||||
uid string
|
||||
created time.Time
|
||||
|
@ -71,6 +72,7 @@ func NewMapper(
|
|||
cfg *types.Config,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
notif *notifier.Notifier,
|
||||
polMan policy.PolicyManager,
|
||||
) *Mapper {
|
||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
|
@ -79,6 +81,7 @@ func NewMapper(
|
|||
cfg: cfg,
|
||||
derpMap: derpMap,
|
||||
notif: notif,
|
||||
polMan: polMan,
|
||||
|
||||
uid: uid,
|
||||
created: time.Now(),
|
||||
|
@ -154,10 +157,9 @@ func (m *Mapper) fullMapResponse(
|
|||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
users []types.User,
|
||||
pol *policy.ACLPolicy,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
|
||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -165,7 +167,7 @@ func (m *Mapper) fullMapResponse(
|
|||
err = appendPeerChanges(
|
||||
resp,
|
||||
true, // full change
|
||||
pol,
|
||||
m.polMan,
|
||||
node,
|
||||
capVer,
|
||||
users,
|
||||
|
@ -184,7 +186,6 @@ func (m *Mapper) fullMapResponse(
|
|||
func (m *Mapper) FullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
pol *policy.ACLPolicy,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
peers, err := m.ListPeers(node.ID)
|
||||
|
@ -196,7 +197,7 @@ func (m *Mapper) FullMapResponse(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
|
||||
resp, err := m.fullMapResponse(node, peers, users, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -210,10 +211,9 @@ func (m *Mapper) FullMapResponse(
|
|||
func (m *Mapper) ReadOnlyMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
pol *policy.ACLPolicy,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
|
||||
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -249,7 +249,6 @@ func (m *Mapper) PeerChangedResponse(
|
|||
node *types.Node,
|
||||
changed map[types.NodeID]bool,
|
||||
patches []*tailcfg.PeerChange,
|
||||
pol *policy.ACLPolicy,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
|
@ -284,7 +283,7 @@ func (m *Mapper) PeerChangedResponse(
|
|||
err = appendPeerChanges(
|
||||
&resp,
|
||||
false, // partial change
|
||||
pol,
|
||||
m.polMan,
|
||||
node,
|
||||
mapRequest.Version,
|
||||
users,
|
||||
|
@ -315,7 +314,7 @@ func (m *Mapper) PeerChangedResponse(
|
|||
|
||||
// Add the node itself, it might have changed, and particularly
|
||||
// if there are no patches or changes, this is a self update.
|
||||
tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
|
||||
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -330,7 +329,6 @@ func (m *Mapper) PeerChangedPatchResponse(
|
|||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
changed []*tailcfg.PeerChange,
|
||||
pol *policy.ACLPolicy,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.PeersChangedPatch = changed
|
||||
|
@ -459,12 +457,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
|||
// incremental.
|
||||
func (m *Mapper) baseWithConfigMapResponse(
|
||||
node *types.Node,
|
||||
pol *policy.ACLPolicy,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp := m.baseMapResponse()
|
||||
|
||||
tailnode, err := tailNode(node, capVer, pol, m.cfg)
|
||||
tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -517,7 +514,7 @@ func appendPeerChanges(
|
|||
resp *tailcfg.MapResponse,
|
||||
|
||||
fullChange bool,
|
||||
pol *policy.ACLPolicy,
|
||||
polMan policy.PolicyManager,
|
||||
node *types.Node,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
users []types.User,
|
||||
|
@ -525,27 +522,24 @@ func appendPeerChanges(
|
|||
changed types.Nodes,
|
||||
cfg *types.Config,
|
||||
) error {
|
||||
packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filter := polMan.Filter()
|
||||
|
||||
sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
|
||||
sshPolicy, err := polMan.SSHPolicy(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
if len(packetFilter) > 0 {
|
||||
changed = policy.FilterNodesByACL(node, changed, packetFilter)
|
||||
if len(filter) > 0 {
|
||||
changed = policy.FilterNodesByACL(node, changed, filter)
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(node, changed)
|
||||
|
||||
dnsConfig := generateDNSConfig(cfg, node)
|
||||
|
||||
tailPeers, err := tailNodes(changed, capVer, pol, cfg)
|
||||
tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -570,7 +564,7 @@ func appendPeerChanges(
|
|||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node, packetFilter),
|
||||
"base": policy.ReduceFilterRules(node, filter),
|
||||
}
|
||||
} else {
|
||||
// This is a hack to avoid sending an empty list of packet filters.
|
||||
|
@ -578,11 +572,11 @@ func appendPeerChanges(
|
|||
// be omitted, causing the client to consider it unchanged, keeping the
|
||||
// previous packet filter. Worst case, this can cause a node that previously
|
||||
// has access to a node to _not_ loose access if an empty (allow none) is sent.
|
||||
reduced := policy.ReduceFilterRules(node, packetFilter)
|
||||
reduced := policy.ReduceFilterRules(node, filter)
|
||||
if len(reduced) > 0 {
|
||||
resp.PacketFilter = reduced
|
||||
} else {
|
||||
resp.PacketFilter = packetFilter
|
||||
resp.PacketFilter = filter
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -461,18 +461,20 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||
|
||||
mappy := NewMapper(
|
||||
nil,
|
||||
tt.cfg,
|
||||
tt.derpMap,
|
||||
nil,
|
||||
polMan,
|
||||
)
|
||||
|
||||
got, err := mappy.fullMapResponse(
|
||||
tt.node,
|
||||
tt.peers,
|
||||
[]types.User{user1, user2},
|
||||
tt.pol,
|
||||
0,
|
||||
)
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
func tailNodes(
|
||||
nodes types.Nodes,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
pol *policy.ACLPolicy,
|
||||
polMan policy.PolicyManager,
|
||||
cfg *types.Config,
|
||||
) ([]*tailcfg.Node, error) {
|
||||
tNodes := make([]*tailcfg.Node, len(nodes))
|
||||
|
@ -23,7 +23,7 @@ func tailNodes(
|
|||
node, err := tailNode(
|
||||
node,
|
||||
capVer,
|
||||
pol,
|
||||
polMan,
|
||||
cfg,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -40,7 +40,7 @@ func tailNodes(
|
|||
func tailNode(
|
||||
node *types.Node,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
pol *policy.ACLPolicy,
|
||||
polMan policy.PolicyManager,
|
||||
cfg *types.Config,
|
||||
) (*tailcfg.Node, error) {
|
||||
addrs := node.Prefixes()
|
||||
|
@ -81,7 +81,7 @@ func tailNode(
|
|||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||
}
|
||||
|
||||
tags, _ := pol.TagsOfNode(node)
|
||||
tags := polMan.Tags(node)
|
||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
|
||||
tNode := tailcfg.Node{
|
||||
|
|
|
@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
cfg := &types.Config{
|
||||
BaseDomain: tt.baseDomain,
|
||||
DNSConfig: tt.dnsConfig,
|
||||
|
@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) {
|
|||
got, err := tailNode(
|
||||
tt.node,
|
||||
0,
|
||||
tt.pol,
|
||||
polMan,
|
||||
cfg,
|
||||
)
|
||||
|
||||
|
@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) {
|
|||
tn, err := tailNode(
|
||||
node,
|
||||
0,
|
||||
&policy.ACLPolicy{},
|
||||
&policy.PolicyManagerV1{},
|
||||
&types.Config{},
|
||||
)
|
||||
if err != nil {
|
||||
|
|
164
hscontrol/policy/pm.go
Normal file
164
hscontrol/policy/pm.go
Normal file
|
@ -0,0 +1,164 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
type PolicyManager interface {
|
||||
Filter() []tailcfg.FilterRule
|
||||
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||
Tags(*types.Node) []string
|
||||
ApproversForRoute(netip.Prefix) []string
|
||||
IPsForUser(string) (*netipx.IPSet, error)
|
||||
SetPolicy([]byte) error
|
||||
SetUsers(users []types.User) error
|
||||
SetNodes(nodes types.Nodes) error
|
||||
}
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPolicyManager(policyBytes, users, nodes)
|
||||
}
|
||||
|
||||
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
err := pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
type PolicyManagerV1 struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManagerV1) updateLocked() error {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SetPolicy(polB []byte) error {
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetUsers(users []types.User) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
|
||||
if pm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tags, _ := pm.pol.TagsOfNode(node)
|
||||
return tags
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
|
||||
// TODO(kradalby): This can be a parse error of the address in the policy,
|
||||
// in the new policy this will be typed and not a problem, in this policy
|
||||
// we will just return empty list
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
return approvers
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) IPsForUser(user string) (*netipx.IPSet, error) {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ips, nil
|
||||
}
|
|
@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() {
|
|||
switch update.Type {
|
||||
case types.StateFullUpdate:
|
||||
m.tracef("Sending Full MapResponse")
|
||||
data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
||||
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
||||
case types.StatePeerChanged:
|
||||
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
||||
|
||||
|
@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() {
|
|||
|
||||
lastMessage = update.Message
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||
updateType = "change"
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy)
|
||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
|
||||
updateType = "patch"
|
||||
case types.StatePeerRemoved:
|
||||
changed := make(map[types.NodeID]bool, len(update.Removed))
|
||||
|
@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() {
|
|||
changed[nodeID] = false
|
||||
}
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||
updateType = "remove"
|
||||
case types.StateSelfUpdate:
|
||||
lastMessage = update.Message
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
// create the map so an empty (self) update is sent
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
|
||||
updateType = "remove"
|
||||
case types.StateDERPUpdated:
|
||||
m.tracef("Sending DERPUpdate MapResponse")
|
||||
|
@ -488,9 +488,9 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||
return
|
||||
}
|
||||
|
||||
if m.h.ACLPolicy != nil {
|
||||
if m.h.polMan != nil {
|
||||
// update routes with peer information
|
||||
err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node)
|
||||
err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Error running auto approved routes")
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
|
@ -544,7 +544,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||
func (m *mapSession) handleReadOnlyRequest() {
|
||||
m.tracef("Client asked for a lite update, responding without peers")
|
||||
|
||||
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy)
|
||||
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to create MapResponse")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
|
|
Loading…
Reference in a new issue