mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-19 10:20:05 +09:00
Add tests to verify "Hosts" aliases in ACL (#1304)
This commit is contained in:
parent
681c86cc95
commit
ceeef40cdf
4 changed files with 337 additions and 16 deletions
50
acls.go
50
acls.go
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/tailscale/hujson"
|
"github.com/tailscale/hujson"
|
||||||
|
"go4.org/netipx"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -165,16 +166,22 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s
|
||||||
aclCachePeerMap := make(map[string]map[string]struct{})
|
aclCachePeerMap := make(map[string]map[string]struct{})
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
for _, srcIP := range rule.SrcIPs {
|
for _, srcIP := range rule.SrcIPs {
|
||||||
if data, ok := aclCachePeerMap[srcIP]; ok {
|
for _, ip := range expandACLPeerAddr(srcIP) {
|
||||||
|
if data, ok := aclCachePeerMap[ip]; ok {
|
||||||
for _, dstPort := range rule.DstPorts {
|
for _, dstPort := range rule.DstPorts {
|
||||||
data[dstPort.IP] = struct{}{}
|
for _, dstIP := range expandACLPeerAddr(dstPort.IP) {
|
||||||
|
data[dstIP] = struct{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
dstPortsMap := make(map[string]struct{}, len(rule.DstPorts))
|
dstPortsMap := make(map[string]struct{}, len(rule.DstPorts))
|
||||||
for _, dstPort := range rule.DstPorts {
|
for _, dstPort := range rule.DstPorts {
|
||||||
dstPortsMap[dstPort.IP] = struct{}{}
|
for _, dstIP := range expandACLPeerAddr(dstPort.IP) {
|
||||||
|
dstPortsMap[dstIP] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
aclCachePeerMap[ip] = dstPortsMap
|
||||||
}
|
}
|
||||||
aclCachePeerMap[srcIP] = dstPortsMap
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -184,6 +191,41 @@ func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]s
|
||||||
return aclCachePeerMap
|
return aclCachePeerMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expandACLPeerAddr takes a "tailcfg.FilterRule" "IP" and expands it into
|
||||||
|
// something our cache logic can look up, which is "*" or single IP addresses.
|
||||||
|
// This is probably quite inefficient, but it is a result of
|
||||||
|
// "make it work, then make it fast", and a lot of the ACL stuff does not
|
||||||
|
// work, but people have tried to make it fast.
|
||||||
|
func expandACLPeerAddr(srcIP string) []string {
|
||||||
|
if ip, err := netip.ParseAddr(srcIP); err == nil {
|
||||||
|
return []string{ip.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cidr, err := netip.ParsePrefix(srcIP); err == nil {
|
||||||
|
addrs := []string{}
|
||||||
|
|
||||||
|
ipRange := netipx.RangeOfPrefix(cidr)
|
||||||
|
|
||||||
|
from := ipRange.From()
|
||||||
|
too := ipRange.To()
|
||||||
|
|
||||||
|
if from == too {
|
||||||
|
return []string{from.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
for from != too {
|
||||||
|
addrs = append(addrs, from.String())
|
||||||
|
|
||||||
|
from = from.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
return addrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// probably "*" or other string based "IP"
|
||||||
|
return []string{srcIP}
|
||||||
|
}
|
||||||
|
|
||||||
func generateACLRules(
|
func generateACLRules(
|
||||||
machines []Machine,
|
machines []Machine,
|
||||||
aclPolicy ACLPolicy,
|
aclPolicy ACLPolicy,
|
||||||
|
|
64
acls_test.go
64
acls_test.go
|
@ -1556,3 +1556,67 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_expandACLPeerAddr(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
srcIP string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "asterix",
|
||||||
|
args: args{
|
||||||
|
srcIP: "*",
|
||||||
|
},
|
||||||
|
want: []string{"*"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip",
|
||||||
|
args: args{
|
||||||
|
srcIP: "10.0.0.1",
|
||||||
|
},
|
||||||
|
want: []string{"10.0.0.1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip/32",
|
||||||
|
args: args{
|
||||||
|
srcIP: "10.0.0.1/32",
|
||||||
|
},
|
||||||
|
want: []string{"10.0.0.1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip/30",
|
||||||
|
args: args{
|
||||||
|
srcIP: "10.0.0.1/30",
|
||||||
|
},
|
||||||
|
want: []string{
|
||||||
|
"10.0.0.0",
|
||||||
|
"10.0.0.1",
|
||||||
|
"10.0.0.2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ip/28",
|
||||||
|
args: args{
|
||||||
|
srcIP: "192.168.0.128/28",
|
||||||
|
},
|
||||||
|
want: []string{
|
||||||
|
"192.168.0.128", "192.168.0.129", "192.168.0.130",
|
||||||
|
"192.168.0.131", "192.168.0.132", "192.168.0.133",
|
||||||
|
"192.168.0.134", "192.168.0.135", "192.168.0.136",
|
||||||
|
"192.168.0.137", "192.168.0.138", "192.168.0.139",
|
||||||
|
"192.168.0.140", "192.168.0.141", "192.168.0.142",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := expandACLPeerAddr(tt.args.srcIP); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("expandACLPeerAddr() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -439,3 +440,214 @@ func TestACLAllowStarDst(t *testing.T) {
|
||||||
err = scenario.Shutdown()
|
err = scenario.Shutdown()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This test aims to cover cases where individual hosts are allowed and denied
|
||||||
|
// access based on their assigned hostname
|
||||||
|
// https://github.com/juanfont/headscale/issues/941
|
||||||
|
|
||||||
|
// ACL = [{
|
||||||
|
// "DstPorts": [{
|
||||||
|
// "Bits": null,
|
||||||
|
// "IP": "100.64.0.3/32",
|
||||||
|
// "Ports": {
|
||||||
|
// "First": 0,
|
||||||
|
// "Last": 65535
|
||||||
|
// }
|
||||||
|
// }],
|
||||||
|
// "SrcIPs": ["*"]
|
||||||
|
// }, {
|
||||||
|
//
|
||||||
|
// "DstPorts": [{
|
||||||
|
// "Bits": null,
|
||||||
|
// "IP": "100.64.0.2/32",
|
||||||
|
// "Ports": {
|
||||||
|
// "First": 0,
|
||||||
|
// "Last": 65535
|
||||||
|
// }
|
||||||
|
// }],
|
||||||
|
// "SrcIPs": ["100.64.0.1/32"]
|
||||||
|
// }]
|
||||||
|
//
|
||||||
|
// ACL Cache Map= {
|
||||||
|
// "*": {
|
||||||
|
// "100.64.0.3/32": {}
|
||||||
|
// },
|
||||||
|
// "100.64.0.1/32": {
|
||||||
|
// "100.64.0.2/32": {}
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
scenario := aclScenario(t,
|
||||||
|
headscale.ACLPolicy{
|
||||||
|
Hosts: headscale.Hosts{
|
||||||
|
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
||||||
|
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
||||||
|
"test3": netip.MustParsePrefix("100.64.0.3/32"),
|
||||||
|
},
|
||||||
|
ACLs: []headscale.ACL{
|
||||||
|
// Everyone can curl test3
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"*"},
|
||||||
|
Destinations: []string{"test3:*"},
|
||||||
|
},
|
||||||
|
// test1 can curl test2
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"test1"},
|
||||||
|
Destinations: []string{"test2:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Since user/users dont matter here, we basically expect that some clients
|
||||||
|
// will be assigned these ips and that we can pick them up for our own use.
|
||||||
|
test1ip := netip.MustParseAddr("100.64.0.1")
|
||||||
|
test1, err := scenario.FindTailscaleClientByIP(test1ip)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test1fqdn, err := test1.FQDN()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String())
|
||||||
|
test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn)
|
||||||
|
|
||||||
|
test2ip := netip.MustParseAddr("100.64.0.2")
|
||||||
|
test2, err := scenario.FindTailscaleClientByIP(test2ip)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test2fqdn, err := test2.FQDN()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String())
|
||||||
|
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
|
||||||
|
|
||||||
|
test3ip := netip.MustParseAddr("100.64.0.3")
|
||||||
|
test3, err := scenario.FindTailscaleClientByIP(test3ip)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
test3fqdn, err := test3.FQDN()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
test3ipURL := fmt.Sprintf("http://%s/etc/hostname", test3ip.String())
|
||||||
|
test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn)
|
||||||
|
|
||||||
|
// test1 can query test3
|
||||||
|
result, err := test1.Curl(test3ipURL)
|
||||||
|
assert.Len(t, result, 13)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
result, err = test1.Curl(test3fqdnURL)
|
||||||
|
assert.Len(t, result, 13)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// test2 can query test3
|
||||||
|
result, err = test2.Curl(test3ipURL)
|
||||||
|
assert.Len(t, result, 13)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
result, err = test2.Curl(test3fqdnURL)
|
||||||
|
assert.Len(t, result, 13)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// test3 cannot query test1
|
||||||
|
result, err = test3.Curl(test1ipURL)
|
||||||
|
assert.Empty(t, result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
result, err = test3.Curl(test1fqdnURL)
|
||||||
|
assert.Empty(t, result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// test3 cannot query test2
|
||||||
|
result, err = test3.Curl(test2ipURL)
|
||||||
|
assert.Empty(t, result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
result, err = test3.Curl(test2fqdnURL)
|
||||||
|
assert.Empty(t, result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// test1 can query test2
|
||||||
|
result, err = test1.Curl(test2ipURL)
|
||||||
|
assert.Len(t, result, 13)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
result, err = test1.Curl(test2fqdnURL)
|
||||||
|
assert.Len(t, result, 13)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// test2 cannot query test1
|
||||||
|
result, err = test2.Curl(test1ipURL)
|
||||||
|
assert.Empty(t, result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
result, err = test2.Curl(test1fqdnURL)
|
||||||
|
assert.Empty(t, result)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
err = scenario.Shutdown()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestACLNamedHostsCanReachBySubnet is the same as
|
||||||
|
// TestACLNamedHostsCanReach, but it tests if we expand a
|
||||||
|
// full CIDR correctly. All routes should work.
|
||||||
|
func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
scenario := aclScenario(t,
|
||||||
|
headscale.ACLPolicy{
|
||||||
|
Hosts: headscale.Hosts{
|
||||||
|
"all": netip.MustParsePrefix("100.64.0.0/24"),
|
||||||
|
},
|
||||||
|
ACLs: []headscale.ACL{
|
||||||
|
// Everyone can curl test3
|
||||||
|
{
|
||||||
|
Action: "accept",
|
||||||
|
Sources: []string{"*"},
|
||||||
|
Destinations: []string{"all:*"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Test that user1 can visit all user2
|
||||||
|
for _, client := range user1Clients {
|
||||||
|
for _, peer := range user2Clients {
|
||||||
|
fqdn, err := peer.FQDN()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
|
result, err := client.Curl(url)
|
||||||
|
assert.Len(t, result, 13)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that user2 can visit all user1
|
||||||
|
for _, client := range user2Clients {
|
||||||
|
for _, peer := range user1Clients {
|
||||||
|
fqdn, err := peer.FQDN()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
|
result, err := client.Curl(url)
|
||||||
|
assert.Len(t, result, 13)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.Shutdown()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
17
machine.go
17
machine.go
|
@ -170,13 +170,14 @@ func (h *Headscale) filterMachinesByACL(currentMachine *Machine, peers Machines)
|
||||||
// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
// filterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
||||||
func filterMachinesByACL(
|
func filterMachinesByACL(
|
||||||
machine *Machine,
|
machine *Machine,
|
||||||
machines []Machine,
|
machines Machines,
|
||||||
lock *sync.RWMutex,
|
lock *sync.RWMutex,
|
||||||
aclPeerCacheMap map[string]map[string]struct{},
|
aclPeerCacheMap map[string]map[string]struct{},
|
||||||
) Machines {
|
) Machines {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", machine.Hostname).
|
Str("self", machine.Hostname).
|
||||||
|
Str("input", machines.String()).
|
||||||
Msg("Finding peers filtered by ACLs")
|
Msg("Finding peers filtered by ACLs")
|
||||||
|
|
||||||
peers := make(map[uint64]Machine)
|
peers := make(map[uint64]Machine)
|
||||||
|
@ -263,7 +264,7 @@ func filterMachinesByACL(
|
||||||
|
|
||||||
lock.RUnlock()
|
lock.RUnlock()
|
||||||
|
|
||||||
authorizedPeers := make([]Machine, 0, len(peers))
|
authorizedPeers := make(Machines, 0, len(peers))
|
||||||
for _, m := range peers {
|
for _, m := range peers {
|
||||||
authorizedPeers = append(authorizedPeers, m)
|
authorizedPeers = append(authorizedPeers, m)
|
||||||
}
|
}
|
||||||
|
@ -274,8 +275,9 @@ func filterMachinesByACL(
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", machine.Hostname).
|
Str("self", machine.Hostname).
|
||||||
Msgf("Found some machines: %v", machines)
|
Str("peers", authorizedPeers.String()).
|
||||||
|
Msg("Authorized peers")
|
||||||
|
|
||||||
return authorizedPeers
|
return authorizedPeers
|
||||||
}
|
}
|
||||||
|
@ -335,8 +337,9 @@ func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("machine", machine.Hostname).
|
Str("self", machine.Hostname).
|
||||||
Msgf("Found total peers: %s", peers.String())
|
Str("peers", peers.String()).
|
||||||
|
Msg("Peers returned to caller")
|
||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue