use IPSet in acls instead of string slice

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-04-28 16:11:02 +02:00 committed by Juan Font
parent 1a7ae11697
commit 735b185e7f
5 changed files with 209 additions and 104 deletions

View file

@ -8,6 +8,7 @@
- Profiles are continously generated in our integration tests. - Profiles are continously generated in our integration tests.
- Fix systemd service file location in `.deb` packages [#1391](https://github.com/juanfont/headscale/pull/1391) - Fix systemd service file location in `.deb` packages [#1391](https://github.com/juanfont/headscale/pull/1391)
- Improvements on Noise implementation [#1379](https://github.com/juanfont/headscale/pull/1379) - Improvements on Noise implementation [#1379](https://github.com/juanfont/headscale/pull/1379)
- Replace node filter logic, ensuring nodes with access can see eachother [#1381](https://github.com/juanfont/headscale/pull/1381)
## 0.22.1 (2023-04-20) ## 0.22.1 (2023-04-20)

186
acls.go
View file

@ -13,8 +13,8 @@ import (
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/tailscale/hujson" "github.com/tailscale/hujson"
"go4.org/netipx"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -272,21 +272,41 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
for innerIndex, rawSrc := range sshACL.Sources { for innerIndex, rawSrc := range sshACL.Sources {
expandedSrcs, err := h.aclPolicy.expandAlias( if isWildcard(rawSrc) {
machines,
rawSrc,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
return nil, err
}
for _, expandedSrc := range expandedSrcs {
principals = append(principals, &tailcfg.SSHPrincipal{ principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc, Any: true,
}) })
} else if isGroup(rawSrc) {
users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
return nil, err
}
for _, user := range users {
principals = append(principals, &tailcfg.SSHPrincipal{
UserLogin: user,
})
}
} else {
expandedSrcs, err := h.aclPolicy.expandAlias(
machines,
rawSrc,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
return nil, err
}
for _, expandedSrc := range expandedSrcs.Prefixes() {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc.Addr().String(),
})
}
} }
} }
@ -295,10 +315,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
userMap[user] = "=" userMap[user] = "="
} }
rules = append(rules, &tailcfg.SSHRule{ rules = append(rules, &tailcfg.SSHRule{
RuleExpires: nil, Principals: principals,
Principals: principals, SSHUsers: userMap,
SSHUsers: userMap, Action: &action,
Action: &action,
}) })
} }
@ -329,7 +348,18 @@ func (pol *ACLPolicy) getIPsFromSource(
machines []Machine, machines []Machine,
stripEmaildomain bool, stripEmaildomain bool,
) ([]string, error) { ) ([]string, error) {
return pol.expandAlias(machines, src, stripEmaildomain) ipSet, err := pol.expandAlias(machines, src, stripEmaildomain)
if err != nil {
return []string{}, err
}
prefixes := []string{}
for _, prefix := range ipSet.Prefixes() {
prefixes = append(prefixes, prefix.String())
}
return prefixes, nil
} }
// getNetPortRangeFromDestination returns a set of tailcfg.NetPortRange // getNetPortRangeFromDestination returns a set of tailcfg.NetPortRange
@ -397,11 +427,11 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
} }
dests := []tailcfg.NetPortRange{} dests := []tailcfg.NetPortRange{}
for _, d := range expanded { for _, dest := range expanded.Prefixes() {
for _, p := range *ports { for _, port := range *ports {
pr := tailcfg.NetPortRange{ pr := tailcfg.NetPortRange{
IP: d, IP: dest.String(),
Ports: p, Ports: port,
} }
dests = append(dests, pr) dests = append(dests, pr)
} }
@ -472,28 +502,30 @@ func (pol *ACLPolicy) expandAlias(
machines Machines, machines Machines,
alias string, alias string,
stripEmailDomain bool, stripEmailDomain bool,
) ([]string, error) { ) (*netipx.IPSet, error) {
if alias == "*" { if isWildcard(alias) {
return []string{"*"}, nil return parseIPSet("*", nil)
} }
build := netipx.IPSetBuilder{}
log.Debug(). log.Debug().
Str("alias", alias). Str("alias", alias).
Msg("Expanding") Msg("Expanding")
// if alias is a group // if alias is a group
if strings.HasPrefix(alias, "group:") { if isGroup(alias) {
return pol.getIPsFromGroup(alias, machines, stripEmailDomain) return pol.getIPsFromGroup(alias, machines, stripEmailDomain)
} }
// if alias is a tag // if alias is a tag
if strings.HasPrefix(alias, "tag:") { if isTag(alias) {
return pol.getIPsFromTag(alias, machines, stripEmailDomain) return pol.getIPsFromTag(alias, machines, stripEmailDomain)
} }
// if alias is a user // if alias is a user
if ips := pol.getIPsForUser(alias, machines, stripEmailDomain); len(ips) > 0 { if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil {
return ips, nil return ips, err
} }
// if alias is an host // if alias is an host
@ -516,7 +548,7 @@ func (pol *ACLPolicy) expandAlias(
log.Warn().Msgf("No IPs found with the alias %v", alias) log.Warn().Msgf("No IPs found with the alias %v", alias)
return []string{}, nil return build.IPSet()
} }
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones // excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
@ -561,7 +593,7 @@ func excludeCorrectlyTaggedNodes(
} }
func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, error) { func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, error) {
if portsStr == "*" { if isWildcard(portsStr) {
return &[]tailcfg.PortRange{ return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd}, {First: portRangeBegin, Last: portRangeEnd},
}, nil }, nil
@ -636,7 +668,7 @@ func getTagOwners(
) )
} }
for _, owner := range ows { for _, owner := range ows {
if strings.HasPrefix(owner, "group:") { if isGroup(owner) {
gs, err := pol.getUsersInGroup(owner, stripEmailDomain) gs, err := pol.getUsersInGroup(owner, stripEmailDomain)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
@ -667,7 +699,7 @@ func (pol *ACLPolicy) getUsersInGroup(
) )
} }
for _, group := range aclGroups { for _, group := range aclGroups {
if strings.HasPrefix(group, "group:") { if isGroup(group) {
return []string{}, fmt.Errorf( return []string{}, fmt.Errorf(
"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", "%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups",
errInvalidGroup, errInvalidGroup,
@ -691,34 +723,34 @@ func (pol *ACLPolicy) getIPsFromGroup(
group string, group string,
machines Machines, machines Machines,
stripEmailDomain bool, stripEmailDomain bool,
) ([]string, error) { ) (*netipx.IPSet, error) {
ips := []string{} build := netipx.IPSetBuilder{}
users, err := pol.getUsersInGroup(group, stripEmailDomain) users, err := pol.getUsersInGroup(group, stripEmailDomain)
if err != nil { if err != nil {
return ips, err return &netipx.IPSet{}, err
} }
for _, n := range users { for _, user := range users {
nodes := filterMachinesByUser(machines, n) filteredMachines := filterMachinesByUser(machines, user)
for _, node := range nodes { for _, machine := range filteredMachines {
ips = append(ips, node.IPAddresses.ToStringSlice()...) machine.IPAddresses.AppendToIPSet(&build)
} }
} }
return ips, nil return build.IPSet()
} }
func (pol *ACLPolicy) getIPsFromTag( func (pol *ACLPolicy) getIPsFromTag(
alias string, alias string,
machines Machines, machines Machines,
stripEmailDomain bool, stripEmailDomain bool,
) ([]string, error) { ) (*netipx.IPSet, error) {
ips := []string{} build := netipx.IPSetBuilder{}
// check for forced tags // check for forced tags
for _, machine := range machines { for _, machine := range machines {
if contains(machine.ForcedTags, alias) { if contains(machine.ForcedTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...) machine.IPAddresses.AppendToIPSet(&build)
} }
} }
@ -726,17 +758,18 @@ func (pol *ACLPolicy) getIPsFromTag(
owners, err := getTagOwners(pol, alias, stripEmailDomain) owners, err := getTagOwners(pol, alias, stripEmailDomain)
if err != nil { if err != nil {
if errors.Is(err, errInvalidTag) { if errors.Is(err, errInvalidTag) {
if len(ips) == 0 { ipSet, _ := build.IPSet()
return ips, fmt.Errorf( if len(ipSet.Prefixes()) == 0 {
return ipSet, fmt.Errorf(
"%w. %v isn't owned by a TagOwner and no forced tags are defined", "%w. %v isn't owned by a TagOwner and no forced tags are defined",
errInvalidTag, errInvalidTag,
alias, alias,
) )
} }
return ips, nil return build.IPSet()
} else { } else {
return ips, err return nil, err
} }
} }
@ -746,53 +779,62 @@ func (pol *ACLPolicy) getIPsFromTag(
for _, machine := range machines { for _, machine := range machines {
hi := machine.GetHostInfo() hi := machine.GetHostInfo()
if contains(hi.RequestTags, alias) { if contains(hi.RequestTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...) machine.IPAddresses.AppendToIPSet(&build)
} }
} }
} }
return ips, nil return build.IPSet()
} }
func (pol *ACLPolicy) getIPsForUser( func (pol *ACLPolicy) getIPsForUser(
user string, user string,
machines Machines, machines Machines,
stripEmailDomain bool, stripEmailDomain bool,
) []string { ) (*netipx.IPSet, error) {
ips := []string{} build := netipx.IPSetBuilder{}
nodes := filterMachinesByUser(machines, user) filteredMachines := filterMachinesByUser(machines, user)
nodes = excludeCorrectlyTaggedNodes(pol, nodes, user, stripEmailDomain) filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain)
for _, n := range nodes { // shortcurcuit if we have no machines to get ips from.
ips = append(ips, n.IPAddresses.ToStringSlice()...) if len(filteredMachines) == 0 {
return nil, nil //nolint
} }
return ips for _, machine := range filteredMachines {
machine.IPAddresses.AppendToIPSet(&build)
}
return build.IPSet()
} }
func (pol *ACLPolicy) getIPsFromSingleIP( func (pol *ACLPolicy) getIPsFromSingleIP(
ip netip.Addr, ip netip.Addr,
machines Machines, machines Machines,
) ([]string, error) { ) (*netipx.IPSet, error) {
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip") log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")
ips := []string{ip.String()}
matches := machines.FilterByIP(ip) matches := machines.FilterByIP(ip)
build := netipx.IPSetBuilder{}
build.Add(ip)
for _, machine := range matches { for _, machine := range matches {
ips = append(ips, machine.IPAddresses.ToStringSlice()...) machine.IPAddresses.AppendToIPSet(&build)
} }
return lo.Uniq(ips), nil return build.IPSet()
} }
func (pol *ACLPolicy) getIPsFromIPPrefix( func (pol *ACLPolicy) getIPsFromIPPrefix(
prefix netip.Prefix, prefix netip.Prefix,
machines Machines, machines Machines,
) ([]string, error) { ) (*netipx.IPSet, error) {
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
val := []string{prefix.String()} build := netipx.IPSetBuilder{}
build.AddPrefix(prefix)
// This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6 // This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6
// addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers. // addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers.
for _, machine := range machines { for _, machine := range machines {
@ -800,10 +842,22 @@ func (pol *ACLPolicy) getIPsFromIPPrefix(
// log.Trace(). // log.Trace().
// Msgf("checking if machine ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String()) // Msgf("checking if machine ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String())
if prefix.Contains(ip) { if prefix.Contains(ip) {
val = append(val, machine.IPAddresses.ToStringSlice()...) machine.IPAddresses.AppendToIPSet(&build)
} }
} }
} }
return lo.Uniq(val), nil return build.IPSet()
}
func isWildcard(str string) bool {
return str == "*"
}
func isGroup(str string) bool {
return strings.HasPrefix(str, "group:")
}
func isTag(str string) bool {
return strings.HasPrefix(str, "tag:")
} }

View file

@ -8,6 +8,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -144,7 +145,7 @@ func (s *Suite) TestSshRules(c *check.C) {
c.Assert(app.sshPolicy.Rules, check.HasLen, 2) c.Assert(app.sshPolicy.Rules, check.HasLen, 2)
c.Assert(app.sshPolicy.Rules[0].SSHUsers, check.HasLen, 1) c.Assert(app.sshPolicy.Rules[0].SSHUsers, check.HasLen, 1)
c.Assert(app.sshPolicy.Rules[0].Principals, check.HasLen, 1) c.Assert(app.sshPolicy.Rules[0].Principals, check.HasLen, 1)
c.Assert(app.sshPolicy.Rules[0].Principals[0].NodeIP, check.Matches, "100.64.0.1") c.Assert(app.sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1")
c.Assert(app.sshPolicy.Rules[1].SSHUsers, check.HasLen, 1) c.Assert(app.sshPolicy.Rules[1].SSHUsers, check.HasLen, 1)
c.Assert(app.sshPolicy.Rules[1].Principals, check.HasLen, 1) c.Assert(app.sshPolicy.Rules[1].Principals, check.HasLen, 1)
@ -232,7 +233,7 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(app.aclRules, check.HasLen, 1) c.Assert(app.aclRules, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.1") c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
} }
// this test should validate that we can expand a group in a TagOWner section and // this test should validate that we can expand a group in a TagOWner section and
@ -282,7 +283,7 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(app.aclRules, check.HasLen, 1) c.Assert(app.aclRules, check.HasLen, 1)
c.Assert(app.aclRules[0].DstPorts, check.HasLen, 1) c.Assert(app.aclRules[0].DstPorts, check.HasLen, 1)
c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1") c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
} }
// need a test with: // need a test with:
@ -331,7 +332,7 @@ func (s *Suite) TestInvalidTagValidUser(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(app.aclRules, check.HasLen, 1) c.Assert(app.aclRules, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.1") c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
} }
// tag on a host is owned by a tag owner, the tag is valid. // tag on a host is owned by a tag owner, the tag is valid.
@ -399,14 +400,14 @@ func (s *Suite) TestValidTagInvalidUser(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(app.aclRules, check.HasLen, 1) c.Assert(app.aclRules, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1) c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1)
c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2") c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2/32")
c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2) c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2)
c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80)) c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80)) c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1") c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443)) c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443)) c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1") c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32")
} }
func (s *Suite) TestPortRange(c *check.C) { func (s *Suite) TestPortRange(c *check.C) {
@ -449,8 +450,8 @@ func (s *Suite) TestPortWildcard(c *check.C) {
c.Assert(rules[0].DstPorts, check.HasLen, 1) c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert(rules[0].SrcIPs, check.HasLen, 1) c.Assert(rules[0].SrcIPs, check.HasLen, 2)
c.Assert(rules[0].SrcIPs[0], check.Equals, "*") c.Assert(rules[0].SrcIPs[0], check.Equals, "0.0.0.0/0")
} }
func (s *Suite) TestPortWildcardYAML(c *check.C) { func (s *Suite) TestPortWildcardYAML(c *check.C) {
@ -465,8 +466,8 @@ func (s *Suite) TestPortWildcardYAML(c *check.C) {
c.Assert(rules[0].DstPorts, check.HasLen, 1) c.Assert(rules[0].DstPorts, check.HasLen, 1)
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0)) c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535)) c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
c.Assert(rules[0].SrcIPs, check.HasLen, 1) c.Assert(rules[0].SrcIPs, check.HasLen, 2)
c.Assert(rules[0].SrcIPs[0], check.Equals, "*") c.Assert(rules[0].SrcIPs[0], check.Equals, "0.0.0.0/0")
} }
func (s *Suite) TestPortUser(c *check.C) { func (s *Suite) TestPortUser(c *check.C) {
@ -511,7 +512,7 @@ func (s *Suite) TestPortUser(c *check.C) {
c.Assert(rules[0].SrcIPs, check.HasLen, 1) c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert(len(ips), check.Equals, 1) c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()) c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
} }
func (s *Suite) TestPortGroup(c *check.C) { func (s *Suite) TestPortGroup(c *check.C) {
@ -554,7 +555,7 @@ func (s *Suite) TestPortGroup(c *check.C) {
c.Assert(rules[0].SrcIPs, check.HasLen, 1) c.Assert(rules[0].SrcIPs, check.HasLen, 1)
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip") c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
c.Assert(len(ips), check.Equals, 1) c.Assert(len(ips), check.Equals, 1)
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()) c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
} }
func Test_expandGroup(t *testing.T) { func Test_expandGroup(t *testing.T) {
@ -920,6 +921,22 @@ func Test_listMachinesInUser(t *testing.T) {
} }
func Test_expandAlias(t *testing.T) { func Test_expandAlias(t *testing.T) {
set := func(ips []string, prefixes []string) *netipx.IPSet {
var builder netipx.IPSetBuilder
for _, ip := range ips {
builder.Add(netip.MustParseAddr(ip))
}
for _, pre := range prefixes {
builder.AddPrefix(netip.MustParsePrefix(pre))
}
s, _ := builder.IPSet()
return s
}
type field struct { type field struct {
pol ACLPolicy pol ACLPolicy
} }
@ -933,7 +950,7 @@ func Test_expandAlias(t *testing.T) {
name string name string
field field field field
args args args args
want []string want *netipx.IPSet
wantErr bool wantErr bool
}{ }{
{ {
@ -953,7 +970,10 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"*"}, want: set([]string{}, []string{
"0.0.0.0/0",
"::/0",
}),
wantErr: false, wantErr: false,
}, },
{ {
@ -993,7 +1013,9 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, want: set([]string{
"100.64.0.1", "100.64.0.2", "100.64.0.3",
}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1033,7 +1055,7 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{}, want: set([]string{}, []string{}),
wantErr: true, wantErr: true,
}, },
{ {
@ -1046,7 +1068,9 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{}, machines: []Machine{},
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"10.0.0.3"}, want: set([]string{
"10.0.0.3",
}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1059,7 +1083,9 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{}, machines: []Machine{},
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"10.0.0.1"}, want: set([]string{
"10.0.0.1",
}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1079,7 +1105,9 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"10.0.0.1"}, want: set([]string{
"10.0.0.1",
}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1100,7 +1128,9 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222"}, want: set([]string{
"10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222",
}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1121,7 +1151,9 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1"}, want: set([]string{
"fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1",
}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1138,7 +1170,7 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{}, machines: []Machine{},
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"10.0.0.132/32"}, want: set([]string{}, []string{"10.0.0.132/32"}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1155,7 +1187,7 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{}, machines: []Machine{},
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"192.168.1.0/24"}, want: set([]string{}, []string{"192.168.1.0/24"}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1169,7 +1201,7 @@ func Test_expandAlias(t *testing.T) {
aclPolicy: ACLPolicy{}, aclPolicy: ACLPolicy{},
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"10.0.0.0/16"}, want: set([]string{}, []string{"10.0.0.0/16"}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1219,7 +1251,9 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"100.64.0.1", "100.64.0.2"}, want: set([]string{
"100.64.0.1", "100.64.0.2",
}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1262,7 +1296,7 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{}, want: set([]string{}, []string{}),
wantErr: true, wantErr: true,
}, },
{ {
@ -1302,7 +1336,7 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"100.64.0.1", "100.64.0.2"}, want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1350,7 +1384,7 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"100.64.0.1", "100.64.0.2"}, want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
wantErr: false, wantErr: false,
}, },
{ {
@ -1400,7 +1434,7 @@ func Test_expandAlias(t *testing.T) {
}, },
stripEmailDomain: true, stripEmailDomain: true,
}, },
want: []string{"100.64.0.4"}, want: set([]string{"100.64.0.4"}, []string{}),
wantErr: false, wantErr: false,
}, },
} }
@ -1416,7 +1450,7 @@ func Test_expandAlias(t *testing.T) {
return return
} }
if !reflect.DeepEqual(got, test.want) { if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("expandAlias() = %v, want %v", got, test.want) t.Errorf("expandAlias() = %v, want %v", got, test.want)
} }
}) })
@ -1702,10 +1736,17 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
SrcIPs: []string{"*"}, SrcIPs: []string{"0.0.0.0/0", "::/0"},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{ {
IP: "*", IP: "0.0.0.0/0",
Ports: tailcfg.PortRange{
First: 0,
Last: 65535,
},
},
{
IP: "::/0",
Ports: tailcfg.PortRange{ Ports: tailcfg.PortRange{
First: 0, First: 0,
Last: 65535, Last: 65535,
@ -1750,17 +1791,17 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
SrcIPs: []string{"100.64.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2221"}, SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0:ab12:4843:2222:6273:2221/128"},
DstPorts: []tailcfg.NetPortRange{ DstPorts: []tailcfg.NetPortRange{
{ {
IP: "100.64.0.2", IP: "100.64.0.2/32",
Ports: tailcfg.PortRange{ Ports: tailcfg.PortRange{
First: 0, First: 0,
Last: 65535, Last: 65535,
}, },
}, },
{ {
IP: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", IP: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222/128",
Ports: tailcfg.PortRange{ Ports: tailcfg.PortRange{
First: 0, First: 0,
Last: 65535, Last: 65535,

View file

@ -424,7 +424,7 @@ func TestSSUserOnlyIsolation(t *testing.T) {
// TODO(kradalby,evenh): ACLs do currently not cover reject // TODO(kradalby,evenh): ACLs do currently not cover reject
// cases properly, and currently will accept all incomming connections // cases properly, and currently will accept all incomming connections
// as long as a rule is present. // as long as a rule is present.
//
// for _, client := range ssh1Clients { // for _, client := range ssh1Clients {
// for _, peer := range ssh2Clients { // for _, peer := range ssh2Clients {
// if client.Hostname() == peer.Hostname() { // if client.Hostname() == peer.Hostname() {

View file

@ -13,6 +13,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/samber/lo" "github.com/samber/lo"
"go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -97,6 +98,14 @@ func (ma MachineAddresses) ToStringSlice() []string {
return strSlice return strSlice
} }
// AppendToIPSet adds the individual ips in MachineAddresses to a
// given netipx.IPSetBuilder.
func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) {
for _, ip := range ma {
build.Add(ip)
}
}
func (ma *MachineAddresses) Scan(destination interface{}) error { func (ma *MachineAddresses) Scan(destination interface{}) error {
switch value := destination.(type) { switch value := destination.(type) {
case string: case string:
@ -1114,7 +1123,7 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error {
} }
// approvedIPs should contain all of machine's IPs if it matches the rule, so check for first // approvedIPs should contain all of machine's IPs if it matches the rule, so check for first
if contains(approvedIps, machine.IPAddresses[0].String()) { if approvedIps.Contains(machine.IPAddresses[0]) {
approvedRoutes = append(approvedRoutes, advertisedRoute) approvedRoutes = append(approvedRoutes, advertisedRoute)
} }
} }