refact: use generics for contains functions

This commit is contained in:
Adrien Raffin-Caboisse 2022-04-25 21:50:40 +02:00
parent ea9aaa6022
commit 8061abe279
5 changed files with 14 additions and 22 deletions

View file

@ -332,7 +332,7 @@ func excludeCorrectlyTaggedNodes(
out := []Machine{} out := []Machine{}
tags := []string{} tags := []string{}
for tag, ns := range aclPolicy.TagOwners { for tag, ns := range aclPolicy.TagOwners {
if containsString(ns, namespace) { if contains(ns, namespace) {
tags = append(tags, tag) tags = append(tags, tag)
} }
} }
@ -342,7 +342,7 @@ func excludeCorrectlyTaggedNodes(
found := false found := false
for _, t := range hi.RequestTags { for _, t := range hi.RequestTags {
if containsString(tags, t) { if contains(tags, t) {
found = true found = true
break break

View file

@ -372,12 +372,12 @@ func nodesToPtables(
tags += "," + tag tags += "," + tag
} }
for _, tag := range machine.InvalidTags { for _, tag := range machine.InvalidTags {
if !containsString(machine.ForcedTags, tag) { if !contains(machine.ForcedTags, tag) {
tags += "," + pterm.LightRed(tag) tags += "," + pterm.LightRed(tag)
} }
} }
for _, tag := range machine.ValidTags { for _, tag := range machine.ValidTags {
if !containsString(machine.ForcedTags, tag) { if !contains(machine.ForcedTags, tag) {
tags += "," + pterm.LightGreen(tag) tags += "," + pterm.LightGreen(tag)
} }
} }

View file

@ -10,6 +10,7 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -565,9 +566,9 @@ func GetFileMode(key string) fs.FileMode {
return fs.FileMode(mode) return fs.FileMode(mode)
} }
func containsString(ss []string, s string) bool { func contains[T string](ts []T, t T) bool {
for _, v := range ss { for _, v := range ts {
if v == s { if reflect.DeepEqual(v,t) {
return true return true
} }
} }

View file

@ -125,7 +125,7 @@ func (machine Machine) isExpired() bool {
func containsAddresses(inputs []string, addrs []string) bool { func containsAddresses(inputs []string, addrs []string) bool {
for _, addr := range addrs { for _, addr := range addrs {
if containsString(inputs, addr) { if contains(inputs, addr) {
return true return true
} }
} }
@ -803,7 +803,7 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
} }
for _, newRoute := range newRoutes { for _, newRoute := range newRoutes {
if !containsIPPrefix(machine.GetAdvertisedRoutes(), newRoute) { if !contains(machine.GetAdvertisedRoutes(), newRoute) {
return fmt.Errorf( return fmt.Errorf(
"route (%s) is not available on node %s: %w", "route (%s) is not available on node %s: %w",
machine.Name, machine.Name,

View file

@ -12,6 +12,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"reflect"
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -223,16 +224,6 @@ func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) {
return ipSet, nil return ipSet, nil
} }
func containsString(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}
func tailNodesToString(nodes []*tailcfg.Node) string { func tailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes)) temp := make([]string, len(nodes))
@ -282,9 +273,9 @@ func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
return result, nil return result, nil
} }
func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool { func contains[T string | netaddr.IPPrefix](ts []T, t T) bool {
for _, p := range prefixes { for _, v := range ts {
if prefix == p { if reflect.DeepEqual(v, t) {
return true return true
} }
} }