mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-19 10:20:05 +09:00
feat(acls): rewrite functions to be testable
Rewrite some function to get rid of the dependency on Headscale object. This allows us to write succinct test that are more easy to review and implement. The improvements of the tests allowed to write the removal of the tagged hosts from the namespace as specified here: https://tailscale.com/kb/1068/acl-tags/
This commit is contained in:
parent
97eac3b938
commit
de59946447
3 changed files with 646 additions and 75 deletions
187
acls.go
187
acls.go
|
@ -2,7 +2,6 @@ package headscale
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
@ -86,6 +85,11 @@ func (h *Headscale) UpdateACLRules() error {
|
|||
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
||||
rules := []tailcfg.FilterRule{}
|
||||
|
||||
machines, err := h.ListAllMachines()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for index, acl := range h.aclPolicy.ACLs {
|
||||
if acl.Action != "accept" {
|
||||
return nil, errInvalidAction
|
||||
|
@ -93,7 +97,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
|||
|
||||
srcIPs := []string{}
|
||||
for innerIndex, user := range acl.Users {
|
||||
srcs, err := h.generateACLPolicySrcIP(user)
|
||||
srcs, err := h.generateACLPolicySrcIP(machines, *h.aclPolicy, user)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing ACL %d, User %d", index, innerIndex)
|
||||
|
@ -105,7 +109,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
|||
|
||||
destPorts := []tailcfg.NetPortRange{}
|
||||
for innerIndex, ports := range acl.Ports {
|
||||
dests, err := h.generateACLPolicyDestPorts(ports)
|
||||
dests, err := h.generateACLPolicyDestPorts(machines, *h.aclPolicy, ports)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
|
||||
|
@ -124,11 +128,13 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
|
|||
return rules, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
|
||||
return h.expandAlias(u)
|
||||
func (h *Headscale) generateACLPolicySrcIP(machines []Machine, aclPolicy ACLPolicy, u string) ([]string, error) {
|
||||
return expandAlias(machines, aclPolicy, u)
|
||||
}
|
||||
|
||||
func (h *Headscale) generateACLPolicyDestPorts(
|
||||
machines []Machine,
|
||||
aclPolicy ACLPolicy,
|
||||
d string,
|
||||
) ([]tailcfg.NetPortRange, error) {
|
||||
tokens := strings.Split(d, ":")
|
||||
|
@ -149,11 +155,11 @@ func (h *Headscale) generateACLPolicyDestPorts(
|
|||
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||
}
|
||||
|
||||
expanded, err := h.expandAlias(alias)
|
||||
expanded, err := expandAlias(machines, aclPolicy, alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ports, err := h.expandPorts(tokens[len(tokens)-1])
|
||||
ports, err := expandPorts(tokens[len(tokens)-1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -177,52 +183,40 @@ func (h *Headscale) generateACLPolicyDestPorts(
|
|||
// - a group
|
||||
// - a tag
|
||||
// and transform these in IPAddresses
|
||||
func (h *Headscale) expandAlias(alias string) ([]string, error) {
|
||||
func expandAlias(machines []Machine, aclPolicy ACLPolicy, alias string) ([]string, error) {
|
||||
ips := []string{}
|
||||
if alias == "*" {
|
||||
return []string{"*"}, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(alias, "group:") {
|
||||
namespaces, err := h.expandGroup(alias)
|
||||
namespaces, err := expandGroup(aclPolicy, alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return ips, err
|
||||
}
|
||||
ips := []string{}
|
||||
for _, n := range namespaces {
|
||||
nodes, err := h.ListMachinesInNamespace(n)
|
||||
if err != nil {
|
||||
return nil, errInvalidNamespace
|
||||
}
|
||||
nodes := listMachinesInNamespace(machines, n)
|
||||
for _, node := range nodes {
|
||||
ips = append(ips, node.IPAddresses.ToStringSlice()...)
|
||||
}
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(alias, "tag:") {
|
||||
var ips []string
|
||||
owners, err := h.expandTagOwners(alias)
|
||||
owners, err := expandTagOwners(aclPolicy, alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return ips, err
|
||||
}
|
||||
for _, namespace := range owners {
|
||||
machines, err := h.ListMachinesInNamespace(namespace)
|
||||
if err != nil {
|
||||
if errors.Is(err, errNamespaceNotFound) {
|
||||
continue
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
machines := listMachinesInNamespace(machines, namespace)
|
||||
for _, machine := range machines {
|
||||
if len(machine.HostInfo) == 0 {
|
||||
continue
|
||||
}
|
||||
hi, err := machine.GetHostInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return ips, err
|
||||
}
|
||||
for _, t := range hi.RequestTags {
|
||||
if alias == t {
|
||||
|
@ -234,75 +228,75 @@ func (h *Headscale) expandAlias(alias string) ([]string, error) {
|
|||
return ips, nil
|
||||
}
|
||||
|
||||
n, err := h.GetNamespace(alias)
|
||||
if err == nil {
|
||||
nodes, err := h.ListMachinesInNamespace(n.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ips := []string{}
|
||||
for _, n := range nodes {
|
||||
ips = append(ips, n.IPAddresses.ToStringSlice()...)
|
||||
}
|
||||
|
||||
// if alias is a namespace
|
||||
nodes := listMachinesInNamespace(machines, alias)
|
||||
nodes, err := excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias)
|
||||
if err != nil {
|
||||
return ips, err
|
||||
}
|
||||
for _, n := range nodes {
|
||||
ips = append(ips, n.IPAddresses.ToStringSlice()...)
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
if h, ok := h.aclPolicy.Hosts[alias]; ok {
|
||||
// if alias is an host
|
||||
if h, ok := aclPolicy.Hosts[alias]; ok {
|
||||
return []string{h.String()}, nil
|
||||
}
|
||||
|
||||
// if alias is an IP
|
||||
ip, err := netaddr.ParseIP(alias)
|
||||
if err == nil {
|
||||
return []string{ip.String()}, nil
|
||||
}
|
||||
|
||||
// if alias is an CIDR
|
||||
cidr, err := netaddr.ParseIPPrefix(alias)
|
||||
if err == nil {
|
||||
return []string{cidr.String()}, nil
|
||||
}
|
||||
|
||||
return nil, errInvalidUserSection
|
||||
return ips, errInvalidUserSection
|
||||
}
|
||||
|
||||
// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group
|
||||
// a group cannot be composed of groups
|
||||
func (h *Headscale) expandTagOwners(owner string) ([]string, error) {
|
||||
var owners []string
|
||||
ows, ok := h.aclPolicy.TagOwners[owner]
|
||||
if !ok {
|
||||
return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, owner)
|
||||
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
|
||||
// that are correctly tagged since they should not be listed as being in the namespace
|
||||
// we assume in this function that we only have nodes from 1 namespace.
|
||||
func excludeCorrectlyTaggedNodes(aclPolicy ACLPolicy, nodes []Machine, namespace string) ([]Machine, error) {
|
||||
out := []Machine{}
|
||||
tags := []string{}
|
||||
for tag, ns := range aclPolicy.TagOwners {
|
||||
if containsString(ns, namespace) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
for _, ow := range ows {
|
||||
if strings.HasPrefix(ow, "group:") {
|
||||
gs, err := h.expandGroup(ow)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
// for each machine if tag is in tags list, don't append it.
|
||||
for _, machine := range nodes {
|
||||
if len(machine.HostInfo) == 0 {
|
||||
out = append(out, machine)
|
||||
continue
|
||||
}
|
||||
hi, err := machine.GetHostInfo()
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
found := false
|
||||
for _, t := range hi.RequestTags {
|
||||
if containsString(tags, t) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
owners = append(owners, gs...)
|
||||
} else {
|
||||
owners = append(owners, ow)
|
||||
}
|
||||
if !found {
|
||||
out = append(out, machine)
|
||||
}
|
||||
}
|
||||
return owners, nil
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// expandGroup will return the list of namespace inside the group
|
||||
// after some validation
|
||||
func (h *Headscale) expandGroup(group string) ([]string, error) {
|
||||
gs, ok := h.aclPolicy.Groups[group]
|
||||
if !ok {
|
||||
return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup)
|
||||
}
|
||||
for _, g := range gs {
|
||||
if strings.HasPrefix(g, "group:") {
|
||||
return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup)
|
||||
}
|
||||
}
|
||||
return gs, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
||||
func expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
||||
if portsStr == "*" {
|
||||
return &[]tailcfg.PortRange{
|
||||
{First: portRangeBegin, Last: portRangeEnd},
|
||||
|
@ -344,3 +338,50 @@ func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
|
|||
|
||||
return &ports, nil
|
||||
}
|
||||
|
||||
func listMachinesInNamespace(machines []Machine, namespace string) []Machine {
|
||||
out := []Machine{}
|
||||
for _, machine := range machines {
|
||||
if machine.Namespace.Name == namespace {
|
||||
out = append(out, machine)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// expandTagOwners will return a list of namespace. An owner can be either a namespace or a group
|
||||
// a group cannot be composed of groups
|
||||
func expandTagOwners(aclPolicy ACLPolicy, tag string) ([]string, error) {
|
||||
var owners []string
|
||||
ows, ok := aclPolicy.TagOwners[tag]
|
||||
if !ok {
|
||||
return []string{}, fmt.Errorf("%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", errInvalidTag, tag)
|
||||
}
|
||||
for _, ow := range ows {
|
||||
if strings.HasPrefix(ow, "group:") {
|
||||
gs, err := expandGroup(aclPolicy, ow)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
owners = append(owners, gs...)
|
||||
} else {
|
||||
owners = append(owners, ow)
|
||||
}
|
||||
}
|
||||
return owners, nil
|
||||
}
|
||||
|
||||
// expandGroup will return the list of namespace inside the group
|
||||
// after some validation
|
||||
func expandGroup(aclPolicy ACLPolicy, group string) ([]string, error) {
|
||||
gs, ok := aclPolicy.Groups[group]
|
||||
if !ok {
|
||||
return []string{}, fmt.Errorf("group %v isn't registered. %w", group, errInvalidGroup)
|
||||
}
|
||||
for _, g := range gs {
|
||||
if strings.HasPrefix(g, "group:") {
|
||||
return []string{}, fmt.Errorf("%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", errInvalidGroup)
|
||||
}
|
||||
}
|
||||
return gs, nil
|
||||
}
|
||||
|
|
521
acls_test.go
521
acls_test.go
|
@ -2,10 +2,13 @@ package headscale
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/datatypes"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func (s *Suite) TestWrongPath(c *check.C) {
|
||||
|
@ -267,9 +270,16 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) {
|
|||
}
|
||||
err = app.UpdateACLRules()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Logf("Rules: %v", app.aclRules)
|
||||
c.Assert(app.aclRules, check.HasLen, 1)
|
||||
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 0)
|
||||
c.Assert(app.aclRules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(app.aclRules[0].SrcIPs[0], check.Equals, "100.64.0.2")
|
||||
c.Assert(app.aclRules[0].DstPorts, check.HasLen, 2)
|
||||
c.Assert(app.aclRules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
|
||||
c.Assert(app.aclRules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
|
||||
c.Assert(app.aclRules[0].DstPorts[0].IP, check.Equals, "100.64.0.1")
|
||||
c.Assert(app.aclRules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
|
||||
c.Assert(app.aclRules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
|
||||
c.Assert(app.aclRules[0].DstPorts[1].IP, check.Equals, "100.64.0.1")
|
||||
}
|
||||
|
||||
func (s *Suite) TestPortRange(c *check.C) {
|
||||
|
@ -385,3 +395,510 @@ func (s *Suite) TestPortGroup(c *check.C) {
|
|||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String())
|
||||
}
|
||||
|
||||
func Test_expandGroup(t *testing.T) {
|
||||
type args struct {
|
||||
aclPolicy ACLPolicy
|
||||
group string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple test",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}},
|
||||
},
|
||||
group: "group:test",
|
||||
},
|
||||
want: []string{"g1", "foo", "test"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "InexistantGroup",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
Groups: Groups{"group:test": []string{"g1", "foo", "test"}, "group:foo": []string{"foo", "test"}},
|
||||
},
|
||||
group: "group:bar",
|
||||
},
|
||||
want: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := expandGroup(tt.args.aclPolicy, tt.args.group)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("expandGroup() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("expandGroup() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_expandTagOwners(t *testing.T) {
|
||||
type args struct {
|
||||
aclPolicy ACLPolicy
|
||||
tag string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple tag",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
TagOwners: TagOwners{"tag:test": []string{"namespace1"}},
|
||||
},
|
||||
tag: "tag:test",
|
||||
},
|
||||
want: []string{"namespace1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tag and group",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
Groups: Groups{"group:foo": []string{"n1", "bar"}},
|
||||
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
|
||||
},
|
||||
tag: "tag:test",
|
||||
},
|
||||
want: []string{"n1", "bar"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "namespace and group",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
Groups: Groups{"group:foo": []string{"n1", "bar"}},
|
||||
TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}},
|
||||
},
|
||||
tag: "tag:test",
|
||||
},
|
||||
want: []string{"n1", "bar", "home"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid tag",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
TagOwners: TagOwners{"tag:foo": []string{"group:foo", "home"}},
|
||||
},
|
||||
tag: "tag:test",
|
||||
},
|
||||
want: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid group",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
Groups: Groups{"group:bar": []string{"n1", "foo"}},
|
||||
TagOwners: TagOwners{"tag:test": []string{"group:foo", "home"}},
|
||||
},
|
||||
tag: "tag:test",
|
||||
},
|
||||
want: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := expandTagOwners(tt.args.aclPolicy, tt.args.tag)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("expandTagOwners() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("expandTagOwners() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_expandPorts(t *testing.T) {
|
||||
type args struct {
|
||||
portsStr string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *[]tailcfg.PortRange
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "wildcard",
|
||||
args: args{portsStr: "*"},
|
||||
want: &[]tailcfg.PortRange{
|
||||
{First: portRangeBegin, Last: portRangeEnd},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "two ports",
|
||||
args: args{portsStr: "80,443"},
|
||||
want: &[]tailcfg.PortRange{
|
||||
{First: 80, Last: 80},
|
||||
{First: 443, Last: 443},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "a range and a port",
|
||||
args: args{portsStr: "80-1024,443"},
|
||||
want: &[]tailcfg.PortRange{
|
||||
{First: 80, Last: 1024},
|
||||
{First: 443, Last: 443},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "out of bounds",
|
||||
args: args{portsStr: "854038"},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong port",
|
||||
args: args{portsStr: "85a38"},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong port in first",
|
||||
args: args{portsStr: "a-80"},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong port in last",
|
||||
args: args{portsStr: "80-85a38"},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong port format",
|
||||
args: args{portsStr: "80-85a38-3"},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := expandPorts(tt.args.portsStr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("expandPorts() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("expandPorts() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_listMachinesInNamespace(t *testing.T) {
|
||||
type args struct {
|
||||
machines []Machine
|
||||
namespace string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []Machine
|
||||
}{
|
||||
{
|
||||
name: "1 machine in namespace",
|
||||
args: args{
|
||||
machines: []Machine{
|
||||
{Namespace: Namespace{Name: "test"}},
|
||||
},
|
||||
namespace: "test",
|
||||
},
|
||||
want: []Machine{
|
||||
{Namespace: Namespace{Name: "test"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "3 machines, 2 in namespace",
|
||||
args: args{
|
||||
machines: []Machine{
|
||||
{ID: 1, Namespace: Namespace{Name: "test"}},
|
||||
{ID: 2, Namespace: Namespace{Name: "foo"}},
|
||||
{ID: 3, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
namespace: "foo",
|
||||
},
|
||||
want: []Machine{
|
||||
{ID: 2, Namespace: Namespace{Name: "foo"}},
|
||||
{ID: 3, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "5 machines, 0 in namespace",
|
||||
args: args{
|
||||
machines: []Machine{
|
||||
{ID: 1, Namespace: Namespace{Name: "test"}},
|
||||
{ID: 2, Namespace: Namespace{Name: "foo"}},
|
||||
{ID: 3, Namespace: Namespace{Name: "foo"}},
|
||||
{ID: 4, Namespace: Namespace{Name: "foo"}},
|
||||
{ID: 5, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
namespace: "bar",
|
||||
},
|
||||
want: []Machine{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := listMachinesInNamespace(tt.args.machines, tt.args.namespace); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("listMachinesInNamespace() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_expandAlias(t *testing.T) {
|
||||
type args struct {
|
||||
machines []Machine
|
||||
aclPolicy ACLPolicy
|
||||
alias string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "wildcard",
|
||||
args: args{
|
||||
alias: "*",
|
||||
machines: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.78.84.227")}},
|
||||
},
|
||||
aclPolicy: ACLPolicy{},
|
||||
},
|
||||
want: []string{"*"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "simple group",
|
||||
args: args{
|
||||
alias: "group:foo",
|
||||
machines: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
|
||||
},
|
||||
aclPolicy: ACLPolicy{
|
||||
Groups: Groups{"group:foo": []string{"foo", "bar"}},
|
||||
},
|
||||
},
|
||||
want: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wrong group",
|
||||
args: args{
|
||||
alias: "group:test",
|
||||
machines: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
|
||||
},
|
||||
aclPolicy: ACLPolicy{
|
||||
Groups: Groups{"group:foo": []string{"foo", "bar"}},
|
||||
},
|
||||
},
|
||||
want: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "simple ipaddress",
|
||||
args: args{
|
||||
alias: "10.0.0.3",
|
||||
machines: []Machine{},
|
||||
aclPolicy: ACLPolicy{},
|
||||
},
|
||||
want: []string{"10.0.0.3"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "private network",
|
||||
args: args{
|
||||
alias: "homeNetwork",
|
||||
machines: []Machine{},
|
||||
aclPolicy: ACLPolicy{
|
||||
Hosts: Hosts{"homeNetwork": netaddr.MustParseIPPrefix("192.168.1.0/24")},
|
||||
},
|
||||
},
|
||||
want: []string{"192.168.1.0/24"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "simple host",
|
||||
args: args{
|
||||
alias: "10.0.0.1",
|
||||
machines: []Machine{},
|
||||
aclPolicy: ACLPolicy{},
|
||||
},
|
||||
want: []string{"10.0.0.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "simple CIDR",
|
||||
args: args{
|
||||
alias: "10.0.0.0/16",
|
||||
machines: []Machine{},
|
||||
aclPolicy: ACLPolicy{},
|
||||
},
|
||||
want: []string{"10.0.0.0/16"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "simple tag",
|
||||
args: args{
|
||||
alias: "tag:test",
|
||||
machines: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
aclPolicy: ACLPolicy{
|
||||
TagOwners: TagOwners{"tag:test": []string{"foo"}},
|
||||
},
|
||||
},
|
||||
want: []string{"100.64.0.1", "100.64.0.2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "No tag defined",
|
||||
args: args{
|
||||
alias: "tag:foo",
|
||||
machines: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "test"}},
|
||||
},
|
||||
aclPolicy: ACLPolicy{
|
||||
Groups: Groups{"group:foo": []string{"foo", "bar"}},
|
||||
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
|
||||
},
|
||||
},
|
||||
want: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "list host in namespace without correctly tagged servers",
|
||||
args: args{
|
||||
alias: "foo",
|
||||
machines: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")}, Namespace: Namespace{Name: "bar"}},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
aclPolicy: ACLPolicy{
|
||||
TagOwners: TagOwners{"tag:test": []string{"foo"}},
|
||||
},
|
||||
},
|
||||
want: []string{"100.64.0.4"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := expandAlias(tt.args.machines, tt.args.aclPolicy, tt.args.alias)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("expandAlias() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("expandAlias() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||
type args struct {
|
||||
aclPolicy ACLPolicy
|
||||
nodes []Machine
|
||||
namespace string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []Machine
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "exclude nodes with valid tags",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
TagOwners: TagOwners{"tag:test": []string{"foo"}},
|
||||
},
|
||||
nodes: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
namespace: "foo",
|
||||
},
|
||||
want: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "all nodes have invalid tags, don't exclude them",
|
||||
args: args{
|
||||
aclPolicy: ACLPolicy{
|
||||
TagOwners: TagOwners{"tag:foo": []string{"foo"}},
|
||||
},
|
||||
nodes: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
namespace: "foo",
|
||||
},
|
||||
want: []Machine{
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")}, Namespace: Namespace{Name: "foo"}, HostInfo: []byte("{\"OS\":\"centos\",\"Hostname\":\"foo\",\"RequestTags\":[\"tag:test\"]}")},
|
||||
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")}, Namespace: Namespace{Name: "foo"}},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := excludeCorrectlyTaggedNodes(tt.args.aclPolicy, tt.args.nodes, tt.args.namespace)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("excludeCorrectlyTaggedNodes() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
13
machine.go
13
machine.go
|
@ -119,6 +119,19 @@ func (machine Machine) isExpired() bool {
|
|||
return time.Now().UTC().After(*machine.Expiry)
|
||||
}
|
||||
|
||||
func (h *Headscale) ListAllMachines() ([]Machine, error) {
|
||||
machines := []Machine{}
|
||||
if err := h.db.Preload("AuthKey").
|
||||
Preload("AuthKey.Namespace").
|
||||
Preload("Namespace").
|
||||
Where("registered").
|
||||
Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
func containsAddresses(inputs []string, addrs MachineAddresses) bool {
|
||||
for _, addr := range addrs.ToStringSlice() {
|
||||
if containsString(inputs, addr) {
|
||||
|
|
Loading…
Reference in a new issue