make generateFilterRules take machine and peers

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-12 11:18:58 +02:00 committed by Kristoffer Dalby
parent 9c425a1c08
commit 161243c787
2 changed files with 24 additions and 19 deletions

View file

@ -128,7 +128,7 @@ func GenerateFilterRules(
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
}
rules, err := policy.generateFilterRules(append(peers, *machine), stripEmailDomain)
rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain)
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
@ -152,10 +152,12 @@ func GenerateFilterRules(
// generateFilterRules takes a set of machines and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) generateFilterRules(
machines types.Machines,
machine *types.Machine,
peers types.Machines,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
machines := append(peers, *machine)
for index, acl := range pol.ACLs {
if acl.Action != "accept" {

View file

@ -199,7 +199,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil)
}
@ -230,7 +230,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
pol, err := LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
}
@ -310,7 +310,7 @@ func (s *Suite) TestPortRange(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -366,7 +366,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -401,7 +401,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -428,7 +428,7 @@ acls:
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -459,7 +459,7 @@ acls:
c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(types.Machines{}, false)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@ -1620,7 +1620,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
pol ACLPolicy
}
type args struct {
machines types.Machines
machine types.Machine
peers types.Machines
stripEmailDomain bool
}
tests := []struct {
@ -1651,7 +1652,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
},
},
args: args{
machines: types.Machines{},
machine: types.Machine{},
peers: types.Machines{},
stripEmailDomain: true,
},
want: []tailcfg.FilterRule{
@ -1691,14 +1693,14 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
},
},
args: args{
machines: types.Machines{
{
machine: types.Machine{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
},
User: types.User{Name: "mickael"},
},
peers: types.Machines{
{
IPAddresses: types.MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
@ -1739,7 +1741,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.generateFilterRules(
tt.args.machines,
&tt.args.machine,
tt.args.peers,
tt.args.stripEmailDomain,
)
if (err != nil) != tt.wantErr {