Compare commits

..

1 commit

Author SHA1 Message Date
Kristoffer Dalby
0030afbe05
Merge 8d5b04f3d3 into e2d5ee0927 2024-10-26 18:26:09 +00:00
2 changed files with 54 additions and 71 deletions

View file

@ -88,7 +88,6 @@ type Headscale struct {
DERPMap *tailcfg.DERPMap DERPMap *tailcfg.DERPMap
DERPServer *derpServer.DERPServer DERPServer *derpServer.DERPServer
polManOnce sync.Once
polMan policy.PolicyManager polMan policy.PolicyManager
mapper *mapper.Mapper mapper *mapper.Mapper
@ -532,7 +531,8 @@ func (h *Headscale) Serve() error {
} }
var err error var err error
if err = h.loadPolicyManager(); err != nil {
if err = h.loadACLPolicy(); err != nil {
return fmt.Errorf("failed to load ACL policy: %w", err) return fmt.Errorf("failed to load ACL policy: %w", err)
} }
@ -812,21 +812,13 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()). Str("signal", sig.String()).
Msg("Received SIGHUP, reloading ACL and Config") Msg("Received SIGHUP, reloading ACL and Config")
if err := h.loadPolicyManager(); err != nil { // TODO(kradalby): Reload config on SIGHUP
log.Error().Err(err).Msg("failed to reload Policy") // TODO(kradalby): Only update if we set a new policy
if err := h.loadACLPolicy(); err != nil {
log.Error().Err(err).Msg("failed to reload ACL policy")
} }
pol, err := h.policyBytes() if h.polMan != nil {
if err != nil {
log.Error().Err(err).Msg("failed to get policy blob")
}
changed, err := h.polMan.SetPolicy(pol)
if err != nil {
log.Error().Err(err).Msg("failed to set new policy")
}
if changed {
log.Info(). log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change") Msg("ACL policy successfully reloaded, notifying nodes of change")
@ -1045,46 +1037,22 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil return &machineKey, nil
} }
// policyBytes returns the appropriate policy for the func (h *Headscale) loadACLPolicy() error {
// current configuration as a []byte array. var (
func (h *Headscale) policyBytes() ([]byte, error) { pm policy.PolicyManager
)
switch h.cfg.Policy.Mode { switch h.cfg.Policy.Mode {
case types.PolicyModeFile: case types.PolicyModeFile:
path := h.cfg.Policy.Path path := h.cfg.Policy.Path
// It is fine to start headscale without a policy file. // It is fine to start headscale without a policy file.
if len(path) == 0 { if len(path) == 0 {
return nil, nil return nil
} }
absPath := util.AbsolutePathFromConfigPath(path) absPath := util.AbsolutePathFromConfigPath(path)
policyFile, err := os.Open(absPath)
if err != nil {
return nil, err
}
defer policyFile.Close()
return io.ReadAll(policyFile)
case types.PolicyModeDB:
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil, nil
}
return nil, err
}
return []byte(p.Data), err
}
return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
}
func (h *Headscale) loadPolicyManager() error {
var errOut error
h.polManOnce.Do(func() {
// Validate and reject configuration that would error when applied // Validate and reject configuration that would error when applied
// when creating a map response. This requires nodes, so there is still // when creating a map response. This requires nodes, so there is still
// a scenario where they might be allowed if the server has no nodes // a scenario where they might be allowed if the server has no nodes
@ -1095,35 +1063,54 @@ func (h *Headscale) loadPolicyManager() error {
// allowed to be written to the database. // allowed to be written to the database.
nodes, err := h.db.ListNodes() nodes, err := h.db.ListNodes()
if err != nil { if err != nil {
errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err) return fmt.Errorf("loading nodes from database to validate policy: %w", err)
return
} }
users, err := h.db.ListUsers() users, err := h.db.ListUsers()
if err != nil { if err != nil {
errOut = fmt.Errorf("loading users from database to validate policy: %w", err) return fmt.Errorf("loading users from database to validate policy: %w", err)
return
} }
pol, err := h.policyBytes() pm, err = policy.NewPolicyManagerFromPath(absPath, users, nodes)
if err != nil { if err != nil {
errOut = fmt.Errorf("loading policy bytes: %w", err) return fmt.Errorf("loading policy from file: %w", err)
return
}
h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
if err != nil {
errOut = fmt.Errorf("creating policy manager: %w", err)
return
} }
if len(nodes) > 0 { if len(nodes) > 0 {
_, err = h.polMan.SSHPolicy(nodes[0]) _, err = pm.SSHPolicy(nodes[0])
if err != nil { if err != nil {
errOut = fmt.Errorf("verifying SSH rules: %w", err) return fmt.Errorf("verifying SSH rules: %w", err)
return
} }
} }
})
return errOut case types.PolicyModeDB:
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil
}
return fmt.Errorf("failed to get policy from database: %w", err)
}
nodes, err := h.db.ListNodes()
if err != nil {
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
}
users, err := h.db.ListUsers()
if err != nil {
return fmt.Errorf("loading users from database to validate policy: %w", err)
}
pm, err = policy.NewPolicyManager([]byte(p.Data), users, nodes)
if err != nil {
return fmt.Errorf("loading policy from database: %w", err)
}
default:
log.Fatal().
Str("mode", string(h.cfg.Policy.Mode)).
Msg("Unknown ACL policy mode")
}
h.polMan = pm
return nil
} }

View file

@ -40,14 +40,10 @@ func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes
} }
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) { func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
var pol *ACLPolicy pol, err := LoadACLPolicyFromBytes(polB)
var err error
if polB != nil && len(polB) > 0 {
pol, err = LoadACLPolicyFromBytes(polB)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err) return nil, fmt.Errorf("parsing policy: %w", err)
} }
}
pm := PolicyManagerV1{ pm := PolicyManagerV1{
pol: pol, pol: pol,