diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index ff73985b..f16169d6 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -68,6 +68,24 @@ func theInternet() *netipx.IPSet { return theInternetSet } +// vinhjaxt +var allTheIps *netipx.IPSet + +func getAllTheIps() *netipx.IPSet { + if allTheIps != nil { + return allTheIps + } + + var build netipx.IPSetBuilder + build.AddPrefix(netip.MustParsePrefix("::/0")) + build.AddPrefix(netip.MustParsePrefix("0.0.0.0/0")) + + allTheIps, _ := build.IPSet() + return allTheIps +} + +// !vinhjaxt + // For some reason golang.org/x/net/internal/iana is an internal package. const ( protocolICMP = 1 // Internet Control Message @@ -169,13 +187,51 @@ func (pol *ACLPolicy) CompileFilterRules( var rules []tailcfg.FilterRule - for index, acl := range pol.ACLs { + // vinhjaxt + polACLs := pol.ACLs + for index := 0; index < len(polACLs); index++ { + acl := polACLs[index] + aclDestinations := acl.Destinations + if acl.Action != "accept" { return nil, ErrInvalidAction } var srcIPs []string for srcIndex, src := range acl.Sources { + // vinhjaxt + if strings.HasPrefix(src, "autogroup:member") { + // split all autogroup:self and others + var oldDst []string + var newDst []string + + for _, dst := range aclDestinations { + if strings.HasPrefix(dst, "autogroup:self") { + newDst = append(newDst, dst) + } else { + oldDst = append(oldDst, dst) + } + } + + if len(oldDst) == 0 { + // all moved to new, only need to change source + src = "autogroup:self" + } else if len(newDst) != 0 { + // apart moved to new + + aclDestinations = oldDst + + splitAcl := ACL{ + Action: acl.Action, + Sources: []string{"autogroup:self"}, + Destinations: newDst, + } + polACLs = append(polACLs, splitAcl) + // Don't do pol.ACLs = polACLs, race condition + } + } + // !vinhjaxt + srcs, err := pol.expandSource(src, nodes) if err != nil { return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) @@ -189,12 +245,19 @@ func (pol *ACLPolicy) CompileFilterRules( } destPorts := []tailcfg.NetPortRange{} - for _, dest := range acl.Destinations { + for _, dest := range aclDestinations { alias, port, err := parseDestination(dest) if err != nil { return nil, err } + // vinhjaxt + if strings.HasPrefix(alias, "autogroup:self") { + if len(acl.Sources) != 1 || !(acl.Sources[0] == "autogroup:self" || acl.Sources[0] == "autogroup:member") { + return nil, errors.New(`dst "autogroup:self" only works with one src "autogroup:member" or "autogroup:self"`) + } + } + expanded, err := pol.ExpandAlias( nodes, alias, @@ -309,9 +372,21 @@ func (pol *ACLPolicy) CompileSSHPolicy( AllowLocalPortForwarding: false, } - for index, sshACL := range pol.SSHs { + // vinhjaxt + polSSHs := pol.SSHs + for index := 0; index < len(polSSHs); index++ { + sshACL := polSSHs[index] + sshACLDestinations := sshACL.Destinations + var dest netipx.IPSetBuilder - for _, src := range sshACL.Destinations { + for _, src := range sshACLDestinations { + // vinhjaxt + if strings.HasPrefix(src, "autogroup:self") { + if len(sshACL.Sources) != 1 || !(sshACL.Sources[0] == "autogroup:self" || sshACL.Sources[0] == "autogroup:member") { + return nil, errors.New(`dst "autogroup:self" only works with one src "autogroup:member" or "autogroup:self"`) + } + } + expanded, err := pol.ExpandAlias(append(peers, node), src) if err != nil { return nil, err @@ -361,6 +436,41 @@ func (pol *ACLPolicy) CompileSSHPolicy( }) } } else { + // vinhjaxt + if strings.HasPrefix(rawSrc, "autogroup:member") { + // split all autogroup:self and others + var oldDst []string + var newDst []string + + for _, dst := range sshACLDestinations { + if strings.HasPrefix(dst, "autogroup:self") { + newDst = append(newDst, dst) + } else { + oldDst = append(oldDst, dst) + } + } + + if len(oldDst) == 0 { + // all moved to new, only need to change source + rawSrc = "autogroup:self" + } else if len(newDst) != 0 { + // apart moved to new + + sshACLDestinations = oldDst + + splitAcl := SSH{ + Action: sshACL.Action, + Sources: []string{"autogroup:self"}, + Destinations: newDst, + Users: sshACL.Users, + CheckPeriod: sshACL.CheckPeriod, + } + polSSHs = append(polSSHs, splitAcl) + // Don't do pol.SSHs = polSSHs, race condition + } + } + // !vinhjaxt + expandedSrcs, err := pol.ExpandAlias( peers, rawSrc, @@ -561,7 +671,8 @@ func (pol *ACLPolicy) ExpandAlias( } if isAutoGroup(alias) { - return expandAutoGroup(alias) + // vinhjaxt + return expandAutoGroup(pol, alias, nodes) } // if alias is a user @@ -880,11 +991,65 @@ func (pol *ACLPolicy) expandIPsFromIPPrefix( return build.IPSet() } -func expandAutoGroup(alias string) (*netipx.IPSet, error) { +// vinhjaxt +func expandAutoGroup(pol *ACLPolicy, alias string, nodes types.Nodes) (*netipx.IPSet, error) { switch { case strings.HasPrefix(alias, "autogroup:internet"): return theInternet(), nil + // vinhjaxt + case strings.HasPrefix(alias, "autogroup:self"): + // all user's devices, not tagged devices + { + var build netipx.IPSetBuilder + if len(nodes) == 0 { + return build.IPSet() + } + + currentNode := nodes[len(nodes)-1] // /mapper/mapper.go#L544 + for _, node := range nodes { + if node.User.Name == currentNode.User.Name { + // same user name + node.AppendToIPSet(&build) + } + } + return build.IPSet() + } + case strings.HasPrefix(alias, "autogroup:member"): + // all users (not tagged devices) + { + var build netipx.IPSetBuilder + + for _, node := range nodes { + if len(node.ForcedTags) != 0 { // auto tag + continue + } + if tags, _ := pol.TagsOfNode(node); len(tags) != 0 { // valid tag manual add by user (tagOwner) + continue + } + node.AppendToIPSet(&build) + } + return build.IPSet() + } + case strings.HasPrefix(alias, "autogroup:tagged"): + // all tagged devices + { + var build netipx.IPSetBuilder + + for _, node := range nodes { + if len(node.ForcedTags) != 0 { // auto tag + node.AppendToIPSet(&build) + } else if tags, _ := pol.TagsOfNode(node); len(tags) != 0 { // valid tag manual add by user (tagOwner) + node.AppendToIPSet(&build) + } + } + return build.IPSet() + } + + case strings.HasPrefix(alias, "autogroup:danger-all"): + // all ips + return getAllTheIps(), nil + // !vinhjaxt default: return nil, fmt.Errorf("unknown autogroup %q", alias) }