From 8061abe279e1c5c9dda8a2870edd29fd120b647b Mon Sep 17 00:00:00 2001 From: Adrien Raffin-Caboisse Date: Mon, 25 Apr 2022 21:50:40 +0200 Subject: [PATCH] refact: use generics for contains functions --- acls.go | 4 ++-- cmd/headscale/cli/nodes.go | 4 ++-- cmd/headscale/cli/utils.go | 7 ++++--- machine.go | 4 ++-- utils.go | 17 ++++------------- 5 files changed, 14 insertions(+), 22 deletions(-) diff --git a/acls.go b/acls.go index 26368362..80660e5a 100644 --- a/acls.go +++ b/acls.go @@ -332,7 +332,7 @@ func excludeCorrectlyTaggedNodes( out := []Machine{} tags := []string{} for tag, ns := range aclPolicy.TagOwners { - if containsString(ns, namespace) { + if contains(ns, namespace) { tags = append(tags, tag) } } @@ -342,7 +342,7 @@ func excludeCorrectlyTaggedNodes( found := false for _, t := range hi.RequestTags { - if containsString(tags, t) { + if contains(tags, t) { found = true break diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 7a6b818a..1f8ed4bd 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -372,12 +372,12 @@ func nodesToPtables( tags += "," + tag } for _, tag := range machine.InvalidTags { - if !containsString(machine.ForcedTags, tag) { + if !contains(machine.ForcedTags, tag) { tags += "," + pterm.LightRed(tag) } } for _, tag := range machine.ValidTags { - if !containsString(machine.ForcedTags, tag) { + if !contains(machine.ForcedTags, tag) { tags += "," + pterm.LightGreen(tag) } } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index de369baa..7ea61df5 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strconv" "strings" "time" @@ -565,9 +566,9 @@ func GetFileMode(key string) fs.FileMode { return fs.FileMode(mode) } -func containsString(ss []string, s string) bool { - for _, v := range ss { - if v == s { +func contains[T string](ts []T, t T) bool { + for _, v := range ts { + if reflect.DeepEqual(v,t) { return true } } diff --git a/machine.go b/machine.go index 2b5da37c..d23f61f4 100644 --- a/machine.go +++ b/machine.go @@ -125,7 +125,7 @@ func (machine Machine) isExpired() bool { func containsAddresses(inputs []string, addrs []string) bool { for _, addr := range addrs { - if containsString(inputs, addr) { + if contains(inputs, addr) { return true } } @@ -803,7 +803,7 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { } for _, newRoute := range newRoutes { - if !containsIPPrefix(machine.GetAdvertisedRoutes(), newRoute) { + if !contains(machine.GetAdvertisedRoutes(), newRoute) { return fmt.Errorf( "route (%s) is not available on node %s: %w", machine.Name, diff --git a/utils.go b/utils.go index af267eb3..44110428 100644 --- a/utils.go +++ b/utils.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "net" + "reflect" "strings" "github.com/rs/zerolog/log" @@ -223,16 +224,6 @@ func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) { return ipSet, nil } -func containsString(ss []string, s string) bool { - for _, v := range ss { - if v == s { - return true - } - } - - return false -} - func tailNodesToString(nodes []*tailcfg.Node) string { temp := make([]string, len(nodes)) @@ -282,9 +273,9 @@ func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { return result, nil } -func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool { - for _, p := range prefixes { - if prefix == p { +func contains[T string | netaddr.IPPrefix](ts []T, t T) bool { + for _, v := range ts { + if reflect.DeepEqual(v, t) { return true } }