generateACLPolicy() no longer a Headscale method

This commit is contained in:
Juan Font 2022-11-30 23:37:58 +00:00
parent eb072a1a74
commit a7c608f0cb
2 changed files with 35 additions and 26 deletions

41
acls.go
View file

@ -117,7 +117,16 @@ func (h *Headscale) LoadACLPolicy(path string) error {
} }
func (h *Headscale) UpdateACLRules() error { func (h *Headscale) UpdateACLRules() error {
rules, err := h.generateACLRules() machines, err := h.ListMachines()
if err != nil {
return err
}
if h.aclPolicy == nil {
return errEmptyPolicy
}
rules, err := generateACLRules(machines, *h.aclPolicy, h.cfg.OIDC.StripEmaildomain)
if err != nil { if err != nil {
return err return err
} }
@ -141,26 +150,17 @@ func (h *Headscale) UpdateACLRules() error {
return nil return nil
} }
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { func generateACLRules(machines []Machine, aclPolicy ACLPolicy, stripEmaildomain bool) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
if h.aclPolicy == nil { for index, acl := range aclPolicy.ACLs {
return nil, errEmptyPolicy
}
machines, err := h.ListMachines()
if err != nil {
return nil, err
}
for index, acl := range h.aclPolicy.ACLs {
if acl.Action != "accept" { if acl.Action != "accept" {
return nil, errInvalidAction return nil, errInvalidAction
} }
srcIPs := []string{} srcIPs := []string{}
for innerIndex, src := range acl.Sources { for innerIndex, src := range acl.Sources {
srcs, err := h.generateACLPolicySrcIP(machines, *h.aclPolicy, src) srcs, err := generateACLPolicySrcIP(machines, aclPolicy, src, stripEmaildomain)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Source %d", index, innerIndex) Msgf("Error parsing ACL %d, Source %d", index, innerIndex)
@ -180,11 +180,12 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
destPorts := []tailcfg.NetPortRange{} destPorts := []tailcfg.NetPortRange{}
for innerIndex, dest := range acl.Destinations { for innerIndex, dest := range acl.Destinations {
dests, err := h.generateACLPolicyDest( dests, err := generateACLPolicyDest(
machines, machines,
*h.aclPolicy, aclPolicy,
dest, dest,
needsWildcard, needsWildcard,
stripEmaildomain,
) )
if err != nil { if err != nil {
log.Error(). log.Error().
@ -310,19 +311,21 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
}, nil }, nil
} }
func (h *Headscale) generateACLPolicySrcIP( func generateACLPolicySrcIP(
machines []Machine, machines []Machine,
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
src string, src string,
stripEmaildomain bool,
) ([]string, error) { ) ([]string, error) {
return expandAlias(machines, aclPolicy, src, h.cfg.OIDC.StripEmaildomain) return expandAlias(machines, aclPolicy, src, stripEmaildomain)
} }
func (h *Headscale) generateACLPolicyDest( func generateACLPolicyDest(
machines []Machine, machines []Machine,
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
dest string, dest string,
needsWildcard bool, needsWildcard bool,
stripEmaildomain bool,
) ([]tailcfg.NetPortRange, error) { ) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(dest, ":") tokens := strings.Split(dest, ":")
if len(tokens) < expectedTokenItems || len(tokens) > 3 { if len(tokens) < expectedTokenItems || len(tokens) > 3 {
@ -346,7 +349,7 @@ func (h *Headscale) generateACLPolicyDest(
machines, machines,
aclPolicy, aclPolicy,
alias, alias,
h.cfg.OIDC.StripEmaildomain, stripEmaildomain,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -54,7 +54,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := app.generateACLRules() rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
} }
@ -411,7 +411,7 @@ func (s *Suite) TestPortRange(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := app.generateACLRules() rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -425,7 +425,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_protocols.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_protocols.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := app.generateACLRules() rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -439,7 +439,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := app.generateACLRules() rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -455,7 +455,7 @@ func (s *Suite) TestPortWildcardYAML(c *check.C) {
err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.yaml") err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.yaml")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := app.generateACLRules() rules, err := generateACLRules([]Machine{}, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -495,7 +495,10 @@ func (s *Suite) TestPortNamespace(c *check.C) {
) )
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := app.generateACLRules() machines, err := app.ListMachines()
c.Assert(err, check.IsNil)
rules, err := generateACLRules(machines, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -535,7 +538,10 @@ func (s *Suite) TestPortGroup(c *check.C) {
err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson") err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := app.generateACLRules() machines, err := app.ListMachines()
c.Assert(err, check.IsNil)
rules, err := generateACLRules(machines, *app.aclPolicy, false)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)