From 1700a747f64076d482723b22df8965eac7a06a63 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 26 Apr 2023 14:04:12 +0200 Subject: [PATCH] outline tests for full filter generate Signed-off-by: Kristoffer Dalby --- acls.go | 7 +-- acls_test.go | 127 +++++++++++++++++++++++++++++++++++++++++++++++++++ flake.nix | 1 + 3 files changed, 132 insertions(+), 3 deletions(-) diff --git a/acls.go b/acls.go index ad5ff3fa..73f437bf 100644 --- a/acls.go +++ b/acls.go @@ -228,7 +228,7 @@ func expandACLPeerAddr(srcIP string) []string { // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *ACLPolicy) generateFilterRules( machines []Machine, - stripEmaildomain bool, + stripEmailDomain bool, ) ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} @@ -239,7 +239,7 @@ func (pol *ACLPolicy) generateFilterRules( srcIPs := []string{} for srcIndex, src := range acl.Sources { - srcs, err := pol.getIPsFromSource(src, machines, stripEmaildomain) + srcs, err := pol.getIPsFromSource(src, machines, stripEmailDomain) if err != nil { log.Error(). Interface("src", src). @@ -266,7 +266,7 @@ func (pol *ACLPolicy) generateFilterRules( dest, machines, needsWildcard, - stripEmaildomain, + stripEmailDomain, ) if err != nil { log.Error(). @@ -569,6 +569,7 @@ func (pol *ACLPolicy) expandAlias( } // if alias is an host + // Note, this is recursive. if h, ok := pol.Hosts[alias]; ok { log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry") diff --git a/acls_test.go b/acls_test.go index 4264f07d..f96ac17b 100644 --- a/acls_test.go +++ b/acls_test.go @@ -6,6 +6,8 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/rs/zerolog/log" "gopkg.in/check.v1" "tailscale.com/envknob" "tailscale.com/tailcfg" @@ -1793,3 +1795,128 @@ func Test_expandACLPeerAddrV6(t *testing.T) { }) } } + +func TestACLPolicy_generateFilterRules(t *testing.T) { + type field struct { + pol ACLPolicy + } + type args struct { + machines []Machine + stripEmailDomain bool + } + tests := []struct { + name string + field field + args args + want []tailcfg.FilterRule + wantErr bool + }{ + { + name: "no-policy", + field: field{}, + args: args{}, + want: []tailcfg.FilterRule{}, + wantErr: false, + }, + { + name: "simple group", + field: field{ + pol: ACLPolicy{ + ACLs: []ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + }, + }, + args: args{ + machines: []Machine{}, + stripEmailDomain: true, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "*", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "simple host by ipv4 single dual stack", + field: field{ + pol: ACLPolicy{ + ACLs: []ACL{ + { + Action: "accept", + Sources: []string{"100.64.0.1"}, + Destinations: []string{"100.64.0.2:*"}, + }, + }, + }, + }, + args: args{ + machines: []Machine{ + { + IPAddresses: MachineAddresses{ + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + }, + User: User{Name: "mickael"}, + }, + { + IPAddresses: MachineAddresses{ + netip.MustParseAddr("10.0.0.2"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), + }, + User: User{Name: "mickael"}, + }, + }, + stripEmailDomain: true, + }, + // [{"SrcIPs":["100.64.0.1"],"DstPorts":[{"IP":"100.64.0.2","Bits":null,"Ports":{"First":0,"Last":65535}}]}] + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.2", + Ports: tailcfg.PortRange{ + First: 0, + Last: 65535, + }, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.field.pol.generateFilterRules( + tt.args.machines, + tt.args.stripEmailDomain, + ) + if (err != nil) != tt.wantErr { + t.Errorf("ACLPolicy.generateFilterRules() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + log.Trace().Interface("got", got).Msg("result") + t.Errorf("ACLPolicy.generateFilterRules() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/flake.nix b/flake.nix index 71cb0d11..ca67df5c 100644 --- a/flake.nix +++ b/flake.nix @@ -99,6 +99,7 @@ goreleaser nfpm gotestsum + gotests # 'dot' is needed for pprof graphs # go tool pprof -http=: