From 2732e7a017d114fe2887455e12507355fbdca93e Mon Sep 17 00:00:00 2001 From: Gabe Cook Date: Sat, 12 Oct 2024 02:48:13 -0500 Subject: [PATCH] chore(policy): ACL code cleanups and lint fixes --- hscontrol/policy/acls.go | 192 +++++++++++++++++++-------------------- 1 file changed, 92 insertions(+), 100 deletions(-) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index f16169d6..9c14a43f 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -28,12 +28,22 @@ var ( ErrInvalidTag = errors.New("invalid tag") ErrInvalidPortFormat = errors.New("invalid port format") ErrWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") + ErrUnknownAutogroup = errors.New("unknown autogroup") + ErrAutogroupSelf = errors.New(`dst "autogroup:self" only works with one src "autogroup:member" or "autogroup:self"`) ) const ( portRangeBegin = 0 portRangeEnd = 65535 expectedTokenItems = 2 + + autogroupPrefix = "autogroup:" + autogroupInternet = "autogroup:internet" + autogroupSelf = "autogroup:self" + autogroupMember = "autogroup:member" + autogroupTagged = "autogroup:tagged" + autogroupNonRoot = "autogroup:nonroot" + autogroupDangerAll = "autogroup:danger-all" ) var theInternetSet *netipx.IPSet @@ -68,12 +78,11 @@ func theInternet() *netipx.IPSet { return theInternetSet } -// vinhjaxt -var allTheIps *netipx.IPSet +var allIPSet *netipx.IPSet -func getAllTheIps() *netipx.IPSet { - if allTheIps != nil { - return allTheIps +func allIPs() *netipx.IPSet { + if allIPSet != nil { + return allIPSet } var build netipx.IPSetBuilder @@ -81,11 +90,10 @@ func getAllTheIps() *netipx.IPSet { build.AddPrefix(netip.MustParsePrefix("0.0.0.0/0")) allTheIps, _ := build.IPSet() + return allTheIps } -// !vinhjaxt - // For some reason golang.org/x/net/internal/iana is an internal package. const ( protocolICMP = 1 // Internet Control Message @@ -187,11 +195,10 @@ func (pol *ACLPolicy) CompileFilterRules( var rules []tailcfg.FilterRule - // vinhjaxt - polACLs := pol.ACLs - for index := 0; index < len(polACLs); index++ { - acl := polACLs[index] - aclDestinations := acl.Destinations + acls := pol.ACLs + for index := 0; index < len(acls); index++ { + acl := acls[index] + destinations := acl.Destinations if acl.Action != "accept" { return nil, ErrInvalidAction @@ -199,38 +206,36 @@ func (pol *ACLPolicy) CompileFilterRules( var srcIPs []string for srcIndex, src := range acl.Sources { - // vinhjaxt - if strings.HasPrefix(src, "autogroup:member") { + if strings.HasPrefix(src, autogroupMember) { // split all autogroup:self and others var oldDst []string var newDst []string - for _, dst := range aclDestinations { - if strings.HasPrefix(dst, "autogroup:self") { + for _, dst := range destinations { + if strings.HasPrefix(dst, autogroupSelf) { newDst = append(newDst, dst) } else { oldDst = append(oldDst, dst) } } - if len(oldDst) == 0 { + switch { + case len(oldDst) == 0: // all moved to new, only need to change source - src = "autogroup:self" - } else if len(newDst) != 0 { + src = autogroupSelf + case len(newDst) != 0: // apart moved to new - aclDestinations = oldDst + destinations = oldDst - splitAcl := ACL{ + splitACL := ACL{ Action: acl.Action, - Sources: []string{"autogroup:self"}, + Sources: []string{autogroupSelf}, Destinations: newDst, } - polACLs = append(polACLs, splitAcl) - // Don't do pol.ACLs = polACLs, race condition + acls = append(acls, splitACL) } } - // !vinhjaxt srcs, err := pol.expandSource(src, nodes) if err != nil { @@ -245,16 +250,15 @@ func (pol *ACLPolicy) CompileFilterRules( } destPorts := []tailcfg.NetPortRange{} - for _, dest := range aclDestinations { + for _, dest := range destinations { alias, port, err := parseDestination(dest) if err != nil { return nil, err } - // vinhjaxt - if strings.HasPrefix(alias, "autogroup:self") { - if len(acl.Sources) != 1 || !(acl.Sources[0] == "autogroup:self" || acl.Sources[0] == "autogroup:member") { - return nil, errors.New(`dst "autogroup:self" only works with one src "autogroup:member" or "autogroup:self"`) + if strings.HasPrefix(alias, autogroupSelf) { + if len(acl.Sources) != 1 || acl.Sources[0] != autogroupSelf && acl.Sources[0] != autogroupMember { + return nil, ErrAutogroupSelf } } @@ -372,18 +376,16 @@ func (pol *ACLPolicy) CompileSSHPolicy( AllowLocalPortForwarding: false, } - // vinhjaxt - polSSHs := pol.SSHs - for index := 0; index < len(polSSHs); index++ { - sshACL := polSSHs[index] - sshACLDestinations := sshACL.Destinations + sshs := pol.SSHs + for index := 0; index < len(sshs); index++ { + sshACL := sshs[index] + destinations := sshACL.Destinations var dest netipx.IPSetBuilder - for _, src := range sshACLDestinations { - // vinhjaxt - if strings.HasPrefix(src, "autogroup:self") { - if len(sshACL.Sources) != 1 || !(sshACL.Sources[0] == "autogroup:self" || sshACL.Sources[0] == "autogroup:member") { - return nil, errors.New(`dst "autogroup:self" only works with one src "autogroup:member" or "autogroup:self"`) + for _, src := range destinations { + if strings.HasPrefix(src, autogroupSelf) { + if len(sshACL.Sources) != 1 || sshACL.Sources[0] != autogroupSelf && sshACL.Sources[0] != autogroupMember { + return nil, ErrAutogroupSelf } } @@ -436,40 +438,38 @@ func (pol *ACLPolicy) CompileSSHPolicy( }) } } else { - // vinhjaxt - if strings.HasPrefix(rawSrc, "autogroup:member") { + if strings.HasPrefix(rawSrc, autogroupMember) { // split all autogroup:self and others var oldDst []string var newDst []string - for _, dst := range sshACLDestinations { - if strings.HasPrefix(dst, "autogroup:self") { + for _, dst := range destinations { + if strings.HasPrefix(dst, autogroupSelf) { newDst = append(newDst, dst) } else { oldDst = append(oldDst, dst) } } - if len(oldDst) == 0 { + switch { + case len(oldDst) == 0: // all moved to new, only need to change source - rawSrc = "autogroup:self" - } else if len(newDst) != 0 { + rawSrc = autogroupSelf + case len(newDst) != 0: // apart moved to new - sshACLDestinations = oldDst + destinations = oldDst - splitAcl := SSH{ + splitACL := SSH{ Action: sshACL.Action, - Sources: []string{"autogroup:self"}, + Sources: []string{autogroupSelf}, Destinations: newDst, Users: sshACL.Users, CheckPeriod: sshACL.CheckPeriod, } - polSSHs = append(polSSHs, splitAcl) - // Don't do pol.SSHs = polSSHs, race condition + sshs = append(sshs, splitACL) } } - // !vinhjaxt expandedSrcs, err := pol.ExpandAlias( peers, @@ -671,7 +671,6 @@ func (pol *ACLPolicy) ExpandAlias( } if isAutoGroup(alias) { - // vinhjaxt return expandAutoGroup(pol, alias, nodes) } @@ -991,67 +990,60 @@ func (pol *ACLPolicy) expandIPsFromIPPrefix( return build.IPSet() } -// vinhjaxt func expandAutoGroup(pol *ACLPolicy, alias string, nodes types.Nodes) (*netipx.IPSet, error) { switch { - case strings.HasPrefix(alias, "autogroup:internet"): + case strings.HasPrefix(alias, autogroupInternet): return theInternet(), nil - // vinhjaxt - case strings.HasPrefix(alias, "autogroup:self"): + case strings.HasPrefix(alias, autogroupSelf): // all user's devices, not tagged devices - { - var build netipx.IPSetBuilder - if len(nodes) == 0 { - return build.IPSet() - } - - currentNode := nodes[len(nodes)-1] // /mapper/mapper.go#L544 + var build netipx.IPSetBuilder + if len(nodes) != 0 { + currentNode := nodes[len(nodes)-1] for _, node := range nodes { - if node.User.Name == currentNode.User.Name { - // same user name + if node.User.ID == currentNode.User.ID { node.AppendToIPSet(&build) } } - return build.IPSet() } - case strings.HasPrefix(alias, "autogroup:member"): - // all users (not tagged devices) - { - var build netipx.IPSetBuilder - for _, node := range nodes { - if len(node.ForcedTags) != 0 { // auto tag - continue - } - if tags, _ := pol.TagsOfNode(node); len(tags) != 0 { // valid tag manual add by user (tagOwner) - continue - } + return build.IPSet() + + case strings.HasPrefix(alias, autogroupMember): + // all users (not tagged devices) + var build netipx.IPSetBuilder + + for _, node := range nodes { + if len(node.ForcedTags) != 0 { // auto tag + continue + } + if tags, _ := pol.TagsOfNode(node); len(tags) != 0 { // valid tag manual add by user (tagOwner) + continue + } + node.AppendToIPSet(&build) + } + + return build.IPSet() + + case strings.HasPrefix(alias, autogroupTagged): + // all tagged devices + var build netipx.IPSetBuilder + + for _, node := range nodes { + if len(node.ForcedTags) != 0 { // auto tag + node.AppendToIPSet(&build) + } else if tags, _ := pol.TagsOfNode(node); len(tags) != 0 { // valid tag manual add by user (tagOwner) node.AppendToIPSet(&build) } - return build.IPSet() - } - case strings.HasPrefix(alias, "autogroup:tagged"): - // all tagged devices - { - var build netipx.IPSetBuilder - - for _, node := range nodes { - if len(node.ForcedTags) != 0 { // auto tag - node.AppendToIPSet(&build) - } else if tags, _ := pol.TagsOfNode(node); len(tags) != 0 { // valid tag manual add by user (tagOwner) - node.AppendToIPSet(&build) - } - } - return build.IPSet() } - case strings.HasPrefix(alias, "autogroup:danger-all"): - // all ips - return getAllTheIps(), nil - // !vinhjaxt + return build.IPSet() + + case strings.HasPrefix(alias, autogroupDangerAll): + return allIPs(), nil + default: - return nil, fmt.Errorf("unknown autogroup %q", alias) + return nil, fmt.Errorf("%w: %q", ErrUnknownAutogroup, alias) } } @@ -1068,7 +1060,7 @@ func isTag(str string) bool { } func isAutoGroup(str string) bool { - return strings.HasPrefix(str, "autogroup:") + return strings.HasPrefix(str, autogroupPrefix) } // TagsOfNode will return the tags of the current node.