hook up user and node changes to policy

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-26 12:53:04 -05:00
parent 19bc8b6e01
commit 8d5b04f3d3
No known key found for this signature in database
4 changed files with 80 additions and 0 deletions

View file

@ -165,6 +165,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app.db,
app.nodeNotifier,
app.ipAlloc,
app.polMan,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
@ -472,6 +473,48 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
users, err := db.ListUsers()
if err != nil {
return err
}
changed, err := polMan.SetUsers(users)
if err != nil {
return err
}
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}
return nil
}
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
nodes, err := db.ListNodes()
if err != nil {
return err
}
changed, err := polMan.SetNodes(nodes)
if err != nil {
return err
}
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}
return nil
}
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
if profilingEnabled {
@ -770,6 +813,7 @@ func (h *Headscale) Serve() error {
Msg("Received SIGHUP, reloading ACL and Config")
// TODO(kradalby): Reload config on SIGHUP
// TODO(kradalby): Only update if we set a new policy
if err := h.loadACLPolicy(); err != nil {
log.Error().Err(err).Msg("failed to reload ACL policy")
}

View file

@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey(
return
}
err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
if err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
}
err = h.db.Write(func(tx *gorm.DB) error {

View file

@ -57,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser(
return nil, err
}
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return &v1.CreateUserResponse{User: user.Proto()}, nil
}
@ -86,6 +91,11 @@ func (api headscaleV1APIServer) DeleteUser(
return nil, err
}
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return &v1.DeleteUserResponse{}, nil
}
@ -220,6 +230,11 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err
}
err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using node: %w", err)
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
}

View file

@ -18,6 +18,7 @@ import (
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -53,6 +54,7 @@ type AuthProviderOIDC struct {
registrationCache *zcache.Cache[string, key.MachinePublic]
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
polMan policy.PolicyManager
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@ -65,6 +67,7 @@ func NewAuthProviderOIDC(
db *db.HSDatabase,
notif *notifier.Notifier,
ipAlloc *db.IPAllocator,
polMan policy.PolicyManager,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
@ -96,6 +99,7 @@ func NewAuthProviderOIDC(
registrationCache: registrationCache,
notifier: notif,
ipAlloc: ipAlloc,
polMan: polMan,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
@ -461,6 +465,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return nil, fmt.Errorf("creating or updating user: %w", err)
}
err = usersChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return user, nil
}
@ -484,6 +493,11 @@ func (a *AuthProviderOIDC) registerNode(
return fmt.Errorf("could not register node: %w", err)
}
err = nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return fmt.Errorf("updating resources using node: %w", err)
}
return nil
}