diff --git a/machine_test.go b/machine_test.go index 90caf9d2..fb104f7d 100644 --- a/machine_test.go +++ b/machine_test.go @@ -2,11 +2,14 @@ package headscale import ( "fmt" + "reflect" "strconv" + "testing" "time" "gopkg.in/check.v1" "inet.af/netaddr" + "tailscale.com/tailcfg" ) func (s *Suite) TestGetMachine(c *check.C) { @@ -295,3 +298,168 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { c.Assert(deserialized[i], check.Equals, input[i]) } } + +func Test_getFilteredByACLPeers(t *testing.T) { + type args struct { + machines []Machine + rules []tailcfg.FilterRule + machine *Machine + } + tests := []struct { + name string + args args + want Machines + wantErr bool + }{ + { + name: "all hosts can talk to each other", + args: args{ + machines: []Machine{ // list of all machines in the database + { + ID: 1, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, + Namespace: Namespace{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, + Namespace: Namespace{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + {SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + machine: &Machine{ // current machine + ID: 1, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, + Namespace: Namespace{Name: "joe"}, + }, + }, + want: Machines{ + { + ID: 2, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, + Namespace: Namespace{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, + Namespace: Namespace{Name: "mickael"}, + }, + }, + wantErr: false, + }, + { + name: "One host can talk to another, but not all hosts", + args: args{ + machines: []Machine{ // list of all machines in the database + { + ID: 1, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, + Namespace: Namespace{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, + Namespace: Namespace{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + {SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &Machine{ // current machine + ID: 1, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, + Namespace: Namespace{Name: "joe"}, + }, + }, + want: Machines{ + { + ID: 2, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, + Namespace: Namespace{Name: "marc"}, + }, + }, + wantErr: false, + }, + { + name: "host cannot directly talk to destination, but return path is authorized", + args: args{ + machines: []Machine{ // list of all machines in the database + { + ID: 1, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, + Namespace: Namespace{Name: "joe"}, + }, + { + ID: 2, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, + Namespace: Namespace{Name: "marc"}, + }, + { + ID: 3, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, + Namespace: Namespace{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + {SrcIPs: []string{"100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + machine: &Machine{ // current machine + ID: 1, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, + Namespace: Namespace{Name: "marc"}, + }, + }, + want: Machines{ + { + ID: 3, + IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, + Namespace: Namespace{Name: "mickael"}, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getFilteredByACLPeers(tt.args.machines, tt.args.rules, tt.args.machine) + if (err != nil) != tt.wantErr { + t.Errorf("getFilteredByACLPeers() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getFilteredByACLPeers() = %v, want %v", got, tt.want) + } + }) + } +} + +var getFilteredByACLPeersTestRules = []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{{IP: "*"}}, + }, +}