diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index 7e730aa8..56bd8f1b 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -30,8 +30,7 @@ jobs: - TestPreAuthKeyCorrectUserLoggedInCommand - TestApiKeyCommand - TestNodeTagCommand - - TestNodeAdvertiseTagNoACLCommand - - TestNodeAdvertiseTagWithACLCommand + - TestNodeAdvertiseTagCommand - TestNodeCommand - TestNodeExpireCommand - TestNodeRenameCommand diff --git a/hscontrol/app.go b/hscontrol/app.go index da20b1ae..1651b8f2 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -88,7 +88,8 @@ type Headscale struct { DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - ACLPolicy *policy.ACLPolicy + polManOnce sync.Once + polMan policy.PolicyManager mapper *mapper.Mapper nodeNotifier *notifier.Notifier @@ -153,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } }) + if err = app.loadPolicyManager(); err != nil { + return nil, fmt.Errorf("failed to load ACL policy: %w", err) + } + var authProvider AuthProvider authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { @@ -165,6 +170,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.db, app.nodeNotifier, app.ipAlloc, + app.polMan, ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { @@ -475,6 +481,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { return router } +// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// Maybe we should attempt a new in memory state and not go via the DB? +func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + users, err := db.ListUsers() + if err != nil { + return err + } + + changed, err := polMan.SetUsers(users) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + +// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed. +// Maybe we should attempt a new in memory state and not go via the DB? +func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error { + nodes, err := db.ListNodes() + if err != nil { + return err + } + + changed, err := polMan.SetNodes(nodes) + if err != nil { + return err + } + + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all") + notif.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } + + return nil +} + // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { if profilingEnabled { @@ -490,19 +542,13 @@ func (h *Headscale) Serve() error { } } - var err error - - if err = h.loadACLPolicy(); err != nil { - return fmt.Errorf("failed to load ACL policy: %w", err) - } - if dumpConfig { spew.Dump(h.cfg) } // Fetch an initial DERP Map before we start serving h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier) + h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan) if h.cfg.DERP.ServerEnabled { // When embedded DERP is enabled we always need a STUN server @@ -772,12 +818,21 @@ func (h *Headscale) Serve() error { Str("signal", sig.String()). Msg("Received SIGHUP, reloading ACL and Config") - // TODO(kradalby): Reload config on SIGHUP - if err := h.loadACLPolicy(); err != nil { - log.Error().Err(err).Msg("failed to reload ACL policy") + if err := h.loadPolicyManager(); err != nil { + log.Error().Err(err).Msg("failed to reload Policy") } - if h.ACLPolicy != nil { + pol, err := h.policyBytes() + if err != nil { + log.Error().Err(err).Msg("failed to get policy blob") + } + + changed, err := h.polMan.SetPolicy(pol) + if err != nil { + log.Error().Err(err).Msg("failed to set new policy") + } + + if changed { log.Info(). Msg("ACL policy successfully reloaded, notifying nodes of change") @@ -996,27 +1051,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } -func (h *Headscale) loadACLPolicy() error { - var ( - pol *policy.ACLPolicy - err error - ) - +// policyBytes returns the appropriate policy for the +// current configuration as a []byte array. +func (h *Headscale) policyBytes() ([]byte, error) { switch h.cfg.Policy.Mode { case types.PolicyModeFile: path := h.cfg.Policy.Path // It is fine to start headscale without a policy file. if len(path) == 0 { - return nil + return nil, nil } absPath := util.AbsolutePathFromConfigPath(path) - pol, err = policy.LoadACLPolicyFromPath(absPath) + policyFile, err := os.Open(absPath) if err != nil { - return fmt.Errorf("failed to load ACL policy from file: %w", err) + return nil, err + } + defer policyFile.Close() + + return io.ReadAll(policyFile) + + case types.PolicyModeDB: + p, err := h.db.GetPolicy() + if err != nil { + if errors.Is(err, types.ErrPolicyNotFound) { + return nil, nil + } + + return nil, err } + return []byte(p.Data), err + } + + return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode) +} + +func (h *Headscale) loadPolicyManager() error { + var errOut error + h.polManOnce.Do(func() { // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -1027,42 +1101,35 @@ func (h *Headscale) loadACLPolicy() error { // allowed to be written to the database. nodes, err := h.db.ListNodes() if err != nil { - return fmt.Errorf("loading nodes from database to validate policy: %w", err) + errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err) + return + } + users, err := h.db.ListUsers() + if err != nil { + errOut = fmt.Errorf("loading users from database to validate policy: %w", err) + return } - _, err = pol.CompileFilterRules(nodes) + pol, err := h.policyBytes() if err != nil { - return fmt.Errorf("verifying policy rules: %w", err) + errOut = fmt.Errorf("loading policy bytes: %w", err) + return + } + + h.polMan, err = policy.NewPolicyManager(pol, users, nodes) + if err != nil { + errOut = fmt.Errorf("creating policy manager: %w", err) + return } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = h.polMan.SSHPolicy(nodes[0]) if err != nil { - return fmt.Errorf("verifying SSH rules: %w", err) + errOut = fmt.Errorf("verifying SSH rules: %w", err) + return } } + }) - case types.PolicyModeDB: - p, err := h.db.GetPolicy() - if err != nil { - if errors.Is(err, types.ErrPolicyNotFound) { - return nil - } - - return fmt.Errorf("failed to get policy from database: %w", err) - } - - pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data)) - if err != nil { - return fmt.Errorf("failed to parse policy: %w", err) - } - default: - log.Fatal(). - Str("mode", string(h.cfg.Policy.Mode)). - Msg("Unknown ACL policy mode") - } - - h.ACLPolicy = pol - - return nil + return errOut } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 67545031..2b23aad3 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey( return } + + err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier) + if err != nil { + http.Error(writer, "Internal server error", http.StatusInternalServerError) + return + } + } err = h.db.Write(func(tx *gorm.DB) error { diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 888f48db..46dce68b 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(testPeers), check.Equals, 9) - adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) + adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) - testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) + testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user}) c.Assert(err, check.IsNil) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) @@ -559,10 +559,6 @@ func TestAutoApproveRoutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { adb, err := newTestDB() assert.NoError(t, err) - pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl)) - - assert.NoError(t, err) - assert.NotNil(t, pol) user, err := adb.CreateUser("test") assert.NoError(t, err) @@ -599,8 +595,17 @@ func TestAutoApproveRoutes(t *testing.T) { node0ByID, err := adb.GetNodeByID(0) assert.NoError(t, err) + users, err := adb.ListUsers() + assert.NoError(t, err) + + nodes, err := adb.ListNodes() + assert.NoError(t, err) + + pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes) + assert.NoError(t, err) + // TODO(kradalby): Check state update - err = adb.EnableAutoApprovedRoutes(pol, node0ByID) + err = adb.EnableAutoApprovedRoutes(pm, node0ByID) assert.NoError(t, err) enabledRoutes, err := adb.GetEnabledRoutes(node0ByID) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 086261aa..dcf238ab 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -598,18 +598,18 @@ func failoverRoute( } func (hsdb *HSDatabase) EnableAutoApprovedRoutes( - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { return hsdb.Write(func(tx *gorm.DB) error { - return EnableAutoApprovedRoutes(tx, aclPolicy, node) + return EnableAutoApprovedRoutes(tx, polMan, node) }) } // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. func EnableAutoApprovedRoutes( tx *gorm.DB, - aclPolicy *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, ) error { if node.IPv4 == nil && node.IPv6 == nil { @@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes( continue } - routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( - netip.Prefix(advertisedRoute.Prefix), - ) - if err != nil { - return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err) - } + routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix)) log.Trace(). Str("node", node.Hostname). @@ -649,7 +644,7 @@ func EnableAutoApprovedRoutes( approvedRoutes = append(approvedRoutes, advertisedRoute) } else { // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) + approvedIps, err := polMan.ExpandAlias(approvedAlias) if err != nil { return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 68793716..51134e7e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -21,7 +21,6 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/db" - "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" ) @@ -58,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -87,6 +91,11 @@ func (api headscaleV1APIServer) DeleteUser( return nil, err } + err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return &v1.DeleteUserResponse{}, nil } @@ -221,6 +230,11 @@ func (api headscaleV1APIServer) RegisterNode( return nil, err } + err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier) + if err != nil { + return nil, fmt.Errorf("updating resources using node: %w", err) + } + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -450,10 +464,7 @@ func (api headscaleV1APIServer) ListNodes( resp.Online = true } - validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - node, - ) - resp.InvalidTags = invalidTags + validTags := api.h.polMan.Tags(node) resp.ValidTags = validTags response[index] = resp } @@ -723,11 +734,6 @@ func (api headscaleV1APIServer) SetPolicy( p := request.GetPolicy() - pol, err := policy.LoadACLPolicyFromBytes([]byte(p)) - if err != nil { - return nil, fmt.Errorf("loading ACL policy file: %w", err) - } - // Validate and reject configuration that would error when applied // when creating a map response. This requires nodes, so there is still // a scenario where they might be allowed if the server has no nodes @@ -737,14 +743,13 @@ func (api headscaleV1APIServer) SetPolicy( if err != nil { return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) } - - _, err = pol.CompileFilterRules(nodes) + changed, err := api.h.polMan.SetPolicy([]byte(p)) if err != nil { - return nil, fmt.Errorf("verifying policy rules: %w", err) + return nil, fmt.Errorf("setting policy: %w", err) } if len(nodes) > 0 { - _, err = pol.CompileSSHPolicy(nodes[0], nodes) + _, err = api.h.polMan.SSHPolicy(nodes[0]) if err != nil { return nil, fmt.Errorf("verifying SSH rules: %w", err) } @@ -755,12 +760,13 @@ func (api headscaleV1APIServer) SetPolicy( return nil, err } - api.h.ACLPolicy = pol - - ctx := types.NotifyCtx(context.Background(), "acl-update", "na") - api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ - Type: types.StateFullUpdate, - }) + // Only send update if the packet filter has changed. + if changed { + ctx := types.NotifyCtx(context.Background(), "acl-update", "na") + api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{ + Type: types.StateFullUpdate, + }) + } response := &v1.SetPolicyResponse{ Policy: updated.Data, diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 3db1e159..51c96f8c 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -55,6 +55,7 @@ type Mapper struct { cfg *types.Config derpMap *tailcfg.DERPMap notif *notifier.Notifier + polMan policy.PolicyManager uid string created time.Time @@ -71,6 +72,7 @@ func NewMapper( cfg *types.Config, derpMap *tailcfg.DERPMap, notif *notifier.Notifier, + polMan policy.PolicyManager, ) *Mapper { uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) @@ -79,6 +81,7 @@ func NewMapper( cfg: cfg, derpMap: derpMap, notif: notif, + polMan: polMan, uid: uid, created: time.Now(), @@ -153,10 +156,9 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { func (m *Mapper) fullMapResponse( node *types.Node, peers types.Nodes, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, capVer) + resp, err := m.baseWithConfigMapResponse(node, capVer) if err != nil { return nil, err } @@ -164,11 +166,10 @@ func (m *Mapper) fullMapResponse( err = appendPeerChanges( resp, true, // full change - pol, + m.polMan, node, capVer, peers, - peers, m.cfg, ) if err != nil { @@ -182,7 +183,6 @@ func (m *Mapper) fullMapResponse( func (m *Mapper) FullMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { peers, err := m.ListPeers(node.ID) @@ -190,7 +190,7 @@ func (m *Mapper) FullMapResponse( return nil, err } - resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) + resp, err := m.fullMapResponse(node, peers, mapRequest.Version) if err != nil { return nil, err } @@ -204,10 +204,9 @@ func (m *Mapper) FullMapResponse( func (m *Mapper) ReadOnlyMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) + resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) if err != nil { return nil, err } @@ -243,7 +242,6 @@ func (m *Mapper) PeerChangedResponse( node *types.Node, changed map[types.NodeID]bool, patches []*tailcfg.PeerChange, - pol *policy.ACLPolicy, messages ...string, ) ([]byte, error) { resp := m.baseMapResponse() @@ -273,10 +271,9 @@ func (m *Mapper) PeerChangedResponse( err = appendPeerChanges( &resp, false, // partial change - pol, + m.polMan, node, mapRequest.Version, - peers, changedNodes, m.cfg, ) @@ -303,7 +300,7 @@ func (m *Mapper) PeerChangedResponse( // Add the node itself, it might have changed, and particularly // if there are no patches or changes, this is a self update. - tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg) + tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg) if err != nil { return nil, err } @@ -318,7 +315,6 @@ func (m *Mapper) PeerChangedPatchResponse( mapRequest tailcfg.MapRequest, node *types.Node, changed []*tailcfg.PeerChange, - pol *policy.ACLPolicy, ) ([]byte, error) { resp := m.baseMapResponse() resp.PeersChangedPatch = changed @@ -447,12 +443,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse { // incremental. func (m *Mapper) baseWithConfigMapResponse( node *types.Node, - pol *policy.ACLPolicy, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, pol, m.cfg) + tailnode, err := tailNode(node, capVer, m.polMan, m.cfg) if err != nil { return nil, err } @@ -505,34 +500,30 @@ func appendPeerChanges( resp *tailcfg.MapResponse, fullChange bool, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, node *types.Node, capVer tailcfg.CapabilityVersion, - peers types.Nodes, changed types.Nodes, cfg *types.Config, ) error { - packetFilter, err := pol.CompileFilterRules(append(peers, node)) - if err != nil { - return err - } + filter := polMan.Filter() - sshPolicy, err := pol.CompileSSHPolicy(node, peers) + sshPolicy, err := polMan.SSHPolicy(node) if err != nil { return err } // If there are filter rules present, see if there are any nodes that cannot // access each-other at all and remove them from the peers. - if len(packetFilter) > 0 { - changed = policy.FilterNodesByACL(node, changed, packetFilter) + if len(filter) > 0 { + changed = policy.FilterNodesByACL(node, changed, filter) } profiles := generateUserProfiles(node, changed) dnsConfig := generateDNSConfig(cfg, node) - tailPeers, err := tailNodes(changed, capVer, pol, cfg) + tailPeers, err := tailNodes(changed, capVer, polMan, cfg) if err != nil { return err } @@ -557,7 +548,7 @@ func appendPeerChanges( // new PacketFilters field and "base" allows us to send a full update when we // have to send an empty list, avoiding the hack in the else block. resp.PacketFilters = map[string][]tailcfg.FilterRule{ - "base": policy.ReduceFilterRules(node, packetFilter), + "base": policy.ReduceFilterRules(node, filter), } } else { // This is a hack to avoid sending an empty list of packet filters. @@ -565,11 +556,11 @@ func appendPeerChanges( // be omitted, causing the client to consider it unchanged, keeping the // previous packet filter. Worst case, this can cause a node that previously // has access to a node to _not_ loose access if an empty (allow none) is sent. - reduced := policy.ReduceFilterRules(node, packetFilter) + reduced := policy.ReduceFilterRules(node, filter) if len(reduced) > 0 { resp.PacketFilter = reduced } else { - resp.PacketFilter = packetFilter + resp.PacketFilter = filter } } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 37ed5c42..4ee8c644 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) { lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) + user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"} + user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"} + mini := &types.Node{ ID: 0, MachineKey: mustMK( @@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, AuthKey: &types.PreAuthKey{}, LastSeen: &lastSeen, @@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.2"), Hostname: "peer1", GivenName: "peer1", - UserID: 0, - User: types.User{Name: "mini"}, + UserID: user1.ID, + User: user1, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) { IPv4: iap("100.64.0.3"), Hostname: "peer2", GivenName: "peer2", - UserID: 1, - User: types.User{Name: "peer2"}, + UserID: user2.ID, + User: user2, ForcedTags: []string{}, LastSeen: &lastSeen, Expiry: &expire, @@ -458,17 +461,19 @@ func Test_fullMapResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node)) + mappy := NewMapper( nil, tt.cfg, tt.derpMap, nil, + polMan, ) got, err := mappy.fullMapResponse( tt.node, tt.peers, - tt.pol, 0, ) diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 24c521dc..4082df2b 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -14,7 +14,7 @@ import ( func tailNodes( nodes types.Nodes, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) ([]*tailcfg.Node, error) { tNodes := make([]*tailcfg.Node, len(nodes)) @@ -23,7 +23,7 @@ func tailNodes( node, err := tailNode( node, capVer, - pol, + polMan, cfg, ) if err != nil { @@ -40,7 +40,7 @@ func tailNodes( func tailNode( node *types.Node, capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, + polMan policy.PolicyManager, cfg *types.Config, ) (*tailcfg.Node, error) { addrs := node.Prefixes() @@ -81,7 +81,7 @@ func tailNode( return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) } - tags, _ := pol.TagsOfNode(node) + tags := polMan.Tags(node) tags = lo.Uniq(append(tags, node.ForcedTags...)) tNode := tailcfg.Node{ diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index b6692c16..9d7f1fed 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node}) cfg := &types.Config{ BaseDomain: tt.baseDomain, DNSConfig: tt.dnsConfig, @@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) { got, err := tailNode( tt.node, 0, - tt.pol, + polMan, cfg, ) @@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) { tn, err := tailNode( node, 0, - &policy.ACLPolicy{}, + &policy.PolicyManagerV1{}, &types.Config{}, ) if err != nil { diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 84267b41..5028e244 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -53,6 +54,7 @@ type AuthProviderOIDC struct { registrationCache *zcache.Cache[string, key.MachinePublic] notifier *notifier.Notifier ipAlloc *db.IPAllocator + polMan policy.PolicyManager oidcProvider *oidc.Provider oauth2Config *oauth2.Config @@ -65,6 +67,7 @@ func NewAuthProviderOIDC( db *db.HSDatabase, notif *notifier.Notifier, ipAlloc *db.IPAllocator, + polMan policy.PolicyManager, ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already @@ -96,6 +99,7 @@ func NewAuthProviderOIDC( registrationCache: registrationCache, notifier: notif, ipAlloc: ipAlloc, + polMan: polMan, oidcProvider: oidcProvider, oauth2Config: oauth2Config, @@ -461,6 +465,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return nil, fmt.Errorf("creating or updating user: %w", err) } + err = usersChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return nil, fmt.Errorf("updating resources using user: %w", err) + } + return user, nil } @@ -484,6 +493,11 @@ func (a *AuthProviderOIDC) registerNode( return fmt.Errorf("could not register node: %w", err) } + err = nodesChangedHook(a.db, a.polMan, a.notifier) + if err != nil { + return fmt.Errorf("updating resources using node: %w", err) + } + return nil } diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index ff73985b..8e2d1961 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests( policy *ACLPolicy, node *types.Node, peers types.Nodes, + users []types.User, ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { // If there is no policy defined, we default to allow all if policy == nil { return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.CompileFilterRules(append(peers, node)) + rules, err := policy.CompileFilterRules(users, append(peers, node)) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - sshPolicy, err := policy.CompileSSHPolicy(node, peers) + sshPolicy, err := policy.CompileSSHPolicy(node, users, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests( // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) CompileFilterRules( + users []types.User, nodes types.Nodes, ) ([]tailcfg.FilterRule, error) { if pol == nil { @@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules( var srcIPs []string for srcIndex, src := range acl.Sources { - srcs, err := pol.expandSource(src, nodes) + srcs, err := pol.expandSource(src, users, nodes) if err != nil { - return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) + return nil, fmt.Errorf( + "parsing policy, acl index: %d->%d: %w", + index, + srcIndex, + err, + ) } srcIPs = append(srcIPs, srcs...) } @@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules( expanded, err := pol.ExpandAlias( nodes, + users, alias, ) if err != nil { @@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F func (pol *ACLPolicy) CompileSSHPolicy( node *types.Node, + users []types.User, peers types.Nodes, ) (*tailcfg.SSHPolicy, error) { if pol == nil { @@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( for index, sshACL := range pol.SSHs { var dest netipx.IPSetBuilder for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, node), src) + expanded, err := pol.ExpandAlias(append(peers, node), users, src) if err != nil { return nil, err } @@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy( case "check": checkAction, err := sshCheckAction(sshACL.CheckPeriod) if err != nil { - return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) + return nil, fmt.Errorf( + "parsing SSH policy, parsing check duration, index: %d: %w", + index, + err, + ) } else { action = *checkAction } default: - return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) + return nil, fmt.Errorf( + "parsing SSH policy, unknown action %q, index: %d: %w", + sshACL.Action, + index, + err, + ) } principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) @@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy( } else { expandedSrcs, err := pol.ExpandAlias( peers, + users, rawSrc, ) if err != nil { @@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) { // with the given src alias. func (pol *ACLPolicy) expandSource( src string, + users []types.User, nodes types.Nodes, ) ([]string, error) { - ipSet, err := pol.ExpandAlias(nodes, src) + ipSet, err := pol.ExpandAlias(nodes, users, src) if err != nil { return []string{}, err } @@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource( // and transform these in IPAddresses. func (pol *ACLPolicy) ExpandAlias( nodes types.Nodes, + users []types.User, alias string, ) (*netipx.IPSet, error) { if isWildcard(alias) { @@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is a group if isGroup(alias) { - return pol.expandIPsFromGroup(alias, nodes) + return pol.expandIPsFromGroup(alias, users, nodes) } // if alias is a tag if isTag(alias) { - return pol.expandIPsFromTag(alias, nodes) + return pol.expandIPsFromTag(alias, users, nodes) } if isAutoGroup(alias) { @@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias( } // if alias is a user - if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { + if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil { return ips, err } @@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias( if h, ok := pol.Hosts[alias]; ok { log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.ExpandAlias(nodes, h.String()) + return pol.ExpandAlias(nodes, users, h.String()) } // if alias is an IP @@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup( func (pol *ACLPolicy) expandIPsFromGroup( group string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - users, err := pol.expandUsersFromGroup(group) + userTokens, err := pol.expandUsersFromGroup(group) if err != nil { return &netipx.IPSet{}, err } - for _, user := range users { - filteredNodes := filterNodesByUser(nodes, user) + for _, user := range userTokens { + filteredNodes := filterNodesByUser(nodes, users, user) for _, node := range filteredNodes { node.AppendToIPSet(&build) } @@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromTag( alias string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder @@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag( // filter out nodes per tag owner for _, user := range owners { - nodes := filterNodesByUser(nodes, user) + nodes := filterNodesByUser(nodes, users, user) for _, node := range nodes { if node.Hostinfo == nil { continue @@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromUser( user string, + users []types.User, nodes types.Nodes, ) (*netipx.IPSet, error) { var build netipx.IPSetBuilder - filteredNodes := filterNodesByUser(nodes, user) + filteredNodes := filterNodesByUser(nodes, users, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) // shortcurcuit if we have no nodes to get ips from. @@ -953,10 +977,40 @@ func (pol *ACLPolicy) TagsOfNode( return validTags, invalidTags } -func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { +// filterNodesByUser returns a list of nodes that match the given userToken from a +// policy. +// Matching nodes are determined by first matching the user token to a user by checking: +// - If it is an ID that mactches the user database ID +// - It is the Provider Identifier from OIDC +// - It matches the username or email of a user +// +// If the token matches more than one user, zero nodes will returned. +func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes { var out types.Nodes + + var potentialUsers []types.User + for _, user := range users { + if user.ProviderIdentifier == userToken { + potentialUsers = append(potentialUsers, user) + + break + } + if user.Email == userToken { + potentialUsers = append(potentialUsers, user) + } + if user.Name == userToken { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) != 1 { + return nil + } + + user := potentialUsers[0] + for _, node := range nodes { - if node.User.Username() == user { + if node.User.ID == user.ID { out = append(out, node) } } diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 1c6e4de8..f13d7f42 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -2,8 +2,10 @@ package policy import ( "errors" + "math/rand/v2" "net/netip" "slices" + "sort" "testing" "github.com/google/go-cmp/cmp" @@ -14,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "go4.org/netipx" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) @@ -375,18 +378,24 @@ func TestParsing(t *testing.T) { return } - rules, err := pol.CompileFilterRules(types.Nodes{ - &types.Node{ - IPv4: iap("100.100.100.100"), + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser", + } + rules, err := pol.CompileFilterRules( + []types.User{ + user, }, - &types.Node{ - IPv4: iap("200.200.200.200"), - User: types.User{ - Name: "testuser", + types.Nodes{ + &types.Node{ + IPv4: iap("100.100.100.100"), }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) + &types.Node{ + IPv4: iap("200.200.200.200"), + User: user, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) if (err != nil) != tt.wantErr { t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) @@ -533,7 +542,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.CompileFilterRules(types.Nodes{}) + rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{}) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -549,7 +558,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -568,7 +577,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -584,7 +593,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) + _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{}) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } @@ -861,6 +870,14 @@ func Test_expandPorts(t *testing.T) { } func Test_listNodesInUser(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "marc"}, + {Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"}, + {Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"}, + {Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"}, + {Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"}, + } + type args struct { nodes types.Nodes user string @@ -874,50 +891,239 @@ func Test_listNodesInUser(t *testing.T) { name: "1 node in user", args: args{ nodes: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, user: "joe", }, want: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, + &types.Node{User: users[1]}, }, }, { name: "3 nodes, 2 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, user: "marc", }, want: types.Nodes{ - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, }, }, { name: "5 nodes, 0 in user", args: args{ nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - &types.Node{ID: 4, User: types.User{Name: "marc"}}, - &types.Node{ID: 5, User: types.User{Name: "marc"}}, + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[0]}, + &types.Node{ID: 3, User: users[0]}, + &types.Node{ID: 4, User: users[0]}, + &types.Node{ID: 5, User: users[0]}, }, user: "mickael", }, want: nil, }, + { + name: "match-by-provider-ident", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[2]}, + }, + }, + { + name: "match-by-email", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + }, + }, + { + name: "multi-match-is-zero", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[1]}, + &types.Node{ID: 2, User: users[2]}, + &types.Node{ID: 3, User: users[3]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-email-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match email, then provider id + &types.Node{ID: 3, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "multi-username-first-match-is-zero", + args: args{ + nodes: types.Nodes{ + // First match username, then provider id + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 2, User: users[2]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-duplicate-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael", + }, + want: nil, + }, + { + name: "all-users-unique-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "marc", + }, + want: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + }, + }, + { + name: "all-users-no-username-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "mikael@headscale.net", + }, + want: nil, + }, + { + name: "all-users-duplicate-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "joe@headscale.net", + }, + want: types.Nodes{ + &types.Node{ID: 2, User: users[1]}, + }, + }, + { + name: "all-users-no-email-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "not-working@headscale.net", + }, + want: nil, + }, + { + name: "all-users-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "http://oidc.org/1234", + }, + want: types.Nodes{ + &types.Node{ID: 3, User: users[2]}, + }, + }, + { + name: "all-users-no-provider-id-random-order", + args: args{ + nodes: types.Nodes{ + &types.Node{ID: 1, User: users[0]}, + &types.Node{ID: 2, User: users[1]}, + &types.Node{ID: 3, User: users[2]}, + &types.Node{ID: 4, User: users[3]}, + &types.Node{ID: 5, User: users[4]}, + }, + user: "http://oidc.org/4321", + }, + want: nil, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got := filterNodesByUser(test.args.nodes, test.args.user) + for range 1000 { + ns := test.args.nodes + rand.Shuffle(len(ns), func(i, j int) { + ns[i], ns[j] = ns[j], ns[i] + }) + got := filterNodesByUser(ns, users, test.args.user) + sort.Slice(got, func(i, j int) bool { + return got[i].ID < got[j].ID + }) - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) + if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { + t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff) + } } }) } @@ -940,6 +1146,12 @@ func Test_expandAlias(t *testing.T) { return s } + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "joe"}, + {Model: gorm.Model{ID: 2}, Name: "marc"}, + {Model: gorm.Model{ID: 3}, Name: "mickael"}, + } + type field struct { pol ACLPolicy } @@ -989,19 +1201,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1022,19 +1234,19 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1185,7 +1397,7 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1194,7 +1406,7 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1203,11 +1415,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], }, }, }, @@ -1260,21 +1472,21 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1295,12 +1507,12 @@ func Test_expandAlias(t *testing.T) { nodes: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.1"), - User: types.User{Name: "joe"}, + User: users[0], ForcedTags: []string{"tag:hr-webserver"}, }, &types.Node{ IPv4: iap("100.64.0.2"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{ OS: "centos", Hostname: "foo", @@ -1309,11 +1521,11 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "mickael"}, + User: users[2], }, }, }, @@ -1350,12 +1562,12 @@ func Test_expandAlias(t *testing.T) { }, &types.Node{ IPv4: iap("100.64.0.3"), - User: types.User{Name: "marc"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{}, }, &types.Node{ IPv4: iap("100.64.0.4"), - User: types.User{Name: "joe"}, + User: users[0], Hostinfo: &tailcfg.Hostinfo{}, }, }, @@ -1368,6 +1580,7 @@ func Test_expandAlias(t *testing.T) { t.Run(test.name, func(t *testing.T) { got, err := test.field.pol.ExpandAlias( test.args.nodes, + users, test.args.alias, ) if (err != nil) != test.wantErr { @@ -1715,6 +1928,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.field.pol.CompileFilterRules( + []types.User{}, tt.args.nodes, ) if (err != nil) != tt.wantErr { @@ -1834,6 +2048,13 @@ func TestTheInternet(t *testing.T) { } func TestReduceFilterRules(t *testing.T) { + users := []types.User{ + {Model: gorm.Model{ID: 1}, Name: "mickael"}, + {Model: gorm.Model{ID: 2}, Name: "user1"}, + {Model: gorm.Model{ID: 3}, Name: "user2"}, + {Model: gorm.Model{ID: 4}, Name: "user100"}, + } + tests := []struct { name string node *types.Node @@ -1855,13 +2076,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: types.User{Name: "mickael"}, + User: users[0], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: types.User{Name: "mickael"}, + User: users[0], }, }, want: []tailcfg.FilterRule{}, @@ -1888,7 +2109,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -1899,7 +2120,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -1967,19 +2188,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, peers: types.Nodes{ &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, // "internal" exit node &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2026,12 +2247,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2113,7 +2334,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -2122,12 +2343,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2215,7 +2436,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -2224,12 +2445,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2292,7 +2513,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -2301,12 +2522,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.2"), IPv6: iap("fd7a:115c:a1e0::2"), - User: types.User{Name: "user2"}, + User: users[2], }, &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2362,7 +2583,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: iap("100.64.0.100"), IPv6: iap("fd7a:115c:a1e0::100"), - User: types.User{Name: "user100"}, + User: users[3], Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -2372,7 +2593,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: iap("100.64.0.1"), IPv6: iap("fd7a:115c:a1e0::1"), - User: types.User{Name: "user1"}, + User: users[1], }, }, want: []tailcfg.FilterRule{ @@ -2400,6 +2621,7 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, _ := tt.pol.CompileFilterRules( + users, append(tt.peers, tt.node), ) @@ -3391,7 +3613,7 @@ func TestSSHRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) + got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers) assert.NoError(t, err) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -3474,14 +3696,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { RequestTags: []string{"tag:test"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + node := &types.Node{ - ID: 0, - Hostname: "testnodes", - IPv4: iap("100.64.0.1"), - UserID: 0, - User: types.User{ - Name: "user1", - }, + ID: 0, + Hostname: "testnodes", + IPv4: iap("100.64.0.1"), + UserID: 0, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3498,7 +3723,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3532,7 +3757,8 @@ func TestInvalidTagValidUser(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3549,7 +3775,7 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3583,7 +3809,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { IPv4: iap("100.64.0.1"), UserID: 1, User: types.User{ - Name: "user1", + Model: gorm.Model{ID: 1}, + Name: "user1", }, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, @@ -3608,7 +3835,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -3637,15 +3864,17 @@ func TestValidTagInvalidUser(t *testing.T) { Hostname: "webserver", RequestTags: []string{"tag:webapp"}, } + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } node := &types.Node{ - ID: 1, - Hostname: "webserver", - IPv4: iap("100.64.0.1"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 1, + Hostname: "webserver", + IPv4: iap("100.64.0.1"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo, } @@ -3656,13 +3885,11 @@ func TestValidTagInvalidUser(t *testing.T) { } nodes2 := &types.Node{ - ID: 2, - Hostname: "user", - IPv4: iap("100.64.0.2"), - UserID: 1, - User: types.User{ - Name: "user1", - }, + ID: 2, + Hostname: "user", + IPv4: iap("100.64.0.2"), + UserID: 1, + User: user, RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &hostInfo2, } @@ -3678,7 +3905,7 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) + got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user}) assert.NoError(t, err) want := []tailcfg.FilterRule{ diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go new file mode 100644 index 00000000..7dbaed33 --- /dev/null +++ b/hscontrol/policy/pm.go @@ -0,0 +1,181 @@ +package policy + +import ( + "fmt" + "io" + "net/netip" + "os" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/util/deephash" +) + +type PolicyManager interface { + Filter() []tailcfg.FilterRule + SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) + Tags(*types.Node) []string + ApproversForRoute(netip.Prefix) []string + ExpandAlias(string) (*netipx.IPSet, error) + SetPolicy([]byte) (bool, error) + SetUsers(users []types.User) (bool, error) + SetNodes(nodes types.Nodes) (bool, error) +} + +func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) { + policyFile, err := os.Open(path) + if err != nil { + return nil, err + } + defer policyFile.Close() + + policyBytes, err := io.ReadAll(policyFile) + if err != nil { + return nil, err + } + + return NewPolicyManager(policyBytes, users, nodes) +} + +func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { + var pol *ACLPolicy + var err error + if polB != nil && len(polB) > 0 { + pol, err = LoadACLPolicyFromBytes(polB) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + } + + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + _, err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) { + pm := PolicyManagerV1{ + pol: pol, + users: users, + nodes: nodes, + } + + _, err := pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +type PolicyManagerV1 struct { + mu sync.Mutex + pol *ACLPolicy + + users []types.User + nodes types.Nodes + + filterHash deephash.Sum + filter []tailcfg.FilterRule +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManagerV1) updateLocked() (bool, error) { + filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("compiling filter rules: %w", err) + } + + filterHash := deephash.Hash(&filter) + if filterHash == pm.filterHash { + return false, nil + } + + pm.filter = filter + pm.filterHash = filterHash + + return true, nil +} + +func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule { + pm.mu.Lock() + defer pm.mu.Unlock() + return pm.filter +} + +func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes) +} + +func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) { + pol, err := LoadACLPolicyFromBytes(polB) + if err != nil { + return false, fmt.Errorf("parsing policy: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.pol = pol + + return pm.updateLocked() +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + pm.users = users + return pm.updateLocked() +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + pm.nodes = nodes + return pm.updateLocked() +} + +func (pm *PolicyManagerV1) Tags(node *types.Node) []string { + if pm == nil { + return nil + } + + tags, _ := pm.pol.TagsOfNode(node) + return tags +} + +func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string { + // TODO(kradalby): This can be a parse error of the address in the policy, + // in the new policy this will be typed and not a problem, in this policy + // we will just return empty list + if pm.pol == nil { + return nil + } + approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route) + return approvers +} + +func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) { + ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias) + if err != nil { + return nil, err + } + return ips, nil +} diff --git a/hscontrol/policy/pm_test.go b/hscontrol/policy/pm_test.go new file mode 100644 index 00000000..24b78e4d --- /dev/null +++ b/hscontrol/policy/pm_test.go @@ -0,0 +1,158 @@ +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" +) + +func TestPolicySetChange(t *testing.T) { + users := []types.User{ + { + Model: gorm.Model{ID: 1}, + Name: "testuser", + }, + } + tests := []struct { + name string + users []types.User + nodes types.Nodes + policy []byte + wantUsersChange bool + wantNodesChange bool + wantPolicyChange bool + wantFilter []tailcfg.FilterRule + }{ + { + name: "set-nodes", + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantNodesChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users", + users: users, + wantUsersChange: false, + wantFilter: []tailcfg.FilterRule{ + { + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-users-and-node", + users: users, + nodes: types.Nodes{ + { + IPv4: iap("100.64.0.2"), + User: users[0], + }, + }, + wantUsersChange: false, + wantNodesChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + name: "set-policy", + policy: []byte(` +{ +"acls": [ + { + "action": "accept", + "src": [ + "100.64.0.61", + ], + "dst": [ + "100.64.0.62:*", + ], + }, + ], +} + `), + wantPolicyChange: true, + wantFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.61/32"}, + DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol := ` +{ + "groups": { + "group:example": [ + "testuser", + ], + }, + + "hosts": { + "host-1": "100.64.0.1", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +` + pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{}) + require.NoError(t, err) + + if tt.policy != nil { + change, err := pm.SetPolicy(tt.policy) + require.NoError(t, err) + + assert.Equal(t, tt.wantPolicyChange, change) + } + + if tt.users != nil { + change, err := pm.SetUsers(tt.users) + require.NoError(t, err) + + assert.Equal(t, tt.wantUsersChange, change) + } + + if tt.nodes != nil { + change, err := pm.SetNodes(tt.nodes) + require.NoError(t, err) + + assert.Equal(t, tt.wantNodesChange, change) + } + + if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" { + t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/poll.go b/hscontrol/poll.go index a8ae01f4..e6047d45 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() { switch update.Type { case types.StateFullUpdate: m.tracef("Sending Full MapResponse") - data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) + data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming())) case types.StatePeerChanged: changed := make(map[types.NodeID]bool, len(update.ChangeNodes)) @@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() { lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "change" case types.StatePeerChangedPatch: m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy) + data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches) updateType = "patch" case types.StatePeerRemoved: changed := make(map[types.NodeID]bool, len(update.Removed)) @@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() { changed[nodeID] = false } m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) - data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage) updateType = "remove" case types.StateSelfUpdate: lastMessage = update.Message m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) // create the map so an empty (self) update is sent - data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage) + data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage) updateType = "remove" case types.StateDERPUpdated: m.tracef("Sending DERPUpdate MapResponse") @@ -488,9 +488,12 @@ func (m *mapSession) handleEndpointUpdate() { return } - if m.h.ACLPolicy != nil { + // TODO(kradalby): Only update the node that has actually changed + nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier) + + if m.h.polMan != nil { // update routes with peer information - err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) + err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node) if err != nil { m.errf(err, "Error running auto approved routes") mapResponseEndpointUpdates.WithLabelValues("error").Inc() @@ -544,7 +547,7 @@ func (m *mapSession) handleEndpointUpdate() { func (m *mapSession) handleReadOnlyRequest() { m.tracef("Client asked for a lite update, responding without peers") - mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy) + mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node) if err != nil { m.errf(err, "Failed to create MapResponse") http.Error(m.w, "", http.StatusInternalServerError) diff --git a/integration/cli_test.go b/integration/cli_test.go index 2b81e814..19e865cd 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -13,6 +13,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" ) func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { @@ -786,117 +787,85 @@ func TestNodeTagCommand(t *testing.T) { ) } -func TestNodeAdvertiseTagNoACLCommand(t *testing.T) { +func TestNodeAdvertiseTagCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags")) - assertNoErr(t, err) - - headscale, err := scenario.Headscale() - assertNoErr(t, err) - - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", + tests := []struct { + name string + policy *policy.ACLPolicy + wantTag bool + }{ + { + name: "no-policy", + wantTag: false, }, - &resultMachines, - ) - assert.Nil(t, err) - found := false - for _, node := range resultMachines { - if node.GetInvalidTags() != nil { - for _, tag := range node.GetInvalidTags() { - if tag == "tag:test" { - found = true - } - } - } - } - assert.Equal( - t, - true, - found, - "should not find a node with the tag 'tag:test' in the list of nodes", - ) -} - -func TestNodeAdvertiseTagWithACLCommand(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - scenario, err := NewScenario(dockertestMaxWait()) - assertNoErr(t, err) - defer scenario.ShutdownAssertNoPanics(t) - - spec := map[string]int{ - "user1": 1, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy( - &policy.ACLPolicy{ - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + { + name: "with-policy", + policy: &policy.ACLPolicy{ + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + TagOwners: map[string][]string{ + "tag:test": {"user1"}, }, }, - TagOwners: map[string][]string{ - "tag:exists": {"user1"}, - }, + wantTag: true, }, - )) - assertNoErr(t, err) + } - headscale, err := scenario.Headscale() - assertNoErr(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + // defer scenario.ShutdownAssertNoPanics(t) - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, spec["user1"]) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--tags", - "--output", "json", - }, - &resultMachines, - ) - assert.Nil(t, err) - found := false - for _, node := range resultMachines { - if node.GetValidTags() != nil { - for _, tag := range node.GetValidTags() { - if tag == "tag:exists" { - found = true + spec := map[string]int{ + "user1": 1, + } + + err = scenario.CreateHeadscaleEnv(spec, + []tsic.Option{tsic.WithTags([]string{"tag:test"})}, + hsic.WithTestName("cliadvtags"), + hsic.WithACLPolicy(tt.policy), + ) + assertNoErr(t, err) + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + // Test list all nodes after added seconds + resultMachines := make([]*v1.Node, spec["user1"]) + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--tags", + "--output", "json", + }, + &resultMachines, + ) + assert.Nil(t, err) + found := false + for _, node := range resultMachines { + if tags := node.GetValidTags(); tags != nil { + found = slices.Contains(tags, "tag:test") } } - } + assert.Equalf( + t, + tt.wantTag, + found, + "'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag, + ) + }) } - assert.Equal( - t, - true, - found, - "should not find a node with the tag 'tag:exists' in the list of nodes", - ) } func TestNodeCommand(t *testing.T) { @@ -1732,7 +1701,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath, }, ) - assert.ErrorContains(t, err, "verifying policy rules: invalid action") + assert.ErrorContains(t, err, "compiling filter rules: invalid action") // The new policy was invalid, the old one should still be in place, which // is none. diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 8c379dc8..0d40e741 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -78,6 +78,10 @@ type Option = func(c *HeadscaleInContainer) // HeadscaleInContainer instance. func WithACLPolicy(acl *policy.ACLPolicy) Option { return func(hsic *HeadscaleInContainer) { + if acl == nil { + return + } + // TODO(kradalby): Move somewhere appropriate hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath