Compare commits

...

14 commits

Author SHA1 Message Date
Kristoffer Dalby
712bb375bf
Merge 014ee87066 into 6275399327 2024-11-18 21:25:44 +01:00
Kristoffer Dalby
014ee87066
update integration test build
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-29 15:02:53 -04:00
Kristoffer Dalby
a942fcf50a
fix autoapprove
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-29 08:40:15 -04:00
Kristoffer Dalby
24f3895b2b
update error string
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-27 13:58:55 -04:00
Kristoffer Dalby
85a038cfca
tags approved via acl
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-27 13:53:45 -04:00
Kristoffer Dalby
dbf2faa4bf
fix nil in router
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-27 12:56:56 -04:00
Kristoffer Dalby
7f665023d8
fix nil pointer in oidc for policy
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-27 11:50:47 -04:00
Kristoffer Dalby
f2ab5e05c9
split out reading policy and applying
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 18:29:27 -04:00
Kristoffer Dalby
50b62ddfb3
fix loading policy manager
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 18:19:14 -04:00
Kristoffer Dalby
8d5b04f3d3
hook up user and node changes to policy
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 12:53:04 -05:00
Kristoffer Dalby
19bc8b6e01
report if filter has changed
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 12:27:12 -05:00
Kristoffer Dalby
8ecba121cc
remove unused args
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 11:43:18 -05:00
Kristoffer Dalby
6afb554e20
wrap policy in policy manager interface
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 11:42:14 -05:00
Kristoffer Dalby
50165ce9e1
resolve user identifier to stable ID
currently, the policy approach node to user matching
with a quite naive approach looking at the username
provided in the policy and matched it with the username
on the nodes. This worked ok as long as usernames were
unique and did not change.

As usernames are no longer guarenteed to be unique in
an OIDC environment we cant rely on this.

This changes the mechanism that matches the user string
(now user token) with nodes:

- first find all potential users by looking up:
  - database ID
  - provider ID (OIDC)
  - username/email

If more than one user is matching, then the query is
rejected, and zero matching nodes are returned.

When a single user is found, the node is matched against
the User database ID, which are also present on the actual
node.

This means that from this commit, users can use the following
to identify users in the policy:
- provider identity (iss + sub)
- username
- email
- database id

There are more changes coming to this, so it is not recommended
to start using any of these new abilities, with the exception
of email, which will not change since it includes an @.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-23 22:31:37 -05:00
18 changed files with 1041 additions and 355 deletions

View file

@ -30,8 +30,7 @@ jobs:
- TestPreAuthKeyCorrectUserLoggedInCommand - TestPreAuthKeyCorrectUserLoggedInCommand
- TestApiKeyCommand - TestApiKeyCommand
- TestNodeTagCommand - TestNodeTagCommand
- TestNodeAdvertiseTagNoACLCommand - TestNodeAdvertiseTagCommand
- TestNodeAdvertiseTagWithACLCommand
- TestNodeCommand - TestNodeCommand
- TestNodeExpireCommand - TestNodeExpireCommand
- TestNodeRenameCommand - TestNodeRenameCommand

View file

@ -88,7 +88,8 @@ type Headscale struct {
DERPMap *tailcfg.DERPMap DERPMap *tailcfg.DERPMap
DERPServer *derpServer.DERPServer DERPServer *derpServer.DERPServer
ACLPolicy *policy.ACLPolicy polManOnce sync.Once
polMan policy.PolicyManager
mapper *mapper.Mapper mapper *mapper.Mapper
nodeNotifier *notifier.Notifier nodeNotifier *notifier.Notifier
@ -153,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
} }
}) })
if err = app.loadPolicyManager(); err != nil {
return nil, fmt.Errorf("failed to load ACL policy: %w", err)
}
var authProvider AuthProvider var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL) authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" { if cfg.OIDC.Issuer != "" {
@ -165,6 +170,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app.db, app.db,
app.nodeNotifier, app.nodeNotifier,
app.ipAlloc, app.ipAlloc,
app.polMan,
) )
if err != nil { if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable { if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
@ -473,6 +479,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router return router
} }
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
users, err := db.ListUsers()
if err != nil {
return err
}
changed, err := polMan.SetUsers(users)
if err != nil {
return err
}
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}
return nil
}
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
nodes, err := db.ListNodes()
if err != nil {
return err
}
changed, err := polMan.SetNodes(nodes)
if err != nil {
return err
}
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}
return nil
}
// Serve launches the HTTP and gRPC server service Headscale and the API. // Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
if profilingEnabled { if profilingEnabled {
@ -488,19 +540,13 @@ func (h *Headscale) Serve() error {
} }
} }
var err error
if err = h.loadACLPolicy(); err != nil {
return fmt.Errorf("failed to load ACL policy: %w", err)
}
if dumpConfig { if dumpConfig {
spew.Dump(h.cfg) spew.Dump(h.cfg)
} }
// Fetch an initial DERP Map before we start serving // Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP) h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier) h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)
if h.cfg.DERP.ServerEnabled { if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server // When embedded DERP is enabled we always need a STUN server
@ -770,12 +816,21 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received SIGHUP, reloading ACL and Config") Msg("Received SIGHUP, reloading ACL and Config")
// TODO(kradalby): Reload config on SIGHUP if err := h.loadPolicyManager(); err != nil {
if err := h.loadACLPolicy(); err != nil { log.Error().Err(err).Msg("failed to reload Policy")
log.Error().Err(err).Msg("failed to reload ACL policy")
} }
if h.ACLPolicy != nil { pol, err := h.policyBytes()
if err != nil {
log.Error().Err(err).Msg("failed to get policy blob")
}
changed, err := h.polMan.SetPolicy(pol)
if err != nil {
log.Error().Err(err).Msg("failed to set new policy")
}
if changed {
log.Info(). log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change") Msg("ACL policy successfully reloaded, notifying nodes of change")
@ -994,27 +1049,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil return &machineKey, nil
} }
func (h *Headscale) loadACLPolicy() error { // policyBytes returns the appropriate policy for the
var ( // current configuration as a []byte array.
pol *policy.ACLPolicy func (h *Headscale) policyBytes() ([]byte, error) {
err error
)
switch h.cfg.Policy.Mode { switch h.cfg.Policy.Mode {
case types.PolicyModeFile: case types.PolicyModeFile:
path := h.cfg.Policy.Path path := h.cfg.Policy.Path
// It is fine to start headscale without a policy file. // It is fine to start headscale without a policy file.
if len(path) == 0 { if len(path) == 0 {
return nil return nil, nil
} }
absPath := util.AbsolutePathFromConfigPath(path) absPath := util.AbsolutePathFromConfigPath(path)
pol, err = policy.LoadACLPolicyFromPath(absPath) policyFile, err := os.Open(absPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to load ACL policy from file: %w", err) return nil, err
}
defer policyFile.Close()
return io.ReadAll(policyFile)
case types.PolicyModeDB:
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil, nil
}
return nil, err
} }
return []byte(p.Data), err
}
return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
}
func (h *Headscale) loadPolicyManager() error {
var errOut error
h.polManOnce.Do(func() {
// Validate and reject configuration that would error when applied // Validate and reject configuration that would error when applied
// when creating a map response. This requires nodes, so there is still // when creating a map response. This requires nodes, so there is still
// a scenario where they might be allowed if the server has no nodes // a scenario where they might be allowed if the server has no nodes
@ -1025,42 +1099,35 @@ func (h *Headscale) loadACLPolicy() error {
// allowed to be written to the database. // allowed to be written to the database.
nodes, err := h.db.ListNodes() nodes, err := h.db.ListNodes()
if err != nil { if err != nil {
return fmt.Errorf("loading nodes from database to validate policy: %w", err) errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err)
return
}
users, err := h.db.ListUsers()
if err != nil {
errOut = fmt.Errorf("loading users from database to validate policy: %w", err)
return
} }
_, err = pol.CompileFilterRules(nodes) pol, err := h.policyBytes()
if err != nil { if err != nil {
return fmt.Errorf("verifying policy rules: %w", err) errOut = fmt.Errorf("loading policy bytes: %w", err)
return
}
h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
if err != nil {
errOut = fmt.Errorf("creating policy manager: %w", err)
return
} }
if len(nodes) > 0 { if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes) _, err = h.polMan.SSHPolicy(nodes[0])
if err != nil { if err != nil {
return fmt.Errorf("verifying SSH rules: %w", err) errOut = fmt.Errorf("verifying SSH rules: %w", err)
return
} }
} }
})
case types.PolicyModeDB: return errOut
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil
}
return fmt.Errorf("failed to get policy from database: %w", err)
}
pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data))
if err != nil {
return fmt.Errorf("failed to parse policy: %w", err)
}
default:
log.Fatal().
Str("mode", string(h.cfg.Policy.Mode)).
Msg("Unknown ACL policy mode")
}
h.ACLPolicy = pol
return nil
} }

View file

@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey(
return return
} }
err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
if err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
} }
err = h.db.Write(func(tx *gorm.DB) error { err = h.db.Write(func(tx *gorm.DB) error {

View file

@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(testPeers), check.Equals, 9) c.Assert(len(testPeers), check.Equals, 9)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
@ -559,10 +559,6 @@ func TestAutoApproveRoutes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
adb, err := newTestDB() adb, err := newTestDB()
assert.NoError(t, err) assert.NoError(t, err)
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
assert.NoError(t, err)
assert.NotNil(t, pol)
user, err := adb.CreateUser("test") user, err := adb.CreateUser("test")
assert.NoError(t, err) assert.NoError(t, err)
@ -599,8 +595,17 @@ func TestAutoApproveRoutes(t *testing.T) {
node0ByID, err := adb.GetNodeByID(0) node0ByID, err := adb.GetNodeByID(0)
assert.NoError(t, err) assert.NoError(t, err)
users, err := adb.ListUsers()
assert.NoError(t, err)
nodes, err := adb.ListNodes()
assert.NoError(t, err)
pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
assert.NoError(t, err)
// TODO(kradalby): Check state update // TODO(kradalby): Check state update
err = adb.EnableAutoApprovedRoutes(pol, node0ByID) err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
assert.NoError(t, err) assert.NoError(t, err)
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)

View file

@ -598,18 +598,18 @@ func failoverRoute(
} }
func (hsdb *HSDatabase) EnableAutoApprovedRoutes( func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy, polMan policy.PolicyManager,
node *types.Node, node *types.Node,
) error { ) error {
return hsdb.Write(func(tx *gorm.DB) error { return hsdb.Write(func(tx *gorm.DB) error {
return EnableAutoApprovedRoutes(tx, aclPolicy, node) return EnableAutoApprovedRoutes(tx, polMan, node)
}) })
} }
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func EnableAutoApprovedRoutes( func EnableAutoApprovedRoutes(
tx *gorm.DB, tx *gorm.DB,
aclPolicy *policy.ACLPolicy, polMan policy.PolicyManager,
node *types.Node, node *types.Node,
) error { ) error {
if node.IPv4 == nil && node.IPv6 == nil { if node.IPv4 == nil && node.IPv6 == nil {
@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes(
continue continue
} }
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))
netip.Prefix(advertisedRoute.Prefix),
)
if err != nil {
return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err)
}
log.Trace(). log.Trace().
Str("node", node.Hostname). Str("node", node.Hostname).
@ -649,7 +644,7 @@ func EnableAutoApprovedRoutes(
approvedRoutes = append(approvedRoutes, advertisedRoute) approvedRoutes = append(approvedRoutes, advertisedRoute)
} else { } else {
// TODO(kradalby): figure out how to get this to depend on less stuff // TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) approvedIps, err := polMan.ExpandAlias(approvedAlias)
if err != nil { if err != nil {
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
} }

View file

@ -21,7 +21,6 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
) )
@ -58,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser(
return nil, err return nil, err
} }
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return &v1.CreateUserResponse{User: user.Proto()}, nil return &v1.CreateUserResponse{User: user.Proto()}, nil
} }
@ -87,6 +91,11 @@ func (api headscaleV1APIServer) DeleteUser(
return nil, err return nil, err
} }
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return &v1.DeleteUserResponse{}, nil return &v1.DeleteUserResponse{}, nil
} }
@ -221,6 +230,11 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err return nil, err
} }
err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using node: %w", err)
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
} }
@ -450,10 +464,7 @@ func (api headscaleV1APIServer) ListNodes(
resp.Online = true resp.Online = true
} }
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( validTags := api.h.polMan.Tags(node)
node,
)
resp.InvalidTags = invalidTags
resp.ValidTags = validTags resp.ValidTags = validTags
response[index] = resp response[index] = resp
} }
@ -723,11 +734,6 @@ func (api headscaleV1APIServer) SetPolicy(
p := request.GetPolicy() p := request.GetPolicy()
pol, err := policy.LoadACLPolicyFromBytes([]byte(p))
if err != nil {
return nil, fmt.Errorf("loading ACL policy file: %w", err)
}
// Validate and reject configuration that would error when applied // Validate and reject configuration that would error when applied
// when creating a map response. This requires nodes, so there is still // when creating a map response. This requires nodes, so there is still
// a scenario where they might be allowed if the server has no nodes // a scenario where they might be allowed if the server has no nodes
@ -737,14 +743,13 @@ func (api headscaleV1APIServer) SetPolicy(
if err != nil { if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
} }
changed, err := api.h.polMan.SetPolicy([]byte(p))
_, err = pol.CompileFilterRules(nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("verifying policy rules: %w", err) return nil, fmt.Errorf("setting policy: %w", err)
} }
if len(nodes) > 0 { if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes) _, err = api.h.polMan.SSHPolicy(nodes[0])
if err != nil { if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err) return nil, fmt.Errorf("verifying SSH rules: %w", err)
} }
@ -755,12 +760,13 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err return nil, err
} }
api.h.ACLPolicy = pol // Only send update if the packet filter has changed.
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-update", "na") ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate, Type: types.StateFullUpdate,
}) })
}
response := &v1.SetPolicyResponse{ response := &v1.SetPolicyResponse{
Policy: updated.Data, Policy: updated.Data,

View file

@ -55,6 +55,7 @@ type Mapper struct {
cfg *types.Config cfg *types.Config
derpMap *tailcfg.DERPMap derpMap *tailcfg.DERPMap
notif *notifier.Notifier notif *notifier.Notifier
polMan policy.PolicyManager
uid string uid string
created time.Time created time.Time
@ -71,6 +72,7 @@ func NewMapper(
cfg *types.Config, cfg *types.Config,
derpMap *tailcfg.DERPMap, derpMap *tailcfg.DERPMap,
notif *notifier.Notifier, notif *notifier.Notifier,
polMan policy.PolicyManager,
) *Mapper { ) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
@ -79,6 +81,7 @@ func NewMapper(
cfg: cfg, cfg: cfg,
derpMap: derpMap, derpMap: derpMap,
notif: notif, notif: notif,
polMan: polMan,
uid: uid, uid: uid,
created: time.Now(), created: time.Now(),
@ -153,10 +156,9 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
func (m *Mapper) fullMapResponse( func (m *Mapper) fullMapResponse(
node *types.Node, node *types.Node,
peers types.Nodes, peers types.Nodes,
pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, pol, capVer) resp, err := m.baseWithConfigMapResponse(node, capVer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -164,11 +166,10 @@ func (m *Mapper) fullMapResponse(
err = appendPeerChanges( err = appendPeerChanges(
resp, resp,
true, // full change true, // full change
pol, m.polMan,
node, node,
capVer, capVer,
peers, peers,
peers,
m.cfg, m.cfg,
) )
if err != nil { if err != nil {
@ -182,7 +183,6 @@ func (m *Mapper) fullMapResponse(
func (m *Mapper) FullMapResponse( func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
pol *policy.ACLPolicy,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
peers, err := m.ListPeers(node.ID) peers, err := m.ListPeers(node.ID)
@ -190,7 +190,7 @@ func (m *Mapper) FullMapResponse(
return nil, err return nil, err
} }
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -204,10 +204,9 @@ func (m *Mapper) FullMapResponse(
func (m *Mapper) ReadOnlyMapResponse( func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
pol *policy.ACLPolicy,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -243,7 +242,6 @@ func (m *Mapper) PeerChangedResponse(
node *types.Node, node *types.Node,
changed map[types.NodeID]bool, changed map[types.NodeID]bool,
patches []*tailcfg.PeerChange, patches []*tailcfg.PeerChange,
pol *policy.ACLPolicy,
messages ...string, messages ...string,
) ([]byte, error) { ) ([]byte, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
@ -273,10 +271,9 @@ func (m *Mapper) PeerChangedResponse(
err = appendPeerChanges( err = appendPeerChanges(
&resp, &resp,
false, // partial change false, // partial change
pol, m.polMan,
node, node,
mapRequest.Version, mapRequest.Version,
peers,
changedNodes, changedNodes,
m.cfg, m.cfg,
) )
@ -303,7 +300,7 @@ func (m *Mapper) PeerChangedResponse(
// Add the node itself, it might have changed, and particularly // Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update. // if there are no patches or changes, this is a self update.
tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg) tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -318,7 +315,6 @@ func (m *Mapper) PeerChangedPatchResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
node *types.Node, node *types.Node,
changed []*tailcfg.PeerChange, changed []*tailcfg.PeerChange,
pol *policy.ACLPolicy,
) ([]byte, error) { ) ([]byte, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
resp.PeersChangedPatch = changed resp.PeersChangedPatch = changed
@ -447,12 +443,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
// incremental. // incremental.
func (m *Mapper) baseWithConfigMapResponse( func (m *Mapper) baseWithConfigMapResponse(
node *types.Node, node *types.Node,
pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse() resp := m.baseMapResponse()
tailnode, err := tailNode(node, capVer, pol, m.cfg) tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -505,34 +500,30 @@ func appendPeerChanges(
resp *tailcfg.MapResponse, resp *tailcfg.MapResponse,
fullChange bool, fullChange bool,
pol *policy.ACLPolicy, polMan policy.PolicyManager,
node *types.Node, node *types.Node,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
peers types.Nodes,
changed types.Nodes, changed types.Nodes,
cfg *types.Config, cfg *types.Config,
) error { ) error {
packetFilter, err := pol.CompileFilterRules(append(peers, node)) filter := polMan.Filter()
if err != nil {
return err
}
sshPolicy, err := pol.CompileSSHPolicy(node, peers) sshPolicy, err := polMan.SSHPolicy(node)
if err != nil { if err != nil {
return err return err
} }
// If there are filter rules present, see if there are any nodes that cannot // If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers. // access each-other at all and remove them from the peers.
if len(packetFilter) > 0 { if len(filter) > 0 {
changed = policy.FilterNodesByACL(node, changed, packetFilter) changed = policy.FilterNodesByACL(node, changed, filter)
} }
profiles := generateUserProfiles(node, changed) profiles := generateUserProfiles(node, changed)
dnsConfig := generateDNSConfig(cfg, node) dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes(changed, capVer, pol, cfg) tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
if err != nil { if err != nil {
return err return err
} }
@ -557,7 +548,7 @@ func appendPeerChanges(
// new PacketFilters field and "base" allows us to send a full update when we // new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block. // have to send an empty list, avoiding the hack in the else block.
resp.PacketFilters = map[string][]tailcfg.FilterRule{ resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node, packetFilter), "base": policy.ReduceFilterRules(node, filter),
} }
} else { } else {
// This is a hack to avoid sending an empty list of packet filters. // This is a hack to avoid sending an empty list of packet filters.
@ -565,11 +556,11 @@ func appendPeerChanges(
// be omitted, causing the client to consider it unchanged, keeping the // be omitted, causing the client to consider it unchanged, keeping the
// previous packet filter. Worst case, this can cause a node that previously // previous packet filter. Worst case, this can cause a node that previously
// has access to a node to _not_ loose access if an empty (allow none) is sent. // has access to a node to _not_ loose access if an empty (allow none) is sent.
reduced := policy.ReduceFilterRules(node, packetFilter) reduced := policy.ReduceFilterRules(node, filter)
if len(reduced) > 0 { if len(reduced) > 0 {
resp.PacketFilter = reduced resp.PacketFilter = reduced
} else { } else {
resp.PacketFilter = packetFilter resp.PacketFilter = filter
} }
} }

View file

@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
mini := &types.Node{ mini := &types.Node{
ID: 0, ID: 0,
MachineKey: mustMK( MachineKey: mustMK(
@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
Hostname: "mini", Hostname: "mini",
GivenName: "mini", GivenName: "mini",
UserID: 0, UserID: user1.ID,
User: types.User{Name: "mini"}, User: user1,
ForcedTags: []string{}, ForcedTags: []string{},
AuthKey: &types.PreAuthKey{}, AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
Hostname: "peer1", Hostname: "peer1",
GivenName: "peer1", GivenName: "peer1",
UserID: 0, UserID: user1.ID,
User: types.User{Name: "mini"}, User: user1,
ForcedTags: []string{}, ForcedTags: []string{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Expiry: &expire, Expiry: &expire,
@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
Hostname: "peer2", Hostname: "peer2",
GivenName: "peer2", GivenName: "peer2",
UserID: 1, UserID: user2.ID,
User: types.User{Name: "peer2"}, User: user2,
ForcedTags: []string{}, ForcedTags: []string{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Expiry: &expire, Expiry: &expire,
@ -458,17 +461,19 @@ func Test_fullMapResponse(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
mappy := NewMapper( mappy := NewMapper(
nil, nil,
tt.cfg, tt.cfg,
tt.derpMap, tt.derpMap,
nil, nil,
polMan,
) )
got, err := mappy.fullMapResponse( got, err := mappy.fullMapResponse(
tt.node, tt.node,
tt.peers, tt.peers,
tt.pol,
0, 0,
) )

View file

@ -14,7 +14,7 @@ import (
func tailNodes( func tailNodes(
nodes types.Nodes, nodes types.Nodes,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
pol *policy.ACLPolicy, polMan policy.PolicyManager,
cfg *types.Config, cfg *types.Config,
) ([]*tailcfg.Node, error) { ) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, len(nodes)) tNodes := make([]*tailcfg.Node, len(nodes))
@ -23,7 +23,7 @@ func tailNodes(
node, err := tailNode( node, err := tailNode(
node, node,
capVer, capVer,
pol, polMan,
cfg, cfg,
) )
if err != nil { if err != nil {
@ -40,7 +40,7 @@ func tailNodes(
func tailNode( func tailNode(
node *types.Node, node *types.Node,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
pol *policy.ACLPolicy, polMan policy.PolicyManager,
cfg *types.Config, cfg *types.Config,
) (*tailcfg.Node, error) { ) (*tailcfg.Node, error) {
addrs := node.Prefixes() addrs := node.Prefixes()
@ -81,7 +81,7 @@ func tailNode(
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
} }
tags, _ := pol.TagsOfNode(node) tags := polMan.Tags(node)
tags = lo.Uniq(append(tags, node.ForcedTags...)) tags = lo.Uniq(append(tags, node.ForcedTags...))
tNode := tailcfg.Node{ tNode := tailcfg.Node{

View file

@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
cfg := &types.Config{ cfg := &types.Config{
BaseDomain: tt.baseDomain, BaseDomain: tt.baseDomain,
DNSConfig: tt.dnsConfig, DNSConfig: tt.dnsConfig,
@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) {
got, err := tailNode( got, err := tailNode(
tt.node, tt.node,
0, 0,
tt.pol, polMan,
cfg, cfg,
) )
@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) {
tn, err := tailNode( tn, err := tailNode(
node, node,
0, 0,
&policy.ACLPolicy{}, &policy.PolicyManagerV1{},
&types.Config{}, &types.Config{},
) )
if err != nil { if err != nil {

View file

@ -18,6 +18,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -53,6 +54,7 @@ type AuthProviderOIDC struct {
registrationCache *zcache.Cache[string, key.MachinePublic] registrationCache *zcache.Cache[string, key.MachinePublic]
notifier *notifier.Notifier notifier *notifier.Notifier
ipAlloc *db.IPAllocator ipAlloc *db.IPAllocator
polMan policy.PolicyManager
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
@ -65,6 +67,7 @@ func NewAuthProviderOIDC(
db *db.HSDatabase, db *db.HSDatabase,
notif *notifier.Notifier, notif *notifier.Notifier,
ipAlloc *db.IPAllocator, ipAlloc *db.IPAllocator,
polMan policy.PolicyManager,
) (*AuthProviderOIDC, error) { ) (*AuthProviderOIDC, error) {
var err error var err error
// grab oidc config if it hasn't been already // grab oidc config if it hasn't been already
@ -96,6 +99,7 @@ func NewAuthProviderOIDC(
registrationCache: registrationCache, registrationCache: registrationCache,
notifier: notif, notifier: notif,
ipAlloc: ipAlloc, ipAlloc: ipAlloc,
polMan: polMan,
oidcProvider: oidcProvider, oidcProvider: oidcProvider,
oauth2Config: oauth2Config, oauth2Config: oauth2Config,
@ -461,6 +465,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return nil, fmt.Errorf("creating or updating user: %w", err) return nil, fmt.Errorf("creating or updating user: %w", err)
} }
err = usersChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return user, nil return user, nil
} }
@ -484,6 +493,11 @@ func (a *AuthProviderOIDC) registerNode(
return fmt.Errorf("could not register node: %w", err) return fmt.Errorf("could not register node: %w", err)
} }
err = nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return fmt.Errorf("updating resources using node: %w", err)
}
return nil return nil
} }

View file

@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests(
policy *ACLPolicy, policy *ACLPolicy,
node *types.Node, node *types.Node,
peers types.Nodes, peers types.Nodes,
users []types.User,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all // If there is no policy defined, we default to allow all
if policy == nil { if policy == nil {
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
} }
rules, err := policy.CompileFilterRules(append(peers, node)) rules, err := policy.CompileFilterRules(users, append(peers, node))
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
sshPolicy, err := policy.CompileSSHPolicy(node, peers) sshPolicy, err := policy.CompileSSHPolicy(node, users, peers)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests(
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) CompileFilterRules( func (pol *ACLPolicy) CompileFilterRules(
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
if pol == nil { if pol == nil {
@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules(
var srcIPs []string var srcIPs []string
for srcIndex, src := range acl.Sources { for srcIndex, src := range acl.Sources {
srcs, err := pol.expandSource(src, nodes) srcs, err := pol.expandSource(src, users, nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) return nil, fmt.Errorf(
"parsing policy, acl index: %d->%d: %w",
index,
srcIndex,
err,
)
} }
srcIPs = append(srcIPs, srcs...) srcIPs = append(srcIPs, srcs...)
} }
@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules(
expanded, err := pol.ExpandAlias( expanded, err := pol.ExpandAlias(
nodes, nodes,
users,
alias, alias,
) )
if err != nil { if err != nil {
@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
func (pol *ACLPolicy) CompileSSHPolicy( func (pol *ACLPolicy) CompileSSHPolicy(
node *types.Node, node *types.Node,
users []types.User,
peers types.Nodes, peers types.Nodes,
) (*tailcfg.SSHPolicy, error) { ) (*tailcfg.SSHPolicy, error) {
if pol == nil { if pol == nil {
@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
for index, sshACL := range pol.SSHs { for index, sshACL := range pol.SSHs {
var dest netipx.IPSetBuilder var dest netipx.IPSetBuilder
for _, src := range sshACL.Destinations { for _, src := range sshACL.Destinations {
expanded, err := pol.ExpandAlias(append(peers, node), src) expanded, err := pol.ExpandAlias(append(peers, node), users, src)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy(
case "check": case "check":
checkAction, err := sshCheckAction(sshACL.CheckPeriod) checkAction, err := sshCheckAction(sshACL.CheckPeriod)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) return nil, fmt.Errorf(
"parsing SSH policy, parsing check duration, index: %d: %w",
index,
err,
)
} else { } else {
action = *checkAction action = *checkAction
} }
default: default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) return nil, fmt.Errorf(
"parsing SSH policy, unknown action %q, index: %d: %w",
sshACL.Action,
index,
err,
)
} }
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
} else { } else {
expandedSrcs, err := pol.ExpandAlias( expandedSrcs, err := pol.ExpandAlias(
peers, peers,
users,
rawSrc, rawSrc,
) )
if err != nil { if err != nil {
@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// with the given src alias. // with the given src alias.
func (pol *ACLPolicy) expandSource( func (pol *ACLPolicy) expandSource(
src string, src string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) ([]string, error) { ) ([]string, error) {
ipSet, err := pol.ExpandAlias(nodes, src) ipSet, err := pol.ExpandAlias(nodes, users, src)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource(
// and transform these in IPAddresses. // and transform these in IPAddresses.
func (pol *ACLPolicy) ExpandAlias( func (pol *ACLPolicy) ExpandAlias(
nodes types.Nodes, nodes types.Nodes,
users []types.User,
alias string, alias string,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
if isWildcard(alias) { if isWildcard(alias) {
@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group // if alias is a group
if isGroup(alias) { if isGroup(alias) {
return pol.expandIPsFromGroup(alias, nodes) return pol.expandIPsFromGroup(alias, users, nodes)
} }
// if alias is a tag // if alias is a tag
if isTag(alias) { if isTag(alias) {
return pol.expandIPsFromTag(alias, nodes) return pol.expandIPsFromTag(alias, users, nodes)
} }
if isAutoGroup(alias) { if isAutoGroup(alias) {
@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias(
} }
// if alias is a user // if alias is a user
if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil {
return ips, err return ips, err
} }
@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias(
if h, ok := pol.Hosts[alias]; ok { if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.ExpandAlias(nodes, h.String()) return pol.ExpandAlias(nodes, users, h.String())
} }
// if alias is an IP // if alias is an IP
@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup(
func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromGroup(
group string, group string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
users, err := pol.expandUsersFromGroup(group) userTokens, err := pol.expandUsersFromGroup(group)
if err != nil { if err != nil {
return &netipx.IPSet{}, err return &netipx.IPSet{}, err
} }
for _, user := range users { for _, user := range userTokens {
filteredNodes := filterNodesByUser(nodes, user) filteredNodes := filterNodesByUser(nodes, users, user)
for _, node := range filteredNodes { for _, node := range filteredNodes {
node.AppendToIPSet(&build) node.AppendToIPSet(&build)
} }
@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromTag(
alias string, alias string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
// filter out nodes per tag owner // filter out nodes per tag owner
for _, user := range owners { for _, user := range owners {
nodes := filterNodesByUser(nodes, user) nodes := filterNodesByUser(nodes, users, user)
for _, node := range nodes { for _, node := range nodes {
if node.Hostinfo == nil { if node.Hostinfo == nil {
continue continue
@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag(
func (pol *ACLPolicy) expandIPsFromUser( func (pol *ACLPolicy) expandIPsFromUser(
user string, user string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
filteredNodes := filterNodesByUser(nodes, user) filteredNodes := filterNodesByUser(nodes, users, user)
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
// shortcurcuit if we have no nodes to get ips from. // shortcurcuit if we have no nodes to get ips from.
@ -953,10 +977,40 @@ func (pol *ACLPolicy) TagsOfNode(
return validTags, invalidTags return validTags, invalidTags
} }
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { // filterNodesByUser returns a list of nodes that match the given userToken from a
// policy.
// Matching nodes are determined by first matching the user token to a user by checking:
// - If it is an ID that mactches the user database ID
// - It is the Provider Identifier from OIDC
// - It matches the username or email of a user
//
// If the token matches more than one user, zero nodes will returned.
func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes {
var out types.Nodes var out types.Nodes
var potentialUsers []types.User
for _, user := range users {
if user.ProviderIdentifier == userToken {
potentialUsers = append(potentialUsers, user)
break
}
if user.Email == userToken {
potentialUsers = append(potentialUsers, user)
}
if user.Name == userToken {
potentialUsers = append(potentialUsers, user)
}
}
if len(potentialUsers) != 1 {
return nil
}
user := potentialUsers[0]
for _, node := range nodes { for _, node := range nodes {
if node.User.Username() == user { if node.User.ID == user.ID {
out = append(out, node) out = append(out, node)
} }
} }

View file

@ -2,8 +2,10 @@ package policy
import ( import (
"errors" "errors"
"math/rand/v2"
"net/netip" "net/netip"
"slices" "slices"
"sort"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -14,6 +16,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -375,18 +378,24 @@ func TestParsing(t *testing.T) {
return return
} }
rules, err := pol.CompileFilterRules(types.Nodes{ user := types.User{
&types.Node{ Model: gorm.Model{ID: 1},
IPv4: iap("100.100.100.100"), Name: "testuser",
}
rules, err := pol.CompileFilterRules(
[]types.User{
user,
}, },
&types.Node{ types.Nodes{
IPv4: iap("200.200.200.200"), &types.Node{
User: types.User{ IPv4: iap("100.100.100.100"),
Name: "testuser",
}, },
Hostinfo: &tailcfg.Hostinfo{}, &types.Node{
}, IPv4: iap("200.200.200.200"),
}) User: user,
Hostinfo: &tailcfg.Hostinfo{},
},
})
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
@ -533,7 +542,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := pol.CompileFilterRules(types.Nodes{}) rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{})
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil) c.Assert(rules, check.IsNil)
} }
@ -549,7 +558,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
} }
@ -568,7 +577,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
} }
@ -584,7 +593,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
} }
@ -861,6 +870,14 @@ func Test_expandPorts(t *testing.T) {
} }
func Test_listNodesInUser(t *testing.T) { func Test_listNodesInUser(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "marc"},
{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"},
{Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"},
{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"},
{Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"},
}
type args struct { type args struct {
nodes types.Nodes nodes types.Nodes
user string user string
@ -874,50 +891,239 @@ func Test_listNodesInUser(t *testing.T) {
name: "1 node in user", name: "1 node in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{User: types.User{Name: "joe"}}, &types.Node{User: users[1]},
}, },
user: "joe", user: "joe",
}, },
want: types.Nodes{ want: types.Nodes{
&types.Node{User: types.User{Name: "joe"}}, &types.Node{User: users[1]},
}, },
}, },
{ {
name: "3 nodes, 2 in user", name: "3 nodes, 2 in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}}, &types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
}, },
user: "marc", user: "marc",
}, },
want: types.Nodes{ want: types.Nodes{
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
}, },
}, },
{ {
name: "5 nodes, 0 in user", name: "5 nodes, 0 in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}}, &types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
&types.Node{ID: 4, User: types.User{Name: "marc"}}, &types.Node{ID: 4, User: users[0]},
&types.Node{ID: 5, User: types.User{Name: "marc"}}, &types.Node{ID: 5, User: users[0]},
}, },
user: "mickael", user: "mickael",
}, },
want: nil, want: nil,
}, },
{
name: "match-by-provider-ident",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[2]},
},
},
{
name: "match-by-email",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[1]},
},
},
{
name: "multi-match-is-zero",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
&types.Node{ID: 3, User: users[3]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-email-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match email, then provider id
&types.Node{ID: 3, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-username-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match username, then provider id
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-duplicate-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-unique-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "marc",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[0]},
},
},
{
name: "all-users-no-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[1]},
},
},
{
name: "all-users-no-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working@headscale.net",
},
want: nil,
},
{
name: "all-users-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 3, User: users[2]},
},
},
{
name: "all-users-no-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "http://oidc.org/4321",
},
want: nil,
},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got := filterNodesByUser(test.args.nodes, test.args.user) for range 1000 {
ns := test.args.nodes
rand.Shuffle(len(ns), func(i, j int) {
ns[i], ns[j] = ns[j], ns[i]
})
got := filterNodesByUser(ns, users, test.args.user)
sort.Slice(got, func(i, j int) bool {
return got[i].ID < got[j].ID
})
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff)
}
} }
}) })
} }
@ -940,6 +1146,12 @@ func Test_expandAlias(t *testing.T) {
return s return s
} }
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "joe"},
{Model: gorm.Model{ID: 2}, Name: "marc"},
{Model: gorm.Model{ID: 3}, Name: "mickael"},
}
type field struct { type field struct {
pol ACLPolicy pol ACLPolicy
} }
@ -989,19 +1201,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1022,19 +1234,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1185,7 +1397,7 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1194,7 +1406,7 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1203,11 +1415,11 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"}, User: users[0],
}, },
}, },
}, },
@ -1260,21 +1472,21 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1295,12 +1507,12 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1309,11 +1521,11 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1350,12 +1562,12 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
}, },
@ -1368,6 +1580,7 @@ func Test_expandAlias(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got, err := test.field.pol.ExpandAlias( got, err := test.field.pol.ExpandAlias(
test.args.nodes, test.args.nodes,
users,
test.args.alias, test.args.alias,
) )
if (err != nil) != test.wantErr { if (err != nil) != test.wantErr {
@ -1715,6 +1928,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.CompileFilterRules( got, err := tt.field.pol.CompileFilterRules(
[]types.User{},
tt.args.nodes, tt.args.nodes,
) )
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -1834,6 +2048,13 @@ func TestTheInternet(t *testing.T) {
} }
func TestReduceFilterRules(t *testing.T) { func TestReduceFilterRules(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "mickael"},
{Model: gorm.Model{ID: 2}, Name: "user1"},
{Model: gorm.Model{ID: 3}, Name: "user2"},
{Model: gorm.Model{ID: 4}, Name: "user100"},
}
tests := []struct { tests := []struct {
name string name string
node *types.Node node *types.Node
@ -1855,13 +2076,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: types.User{Name: "mickael"}, User: users[0],
}, },
peers: types.Nodes{ peers: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: types.User{Name: "mickael"}, User: users[0],
}, },
}, },
want: []tailcfg.FilterRule{}, want: []tailcfg.FilterRule{},
@ -1888,7 +2109,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{ RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"), netip.MustParsePrefix("10.33.0.0/16"),
@ -1899,7 +2120,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -1967,19 +2188,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
peers: types.Nodes{ peers: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
// "internal" exit node // "internal" exit node
&types.Node{ &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -2026,12 +2247,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2113,7 +2334,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -2122,12 +2343,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2215,7 +2436,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
}, },
@ -2224,12 +2445,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2292,7 +2513,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
}, },
@ -2301,12 +2522,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2362,7 +2583,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
}, },
@ -2372,7 +2593,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2400,6 +2621,7 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, _ := tt.pol.CompileFilterRules( got, _ := tt.pol.CompileFilterRules(
users,
append(tt.peers, tt.node), append(tt.peers, tt.node),
) )
@ -3391,7 +3613,7 @@ func TestSSHRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers)
assert.NoError(t, err) assert.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
@ -3474,14 +3696,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
RequestTags: []string{"tag:test"}, RequestTags: []string{"tag:test"},
} }
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
Hostname: "testnodes", Hostname: "testnodes",
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 0, UserID: 0,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
@ -3498,7 +3723,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3532,7 +3757,8 @@ func TestInvalidTagValidUser(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: types.User{
Name: "user1", Model: gorm.Model{ID: 1},
Name: "user1",
}, },
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
@ -3549,7 +3775,7 @@ func TestInvalidTagValidUser(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3583,7 +3809,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: types.User{
Name: "user1", Model: gorm.Model{ID: 1},
Name: "user1",
}, },
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
@ -3608,7 +3835,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3637,15 +3864,17 @@ func TestValidTagInvalidUser(t *testing.T) {
Hostname: "webserver", Hostname: "webserver",
RequestTags: []string{"tag:webapp"}, RequestTags: []string{"tag:webapp"},
} }
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{ node := &types.Node{
ID: 1, ID: 1,
Hostname: "webserver", Hostname: "webserver",
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
@ -3656,13 +3885,11 @@ func TestValidTagInvalidUser(t *testing.T) {
} }
nodes2 := &types.Node{ nodes2 := &types.Node{
ID: 2, ID: 2,
Hostname: "user", Hostname: "user",
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
UserID: 1, UserID: 1,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo2, Hostinfo: &hostInfo2,
} }
@ -3678,7 +3905,7 @@ func TestValidTagInvalidUser(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{

181
hscontrol/policy/pm.go Normal file
View file

@ -0,0 +1,181 @@
package policy
import (
"fmt"
"io"
"net/netip"
"os"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
)
type PolicyManager interface {
Filter() []tailcfg.FilterRule
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
Tags(*types.Node) []string
ApproversForRoute(netip.Prefix) []string
ExpandAlias(string) (*netipx.IPSet, error)
SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes types.Nodes) (bool, error)
}
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
policyFile, err := os.Open(path)
if err != nil {
return nil, err
}
defer policyFile.Close()
policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return nil, err
}
return NewPolicyManager(policyBytes, users, nodes)
}
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
var pol *ACLPolicy
var err error
if polB != nil && len(polB) > 0 {
pol, err = LoadACLPolicyFromBytes(polB)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}
}
pm := PolicyManagerV1{
pol: pol,
users: users,
nodes: nodes,
}
_, err = pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
pm := PolicyManagerV1{
pol: pol,
users: users,
nodes: nodes,
}
_, err := pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
type PolicyManagerV1 struct {
mu sync.Mutex
pol *ACLPolicy
users []types.User
nodes types.Nodes
filterHash deephash.Sum
filter []tailcfg.FilterRule
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return false, fmt.Errorf("compiling filter rules: %w", err)
}
filterHash := deephash.Hash(&filter)
if filterHash == pm.filterHash {
return false, nil
}
pm.filter = filter
pm.filterHash = filterHash
return true, nil
}
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
}
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
pol, err := LoadACLPolicyFromBytes(polB)
if err != nil {
return false, fmt.Errorf("parsing policy: %w", err)
}
pm.mu.Lock()
defer pm.mu.Unlock()
pm.pol = pol
return pm.updateLocked()
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
if pm == nil {
return nil
}
tags, _ := pm.pol.TagsOfNode(node)
return tags
}
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
// TODO(kradalby): This can be a parse error of the address in the policy,
// in the new policy this will be typed and not a problem, in this policy
// we will just return empty list
if pm.pol == nil {
return nil
}
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
return approvers
}
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
if err != nil {
return nil, err
}
return ips, nil
}

158
hscontrol/policy/pm_test.go Normal file
View file

@ -0,0 +1,158 @@
package policy
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func TestPolicySetChange(t *testing.T) {
users := []types.User{
{
Model: gorm.Model{ID: 1},
Name: "testuser",
},
}
tests := []struct {
name string
users []types.User
nodes types.Nodes
policy []byte
wantUsersChange bool
wantNodesChange bool
wantPolicyChange bool
wantFilter []tailcfg.FilterRule
}{
{
name: "set-nodes",
nodes: types.Nodes{
{
IPv4: iap("100.64.0.2"),
User: users[0],
},
},
wantNodesChange: false,
wantFilter: []tailcfg.FilterRule{
{
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-users",
users: users,
wantUsersChange: false,
wantFilter: []tailcfg.FilterRule{
{
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-users-and-node",
users: users,
nodes: types.Nodes{
{
IPv4: iap("100.64.0.2"),
User: users[0],
},
},
wantUsersChange: false,
wantNodesChange: true,
wantFilter: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.2/32"},
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-policy",
policy: []byte(`
{
"acls": [
{
"action": "accept",
"src": [
"100.64.0.61",
],
"dst": [
"100.64.0.62:*",
],
},
],
}
`),
wantPolicyChange: true,
wantFilter: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.61/32"},
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pol := `
{
"groups": {
"group:example": [
"testuser",
],
},
"hosts": {
"host-1": "100.64.0.1",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"group:example",
],
"dst": [
"host-1:*",
],
},
],
}
`
pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
require.NoError(t, err)
if tt.policy != nil {
change, err := pm.SetPolicy(tt.policy)
require.NoError(t, err)
assert.Equal(t, tt.wantPolicyChange, change)
}
if tt.users != nil {
change, err := pm.SetUsers(tt.users)
require.NoError(t, err)
assert.Equal(t, tt.wantUsersChange, change)
}
if tt.nodes != nil {
change, err := pm.SetNodes(tt.nodes)
require.NoError(t, err)
assert.Equal(t, tt.wantNodesChange, change)
}
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() {
switch update.Type { switch update.Type {
case types.StateFullUpdate: case types.StateFullUpdate:
m.tracef("Sending Full MapResponse") m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged: case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() {
lastMessage = update.Message lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "change" updateType = "change"
case types.StatePeerChangedPatch: case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
updateType = "patch" updateType = "patch"
case types.StatePeerRemoved: case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed)) changed := make(map[types.NodeID]bool, len(update.Removed))
@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() {
changed[nodeID] = false changed[nodeID] = false
} }
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "remove" updateType = "remove"
case types.StateSelfUpdate: case types.StateSelfUpdate:
lastMessage = update.Message lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent // create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove" updateType = "remove"
case types.StateDERPUpdated: case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse") m.tracef("Sending DERPUpdate MapResponse")
@ -488,9 +488,12 @@ func (m *mapSession) handleEndpointUpdate() {
return return
} }
if m.h.ACLPolicy != nil { // TODO(kradalby): Only update the node that has actually changed
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
if m.h.polMan != nil {
// update routes with peer information // update routes with peer information
err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
if err != nil { if err != nil {
m.errf(err, "Error running auto approved routes") m.errf(err, "Error running auto approved routes")
mapResponseEndpointUpdates.WithLabelValues("error").Inc() mapResponseEndpointUpdates.WithLabelValues("error").Inc()
@ -544,7 +547,7 @@ func (m *mapSession) handleEndpointUpdate() {
func (m *mapSession) handleReadOnlyRequest() { func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers") m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy) mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
if err != nil { if err != nil {
m.errf(err, "Failed to create MapResponse") m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError) http.Error(m.w, "", http.StatusInternalServerError)

View file

@ -13,6 +13,7 @@ import (
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
) )
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
@ -786,117 +787,85 @@ func TestNodeTagCommand(t *testing.T) {
) )
} }
func TestNodeAdvertiseTagNoACLCommand(t *testing.T) { func TestNodeAdvertiseTagCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario(dockertestMaxWait()) tests := []struct {
assertNoErr(t, err) name string
defer scenario.ShutdownAssertNoPanics(t) policy *policy.ACLPolicy
wantTag bool
spec := map[string]int{ }{
"user1": 1, {
} name: "no-policy",
wantTag: false,
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec["user1"])
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--tags",
"--output", "json",
}, },
&resultMachines, {
) name: "with-policy",
assert.Nil(t, err) policy: &policy.ACLPolicy{
found := false ACLs: []policy.ACL{
for _, node := range resultMachines { {
if node.GetInvalidTags() != nil { Action: "accept",
for _, tag := range node.GetInvalidTags() { Sources: []string{"*"},
if tag == "tag:test" { Destinations: []string{"*:*"},
found = true },
} },
} TagOwners: map[string][]string{
} "tag:test": {"user1"},
}
assert.Equal(
t,
true,
found,
"should not find a node with the tag 'tag:test' in the list of nodes",
)
}
func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy(
&policy.ACLPolicy{
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
}, },
}, },
TagOwners: map[string][]string{ wantTag: true,
"tag:exists": {"user1"},
},
}, },
)) }
assertNoErr(t, err)
headscale, err := scenario.Headscale() for _, tt := range tests {
assertNoErr(t, err) t.Run(tt.name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
// defer scenario.ShutdownAssertNoPanics(t)
// Test list all nodes after added seconds spec := map[string]int{
resultMachines := make([]*v1.Node, spec["user1"]) "user1": 1,
err = executeAndUnmarshal( }
headscale,
[]string{ err = scenario.CreateHeadscaleEnv(spec,
"headscale", []tsic.Option{tsic.WithTags([]string{"tag:test"})},
"nodes", hsic.WithTestName("cliadvtags"),
"list", hsic.WithACLPolicy(tt.policy),
"--tags", )
"--output", "json", assertNoErr(t, err)
},
&resultMachines, headscale, err := scenario.Headscale()
) assertNoErr(t, err)
assert.Nil(t, err)
found := false // Test list all nodes after added seconds
for _, node := range resultMachines { resultMachines := make([]*v1.Node, spec["user1"])
if node.GetValidTags() != nil { err = executeAndUnmarshal(
for _, tag := range node.GetValidTags() { headscale,
if tag == "tag:exists" { []string{
found = true "headscale",
"nodes",
"list",
"--tags",
"--output", "json",
},
&resultMachines,
)
assert.Nil(t, err)
found := false
for _, node := range resultMachines {
if tags := node.GetValidTags(); tags != nil {
found = slices.Contains(tags, "tag:test")
} }
} }
} assert.Equalf(
t,
tt.wantTag,
found,
"'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag,
)
})
} }
assert.Equal(
t,
true,
found,
"should not find a node with the tag 'tag:exists' in the list of nodes",
)
} }
func TestNodeCommand(t *testing.T) { func TestNodeCommand(t *testing.T) {
@ -1732,7 +1701,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
policyFilePath, policyFilePath,
}, },
) )
assert.ErrorContains(t, err, "verifying policy rules: invalid action") assert.ErrorContains(t, err, "compiling filter rules: invalid action")
// The new policy was invalid, the old one should still be in place, which // The new policy was invalid, the old one should still be in place, which
// is none. // is none.

View file

@ -81,6 +81,10 @@ type Option = func(c *HeadscaleInContainer)
// HeadscaleInContainer instance. // HeadscaleInContainer instance.
func WithACLPolicy(acl *policy.ACLPolicy) Option { func WithACLPolicy(acl *policy.ACLPolicy) Option {
return func(hsic *HeadscaleInContainer) { return func(hsic *HeadscaleInContainer) {
if acl == nil {
return
}
// TODO(kradalby): Move somewhere appropriate // TODO(kradalby): Move somewhere appropriate
hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath