rename acl "get" funcs to "expand" for consistency

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-19 09:17:50 +02:00 committed by Kristoffer Dalby
parent 155cc072f7
commit 19dc0ac702
4 changed files with 59 additions and 62 deletions

View file

@ -338,7 +338,7 @@ func (api headscaleV1APIServer) ListMachines(
response := make([]*v1.Machine, len(machines))
for index, machine := range machines {
m := machine.Proto()
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
validTags, invalidTags := api.h.ACLPolicy.TagsOfMachine(
machine,
)
m.InvalidTags = invalidTags

View file

@ -104,7 +104,7 @@ func tailNode(
online := machine.IsOnline()
tags, _ := pol.GetTagsOfMachine(machine)
tags, _ := pol.TagsOfMachine(machine)
tags = lo.Uniq(append(tags, machine.ForcedTags...))
node := tailcfg.Node{

View file

@ -114,9 +114,6 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
return &policy, nil
}
// TODO(kradalby): This needs to be replace with something that generates
// the rules as needed and not stores it on the global object, rules are
// per node and that should be taken into account.
func GenerateFilterAndSSHRules(
policy *ACLPolicy,
machine *types.Machine,
@ -169,7 +166,7 @@ func (pol *ACLPolicy) generateFilterRules(
srcIPs := []string{}
for srcIndex, src := range acl.Sources {
srcs, err := pol.getIPsFromSource(src, machines)
srcs, err := pol.expandSource(src, machines)
if err != nil {
log.Error().
Interface("src", src).
@ -338,7 +335,7 @@ func (pol *ACLPolicy) generateSSHRules(
Any: true,
})
} else if isGroup(rawSrc) {
users, err := pol.getUsersInGroup(rawSrc)
users, err := pol.expandUsersFromGroup(rawSrc)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
@ -401,26 +398,6 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
}, nil
}
// getIPsFromSource returns a set of Source IPs that would be associated
// with the given src alias.
func (pol *ACLPolicy) getIPsFromSource(
src string,
machines types.Machines,
) ([]string, error) {
ipSet, err := pol.ExpandAlias(machines, src)
if err != nil {
return []string{}, err
}
prefixes := []string{}
for _, prefix := range ipSet.Prefixes() {
prefixes = append(prefixes, prefix.String())
}
return prefixes, nil
}
func parseDestination(dest string) (string, string, error) {
var tokens []string
@ -520,6 +497,26 @@ func parseProtocol(protocol string) ([]int, bool, error) {
}
}
// expandSource returns a set of Source IPs that would be associated
// with the given src alias.
func (pol *ACLPolicy) expandSource(
src string,
machines types.Machines,
) ([]string, error) {
ipSet, err := pol.ExpandAlias(machines, src)
if err != nil {
return []string{}, err
}
prefixes := []string{}
for _, prefix := range ipSet.Prefixes() {
prefixes = append(prefixes, prefix.String())
}
return prefixes, nil
}
// expandalias has an input of either
// - a user
// - a group
@ -544,16 +541,16 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group
if isGroup(alias) {
return pol.getIPsFromGroup(alias, machines)
return pol.expandIPsFromGroup(alias, machines)
}
// if alias is a tag
if isTag(alias) {
return pol.getIPsFromTag(alias, machines)
return pol.expandIPsFromTag(alias, machines)
}
// if alias is a user
if ips, err := pol.getIPsForUser(alias, machines); ips != nil {
if ips, err := pol.expandIPsFromUser(alias, machines); ips != nil {
return ips, err
}
@ -567,12 +564,12 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is an IP
if ip, err := netip.ParseAddr(alias); err == nil {
return pol.getIPsFromSingleIP(ip, machines)
return pol.expandIPsFromSingleIP(ip, machines)
}
// if alias is an IP Prefix (CIDR)
if prefix, err := netip.ParsePrefix(alias); err == nil {
return pol.getIPsFromIPPrefix(prefix, machines)
return pol.expandIPsFromIPPrefix(prefix, machines)
}
log.Warn().Msgf("No IPs found with the alias %v", alias)
@ -591,7 +588,7 @@ func excludeCorrectlyTaggedNodes(
out := types.Machines{}
tags := []string{}
for tag := range aclPolicy.TagOwners {
owners, _ := getTagOwners(aclPolicy, user)
owners, _ := expandOwnersFromTag(aclPolicy, user)
ns := append(owners, user)
if util.StringOrPrefixListContains(ns, user) {
tags = append(tags, tag)
@ -668,20 +665,9 @@ func expandPorts(portsStr string, isWild bool) (*[]tailcfg.PortRange, error) {
return &ports, nil
}
func filterMachinesByUser(machines types.Machines, user string) types.Machines {
out := types.Machines{}
for _, machine := range machines {
if machine.User.Name == user {
out = append(out, machine)
}
}
return out
}
// getTagOwners will return a list of user. An owner can be either a user or a group
// expandOwnersFromTag will return a list of user. An owner can be either a user or a group
// a group cannot be composed of groups.
func getTagOwners(
func expandOwnersFromTag(
pol *ACLPolicy,
tag string,
) ([]string, error) {
@ -696,7 +682,7 @@ func getTagOwners(
}
for _, owner := range ows {
if isGroup(owner) {
gs, err := pol.getUsersInGroup(owner)
gs, err := pol.expandUsersFromGroup(owner)
if err != nil {
return []string{}, err
}
@ -709,9 +695,9 @@ func getTagOwners(
return owners, nil
}
// getUsersInGroup will return the list of user inside the group
// expandUsersFromGroup will return the list of user inside the group
// after some validation.
func (pol *ACLPolicy) getUsersInGroup(
func (pol *ACLPolicy) expandUsersFromGroup(
group string,
) ([]string, error) {
users := []string{}
@ -745,13 +731,13 @@ func (pol *ACLPolicy) getUsersInGroup(
return users, nil
}
func (pol *ACLPolicy) getIPsFromGroup(
func (pol *ACLPolicy) expandIPsFromGroup(
group string,
machines types.Machines,
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
users, err := pol.getUsersInGroup(group)
users, err := pol.expandUsersFromGroup(group)
if err != nil {
return &netipx.IPSet{}, err
}
@ -765,7 +751,7 @@ func (pol *ACLPolicy) getIPsFromGroup(
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromTag(
func (pol *ACLPolicy) expandIPsFromTag(
alias string,
machines types.Machines,
) (*netipx.IPSet, error) {
@ -779,7 +765,7 @@ func (pol *ACLPolicy) getIPsFromTag(
}
// find tag owners
owners, err := getTagOwners(pol, alias)
owners, err := expandOwnersFromTag(pol, alias)
if err != nil {
if errors.Is(err, ErrInvalidTag) {
ipSet, _ := build.IPSet()
@ -811,7 +797,7 @@ func (pol *ACLPolicy) getIPsFromTag(
return build.IPSet()
}
func (pol *ACLPolicy) getIPsForUser(
func (pol *ACLPolicy) expandIPsFromUser(
user string,
machines types.Machines,
) (*netipx.IPSet, error) {
@ -832,7 +818,7 @@ func (pol *ACLPolicy) getIPsForUser(
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromSingleIP(
func (pol *ACLPolicy) expandIPsFromSingleIP(
ip netip.Addr,
machines types.Machines,
) (*netipx.IPSet, error) {
@ -850,7 +836,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP(
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromIPPrefix(
func (pol *ACLPolicy) expandIPsFromIPPrefix(
prefix netip.Prefix,
machines types.Machines,
) (*netipx.IPSet, error) {
@ -885,10 +871,10 @@ func isTag(str string) bool {
return strings.HasPrefix(str, "tag:")
}
// getTags will return the tags of the current machine.
// TagsOfMachine will return the tags of the current machine.
// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag.
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
func (pol *ACLPolicy) GetTagsOfMachine(
func (pol *ACLPolicy) TagsOfMachine(
machine types.Machine,
) ([]string, []string) {
validTags := make([]string, 0)
@ -897,7 +883,7 @@ func (pol *ACLPolicy) GetTagsOfMachine(
validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool)
for _, tag := range machine.HostInfo.RequestTags {
owners, err := getTagOwners(pol, tag)
owners, err := expandOwnersFromTag(pol, tag)
if errors.Is(err, ErrInvalidTag) {
invalidTagMap[tag] = true
@ -925,6 +911,17 @@ func (pol *ACLPolicy) GetTagsOfMachine(
return validTags, invalidTags
}
func filterMachinesByUser(machines types.Machines, user string) types.Machines {
out := types.Machines{}
for _, machine := range machines {
if machine.User.Name == user {
out = append(out, machine)
}
}
return out
}
// FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
func FilterMachinesByACL(
machine *types.Machine,

View file

@ -690,7 +690,7 @@ func Test_expandGroup(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
viper.Set("oidc.strip_email_domain", test.args.stripEmail)
got, err := test.field.pol.getUsersInGroup(
got, err := test.field.pol.expandUsersFromGroup(
test.args.group,
)
@ -779,7 +779,7 @@ func Test_expandTagOwners(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := getTagOwners(
got, err := expandOwnersFromTag(
test.args.aclPolicy,
test.args.tag,
)
@ -2022,7 +2022,7 @@ func Test_getTags(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine(
gotValid, gotInvalid := test.args.aclPolicy.TagsOfMachine(
test.args.machine,
)
for _, valid := range gotValid {