diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 866c3cb2..74950c20 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -338,7 +338,7 @@ func (api headscaleV1APIServer) ListMachines( response := make([]*v1.Machine, len(machines)) for index, machine := range machines { m := machine.Proto() - validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( + validTags, invalidTags := api.h.ACLPolicy.TagsOfMachine( machine, ) m.InvalidTags = invalidTags diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index afc9423d..92bd5c96 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -104,7 +104,7 @@ func tailNode( online := machine.IsOnline() - tags, _ := pol.GetTagsOfMachine(machine) + tags, _ := pol.TagsOfMachine(machine) tags = lo.Uniq(append(tags, machine.ForcedTags...)) node := tailcfg.Node{ diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index bcdbb5d8..d4e24944 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -114,9 +114,6 @@ func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) { return &policy, nil } -// TODO(kradalby): This needs to be replace with something that generates -// the rules as needed and not stores it on the global object, rules are -// per node and that should be taken into account. func GenerateFilterAndSSHRules( policy *ACLPolicy, machine *types.Machine, @@ -169,7 +166,7 @@ func (pol *ACLPolicy) generateFilterRules( srcIPs := []string{} for srcIndex, src := range acl.Sources { - srcs, err := pol.getIPsFromSource(src, machines) + srcs, err := pol.expandSource(src, machines) if err != nil { log.Error(). Interface("src", src). @@ -338,7 +335,7 @@ func (pol *ACLPolicy) generateSSHRules( Any: true, }) } else if isGroup(rawSrc) { - users, err := pol.getUsersInGroup(rawSrc) + users, err := pol.expandUsersFromGroup(rawSrc) if err != nil { log.Error(). Msgf("Error parsing SSH %d, Source %d", index, innerIndex) @@ -401,26 +398,6 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { }, nil } -// getIPsFromSource returns a set of Source IPs that would be associated -// with the given src alias. -func (pol *ACLPolicy) getIPsFromSource( - src string, - machines types.Machines, -) ([]string, error) { - ipSet, err := pol.ExpandAlias(machines, src) - if err != nil { - return []string{}, err - } - - prefixes := []string{} - - for _, prefix := range ipSet.Prefixes() { - prefixes = append(prefixes, prefix.String()) - } - - return prefixes, nil -} - func parseDestination(dest string) (string, string, error) { var tokens []string @@ -520,6 +497,26 @@ func parseProtocol(protocol string) ([]int, bool, error) { } } +// expandSource returns a set of Source IPs that would be associated +// with the given src alias. +func (pol *ACLPolicy) expandSource( + src string, + machines types.Machines, +) ([]string, error) { + ipSet, err := pol.ExpandAlias(machines, src) + if err != nil { + return []string{}, err + } + + prefixes := []string{} + + for _, prefix := range ipSet.Prefixes() { + prefixes = append(prefixes, prefix.String()) + } + + return prefixes, nil +} + // expandalias has an input of either // - a user // - a group @@ -544,16 +541,16 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is a group if isGroup(alias) { - return pol.getIPsFromGroup(alias, machines) + return pol.expandIPsFromGroup(alias, machines) } // if alias is a tag if isTag(alias) { - return pol.getIPsFromTag(alias, machines) + return pol.expandIPsFromTag(alias, machines) } // if alias is a user - if ips, err := pol.getIPsForUser(alias, machines); ips != nil { + if ips, err := pol.expandIPsFromUser(alias, machines); ips != nil { return ips, err } @@ -567,12 +564,12 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is an IP if ip, err := netip.ParseAddr(alias); err == nil { - return pol.getIPsFromSingleIP(ip, machines) + return pol.expandIPsFromSingleIP(ip, machines) } // if alias is an IP Prefix (CIDR) if prefix, err := netip.ParsePrefix(alias); err == nil { - return pol.getIPsFromIPPrefix(prefix, machines) + return pol.expandIPsFromIPPrefix(prefix, machines) } log.Warn().Msgf("No IPs found with the alias %v", alias) @@ -591,7 +588,7 @@ func excludeCorrectlyTaggedNodes( out := types.Machines{} tags := []string{} for tag := range aclPolicy.TagOwners { - owners, _ := getTagOwners(aclPolicy, user) + owners, _ := expandOwnersFromTag(aclPolicy, user) ns := append(owners, user) if util.StringOrPrefixListContains(ns, user) { tags = append(tags, tag) @@ -668,20 +665,9 @@ func expandPorts(portsStr string, isWild bool) (*[]tailcfg.PortRange, error) { return &ports, nil } -func filterMachinesByUser(machines types.Machines, user string) types.Machines { - out := types.Machines{} - for _, machine := range machines { - if machine.User.Name == user { - out = append(out, machine) - } - } - - return out -} - -// getTagOwners will return a list of user. An owner can be either a user or a group +// expandOwnersFromTag will return a list of user. An owner can be either a user or a group // a group cannot be composed of groups. -func getTagOwners( +func expandOwnersFromTag( pol *ACLPolicy, tag string, ) ([]string, error) { @@ -696,7 +682,7 @@ func getTagOwners( } for _, owner := range ows { if isGroup(owner) { - gs, err := pol.getUsersInGroup(owner) + gs, err := pol.expandUsersFromGroup(owner) if err != nil { return []string{}, err } @@ -709,9 +695,9 @@ func getTagOwners( return owners, nil } -// getUsersInGroup will return the list of user inside the group +// expandUsersFromGroup will return the list of user inside the group // after some validation. -func (pol *ACLPolicy) getUsersInGroup( +func (pol *ACLPolicy) expandUsersFromGroup( group string, ) ([]string, error) { users := []string{} @@ -745,13 +731,13 @@ func (pol *ACLPolicy) getUsersInGroup( return users, nil } -func (pol *ACLPolicy) getIPsFromGroup( +func (pol *ACLPolicy) expandIPsFromGroup( group string, machines types.Machines, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} - users, err := pol.getUsersInGroup(group) + users, err := pol.expandUsersFromGroup(group) if err != nil { return &netipx.IPSet{}, err } @@ -765,7 +751,7 @@ func (pol *ACLPolicy) getIPsFromGroup( return build.IPSet() } -func (pol *ACLPolicy) getIPsFromTag( +func (pol *ACLPolicy) expandIPsFromTag( alias string, machines types.Machines, ) (*netipx.IPSet, error) { @@ -779,7 +765,7 @@ func (pol *ACLPolicy) getIPsFromTag( } // find tag owners - owners, err := getTagOwners(pol, alias) + owners, err := expandOwnersFromTag(pol, alias) if err != nil { if errors.Is(err, ErrInvalidTag) { ipSet, _ := build.IPSet() @@ -811,7 +797,7 @@ func (pol *ACLPolicy) getIPsFromTag( return build.IPSet() } -func (pol *ACLPolicy) getIPsForUser( +func (pol *ACLPolicy) expandIPsFromUser( user string, machines types.Machines, ) (*netipx.IPSet, error) { @@ -832,7 +818,7 @@ func (pol *ACLPolicy) getIPsForUser( return build.IPSet() } -func (pol *ACLPolicy) getIPsFromSingleIP( +func (pol *ACLPolicy) expandIPsFromSingleIP( ip netip.Addr, machines types.Machines, ) (*netipx.IPSet, error) { @@ -850,7 +836,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP( return build.IPSet() } -func (pol *ACLPolicy) getIPsFromIPPrefix( +func (pol *ACLPolicy) expandIPsFromIPPrefix( prefix netip.Prefix, machines types.Machines, ) (*netipx.IPSet, error) { @@ -885,10 +871,10 @@ func isTag(str string) bool { return strings.HasPrefix(str, "tag:") } -// getTags will return the tags of the current machine. +// TagsOfMachine will return the tags of the current machine. // Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. -func (pol *ACLPolicy) GetTagsOfMachine( +func (pol *ACLPolicy) TagsOfMachine( machine types.Machine, ) ([]string, []string) { validTags := make([]string, 0) @@ -897,7 +883,7 @@ func (pol *ACLPolicy) GetTagsOfMachine( validTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool) for _, tag := range machine.HostInfo.RequestTags { - owners, err := getTagOwners(pol, tag) + owners, err := expandOwnersFromTag(pol, tag) if errors.Is(err, ErrInvalidTag) { invalidTagMap[tag] = true @@ -925,6 +911,17 @@ func (pol *ACLPolicy) GetTagsOfMachine( return validTags, invalidTags } +func filterMachinesByUser(machines types.Machines, user string) types.Machines { + out := types.Machines{} + for _, machine := range machines { + if machine.User.Name == user { + out = append(out, machine) + } + } + + return out +} + // FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine. func FilterMachinesByACL( machine *types.Machine, diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 3995935d..a71ef20e 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -690,7 +690,7 @@ func Test_expandGroup(t *testing.T) { t.Run(test.name, func(t *testing.T) { viper.Set("oidc.strip_email_domain", test.args.stripEmail) - got, err := test.field.pol.getUsersInGroup( + got, err := test.field.pol.expandUsersFromGroup( test.args.group, ) @@ -779,7 +779,7 @@ func Test_expandTagOwners(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - got, err := getTagOwners( + got, err := expandOwnersFromTag( test.args.aclPolicy, test.args.tag, ) @@ -2022,7 +2022,7 @@ func Test_getTags(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine( + gotValid, gotInvalid := test.args.aclPolicy.TagsOfMachine( test.args.machine, ) for _, valid := range gotValid {