This commit is contained in:
Kristoffer Dalby 2024-11-17 13:56:48 -07:00 committed by GitHub
commit 4729f61480
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 3191 additions and 211 deletions

View file

@ -30,6 +30,7 @@ import (
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policyv2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog"
@ -89,6 +90,7 @@ type Headscale struct {
DERPServer *derpServer.DERPServer
ACLPolicy *policy.ACLPolicy
PolicyManager *policyv2.PolicyManager
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier

View file

@ -1764,9 +1764,9 @@ var tsExitNodeDest = []tailcfg.NetPortRange{
},
}
// hsExitNodeDest is the list of destination IP ranges that are allowed when
// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when
// we use headscale "autogroup:internet".
var hsExitNodeDest = []tailcfg.NetPortRange{
var hsExitNodeDestForTest = []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
@ -1823,13 +1823,13 @@ func TestTheInternet(t *testing.T) {
internetPrefs := internetSet.Prefixes()
for i := range internetPrefs {
if internetPrefs[i].String() != hsExitNodeDest[i].IP {
t.Errorf("prefix from internet set %q != hsExit list %q", internetPrefs[i].String(), hsExitNodeDest[i].IP)
if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP {
t.Errorf("prefix from internet set %q != hsExit list %q", internetPrefs[i].String(), hsExitNodeDestForTest[i].IP)
}
}
if len(internetPrefs) != len(hsExitNodeDest) {
t.Fatalf("expected same length of prefixes, internet: %d, hsExit: %d", len(internetPrefs), len(hsExitNodeDest))
if len(internetPrefs) != len(hsExitNodeDestForTest) {
t.Fatalf("expected same length of prefixes, internet: %d, hsExit: %d", len(internetPrefs), len(hsExitNodeDestForTest))
}
}
@ -1838,20 +1838,27 @@ func TestReduceFilterRules(t *testing.T) {
name string
node *types.Node
peers types.Nodes
pol ACLPolicy
pol string
want []tailcfg.FilterRule
}{
{
name: "host1-can-reach-host2-no-rules",
pol: ACLPolicy{
ACLs: []ACL{
pol: `
{
"acls": [
{
Action: "accept",
Sources: []string{"100.64.0.1"},
Destinations: []string{"100.64.0.2:*"},
},
},
},
"action": "accept",
"proto": "",
"src": [
"100.64.0.1"
],
"dst": [
"100.64.0.2:*"
]
}
],
}
`,
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
@ -1868,23 +1875,37 @@ func TestReduceFilterRules(t *testing.T) {
},
{
name: "1604-subnet-routers-are-preserved",
pol: ACLPolicy{
Groups: Groups{
"group:admins": {"user1"},
pol: `
{
"groups": {
"group:admins": [
"user1"
]
},
ACLs: []ACL{
"acls": [
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"group:admins:*"},
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"group:admins:*"
]
},
{
Action: "accept",
Sources: []string{"group:admins"},
Destinations: []string{"10.33.0.0/16:*"},
},
},
},
"action": "accept",
"proto": "",
"src": [
"group:admins"
],
"dst": [
"10.33.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
@ -1939,31 +1960,42 @@ func TestReduceFilterRules(t *testing.T) {
},
{
name: "1786-reducing-breaks-exit-nodes-the-client",
pol: ACLPolicy{
Hosts: Hosts{
// Exit node
"internal": netip.MustParsePrefix("100.64.0.100/32"),
pol: `
{
"groups": {
"group:team": [
"user3",
"user2",
"user1"
]
},
Groups: Groups{
"group:team": {"user3", "user2", "user1"},
"hosts": {
"internal": "100.64.0.100/32"
},
ACLs: []ACL{
"acls": [
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"internal:*",
},
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"autogroup:internet:*",
},
},
},
},
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
@ -1989,31 +2021,42 @@ func TestReduceFilterRules(t *testing.T) {
},
{
name: "1786-reducing-breaks-exit-nodes-the-exit",
pol: ACLPolicy{
Hosts: Hosts{
// Exit node
"internal": netip.MustParsePrefix("100.64.0.100/32"),
pol: `
{
"groups": {
"group:team": [
"user3",
"user2",
"user1"
]
},
Groups: Groups{
"group:team": {"user3", "user2", "user1"},
"hosts": {
"internal": "100.64.0.100/32"
},
ACLs: []ACL{
"acls": [
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"internal:*",
},
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"autogroup:internet:*",
},
},
},
},
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"autogroup:internet:*"
]
}
],
}
`,
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
@ -2050,32 +2093,42 @@ func TestReduceFilterRules(t *testing.T) {
},
{
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: hsExitNodeDest,
DstPorts: hsExitNodeDestForTest,
},
},
},
{
name: "1786-reducing-breaks-exit-nodes-the-example-from-issue",
pol: ACLPolicy{
Hosts: Hosts{
// Exit node
"internal": netip.MustParsePrefix("100.64.0.100/32"),
pol: `
{
"groups": {
"group:team": [
"user3",
"user2",
"user1"
]
},
Groups: Groups{
"group:team": {"user3", "user2", "user1"},
"hosts": {
"internal": "100.64.0.100/32"
},
ACLs: []ACL{
"acls": [
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"internal:*",
},
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"0.0.0.0/5:*",
"8.0.0.0/7:*",
"11.0.0.0/8:*",
@ -2105,11 +2158,12 @@ func TestReduceFilterRules(t *testing.T) {
"194.0.0.0/7:*",
"196.0.0.0/6:*",
"200.0.0.0/5:*",
"208.0.0.0/4:*",
},
},
},
},
"208.0.0.0/4:*"
]
}
],
}
`,
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
@ -2186,32 +2240,43 @@ func TestReduceFilterRules(t *testing.T) {
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like",
pol: ACLPolicy{
Hosts: Hosts{
// Exit node
"internal": netip.MustParsePrefix("100.64.0.100/32"),
pol: `
{
"groups": {
"group:team": [
"user3",
"user2",
"user1"
]
},
Groups: Groups{
"group:team": {"user3", "user2", "user1"},
"hosts": {
"internal": "100.64.0.100/32"
},
ACLs: []ACL{
"acls": [
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"internal:*",
},
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/8:*",
"16.0.0.0/8:*",
},
},
},
},
"16.0.0.0/8:*"
]
}
],
}
`,
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
@ -2263,32 +2328,43 @@ func TestReduceFilterRules(t *testing.T) {
},
{
name: "1786-reducing-breaks-exit-nodes-app-connector-like2",
pol: ACLPolicy{
Hosts: Hosts{
// Exit node
"internal": netip.MustParsePrefix("100.64.0.100/32"),
pol: `
{
"groups": {
"group:team": [
"user3",
"user2",
"user1"
]
},
Groups: Groups{
"group:team": {"user3", "user2", "user1"},
"hosts": {
"internal": "100.64.0.100/32"
},
ACLs: []ACL{
"acls": [
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"internal:*",
},
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"internal:*"
]
},
{
Action: "accept",
Sources: []string{"group:team"},
Destinations: []string{
"action": "accept",
"proto": "",
"src": [
"group:team"
],
"dst": [
"8.0.0.0/16:*",
"16.0.0.0/16:*",
},
},
},
},
"16.0.0.0/16:*"
]
}
],
}
`,
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
@ -2340,25 +2416,32 @@ func TestReduceFilterRules(t *testing.T) {
},
{
name: "1817-reduce-breaks-32-mask",
pol: ACLPolicy{
Hosts: Hosts{
"vlan1": netip.MustParsePrefix("172.16.0.0/24"),
"dns1": netip.MustParsePrefix("172.16.0.21/32"),
pol: `
{
"groups": {
"group:access": [
"user1"
]
},
Groups: Groups{
"group:access": {"user1"},
"hosts": {
"dns1": "172.16.0.21/32",
"vlan1": "172.16.0.0/24"
},
ACLs: []ACL{
"acls": [
{
Action: "accept",
Sources: []string{"group:access"},
Destinations: []string{
"action": "accept",
"proto": "",
"src": [
"group:access"
],
"dst": [
"tag:access-servers:*",
"dns1:*",
},
},
},
},
"dns1:*"
]
}
],
}
`,
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
@ -2399,7 +2482,11 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, _ := tt.pol.CompileFilterRules(
pol, err := LoadACLPolicyFromBytes([]byte(tt.pol))
if err != nil {
t.Fatalf("parsing policy: %s", err)
}
got, _ := pol.CompileFilterRules(
append(tt.peers, tt.node),
)

View file

@ -0,0 +1,180 @@
package policyv2
import (
"errors"
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/tailcfg"
)
var (
ErrInvalidAction = errors.New("invalid action")
)
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *Policy) CompileFilterRules(
users types.Users,
nodes types.Nodes,
) ([]tailcfg.FilterRule, error) {
if pol == nil {
return tailcfg.FilterAllowAll, nil
}
var rules []tailcfg.FilterRule
for _, acl := range pol.ACLs {
if acl.Action != "accept" {
return nil, ErrInvalidAction
}
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
if err != nil {
return nil, fmt.Errorf("resolving source ips: %w", err)
}
// TODO(kradalby): integrate type into schema
// TODO(kradalby): figure out the _ is wildcard stuff
protocols, _, err := parseProtocol(acl.Protocol)
if err != nil {
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
}
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, users, nodes)
if err != nil {
return nil, err
}
for _, pref := range ips.Prefixes() {
for _, port := range dest.Ports {
pr := tailcfg.NetPortRange{
IP: pref.String(),
Ports: port,
}
destPorts = append(destPorts, pr)
}
}
}
rules = append(rules, tailcfg.FilterRule{
SrcIPs: ipSetToPrefixStringList(srcIPs),
DstPorts: destPorts,
IPProto: protocols,
})
}
return rules, nil
}
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
return tailcfg.SSHAction{
Reject: !accept,
Accept: accept,
SessionDuration: duration,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
}
}
func (pol *Policy) CompileSSHPolicy(
users types.Users,
node types.Node,
nodes types.Nodes,
) (*tailcfg.SSHPolicy, error) {
if pol == nil {
return nil, nil
}
var rules []*tailcfg.SSHRule
for index, rule := range pol.SSHs {
var dest netipx.IPSetBuilder
for _, src := range rule.Destinations {
ips, err := src.Resolve(pol, users, nodes)
if err != nil {
return nil, err
}
dest.AddSet(ips)
}
destSet, err := dest.IPSet()
if err != nil {
return nil, err
}
if !node.InIPSet(destSet) {
continue
}
var action tailcfg.SSHAction
switch rule.Action {
case "accept":
action = sshAction(true, 0)
case "check":
action = sshAction(true, rule.CheckPeriod)
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}
var principals []*tailcfg.SSHPrincipal
for _, src := range rule.Sources {
if isWildcard(rawSrc) {
principals = append(principals, &tailcfg.SSHPrincipal{
Any: true,
})
} else if isGroup(rawSrc) {
users, err := pol.expandUsersFromGroup(rawSrc)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err)
}
for _, user := range users {
principals = append(principals, &tailcfg.SSHPrincipal{
UserLogin: user,
})
}
} else {
expandedSrcs, err := pol.ExpandAlias(
peers,
rawSrc,
)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
}
for _, expandedSrc := range expandedSrcs.Prefixes() {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc.Addr().String(),
})
}
}
}
userMap := make(map[string]string, len(rule.Users))
for _, user := range rule.Users {
userMap[user] = "="
}
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}
return &tailcfg.SSHPolicy{
Rules: rules,
}, nil
}
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
var out []string
for _, pref := range ips.Prefixes() {
out = append(out, pref.String())
}
return out
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,80 @@
package policyv2
import (
"fmt"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
)
type PolicyManager struct {
mu sync.Mutex
pol *Policy
users []types.User
nodes types.Nodes
filter []tailcfg.FilterRule
// TODO(kradalby): Implement SSH policy
sshPolicy *tailcfg.SSHPolicy
}
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
// It returns an error if the policy file is invalid.
// The policy manager will update the filter rules based on the users and nodes.
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
policy, err := policyFromBytes(b)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}
pm := PolicyManager{
pol: policy,
users: users,
nodes: nodes,
}
err = pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
// Filter returns the current filter rules for the entire tailnet.
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManager) updateLocked() error {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return fmt.Errorf("compiling filter rules: %w", err)
}
pm.filter = filter
return nil
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetUsers(users []types.User) error {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetNodes(nodes types.Nodes) error {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}

View file

@ -0,0 +1,58 @@
package policyv2
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
return &types.Node{
ID: 0,
Hostname: name,
IPv4: ap(ipv4),
IPv6: ap(ipv6),
User: user,
UserID: user.ID,
Hostinfo: hostinfo,
}
}
func TestPolicyManager(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
}
tests := []struct {
name string
pol string
nodes types.Nodes
wantFilter []tailcfg.FilterRule
}{
{
name: "empty-policy",
pol: "{}",
nodes: types.Nodes{},
wantFilter: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
require.NoError(t, err)
filter := pm.Filter()
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
}
// TODO(kradalby): Test SSH Policy
})
}
}

821
hscontrol/policyv2/types.go Normal file
View file

@ -0,0 +1,821 @@
package policyv2
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/netip"
"strconv"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/tailscale/hujson"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var theInternetSet *netipx.IPSet
// theInternet returns the IPSet for the Internet.
// https://www.youtube.com/watch?v=iDbyYGrswtg
func theInternet() *netipx.IPSet {
if theInternetSet != nil {
return theInternetSet
}
var internetBuilder netipx.IPSetBuilder
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
internetBuilder.AddPrefix(tsaddr.AllIPv4())
// Delete Private network addresses
// https://datatracker.ietf.org/doc/html/rfc1918
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
// Delete Tailscale networks
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
// Delete "cant find DHCP networks"
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-loca
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
theInternetSet, _ := internetBuilder.IPSet()
return theInternetSet
}
type Asterix int
func (a Asterix) Validate() error {
return nil
}
func (a Asterix) String() string {
return "*"
}
func (a Asterix) UnmarshalJSON(b []byte) error {
return nil
}
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
ips.AddPrefix(tsaddr.AllIPv4())
ips.AddPrefix(tsaddr.AllIPv6())
return ips.IPSet()
}
// Username is a string that represents a username, it must contain an @.
type Username string
func (u Username) Validate() error {
if strings.Contains(string(u), "@") {
return nil
}
return fmt.Errorf("Username has to contain @, got: %q", u)
}
func (u *Username) String() string {
return string(*u)
}
func (u *Username) UnmarshalJSON(b []byte) error {
*u = Username(strings.Trim(string(b), `"`))
if err := u.Validate(); err != nil {
return err
}
return nil
}
func (u Username) CanBeTagOwner() bool {
return true
}
func (u Username) resolveUser(users types.Users) (*types.User, error) {
var potentialUsers types.Users
for _, user := range users {
if user.ProviderIdentifier == string(u) {
potentialUsers = append(potentialUsers, user)
break
}
if user.Email == string(u) {
potentialUsers = append(potentialUsers, user)
}
if user.Name == string(u) {
potentialUsers = append(potentialUsers, user)
}
}
if len(potentialUsers) > 1 {
return nil, fmt.Errorf("unable to resolve user identifier to distinct: %s matched multiple %s", u, potentialUsers)
} else if len(potentialUsers) == 0 {
return nil, fmt.Errorf("unable to resolve user identifier, no user found: %s not in %s", u, users)
}
user := potentialUsers[0]
return &user, nil
}
func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
user, err := u.resolveUser(users)
if err != nil {
return nil, err
}
for _, node := range nodes {
if node.IsTagged() {
continue
}
if node.User.ID == user.ID {
node.AppendToIPSet(&ips)
}
}
return ips.IPSet()
}
// Group is a special string which is always prefixed with `group:`
type Group string
func (g Group) Validate() error {
if strings.HasPrefix(string(g), "group:") {
return nil
}
return fmt.Errorf(`Group has to start with "group:", got: %q`, g)
}
func (g *Group) UnmarshalJSON(b []byte) error {
*g = Group(strings.Trim(string(b), `"`))
if err := g.Validate(); err != nil {
return err
}
return nil
}
func (g Group) CanBeTagOwner() bool {
return true
}
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
for _, user := range p.Groups[g] {
uips, err := user.Resolve(nil, users, nodes)
if err != nil {
return nil, err
}
ips.AddSet(uips)
}
return ips.IPSet()
}
// Tag is a special string which is always prefixed with `tag:`
type Tag string
func (t Tag) Validate() error {
if strings.HasPrefix(string(t), "tag:") {
return nil
}
return fmt.Errorf(`tag has to start with "tag:", got: %q`, t)
}
func (t *Tag) UnmarshalJSON(b []byte) error {
*t = Tag(strings.Trim(string(b), `"`))
if err := t.Validate(); err != nil {
return err
}
return nil
}
func (t Tag) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
for _, node := range nodes {
if node.HasTag(string(t)) {
node.AppendToIPSet(&ips)
}
}
return ips.IPSet()
}
// Host is a string that represents a hostname.
type Host string
func (h Host) Validate() error {
return nil
}
func (h *Host) UnmarshalJSON(b []byte) error {
*h = Host(strings.Trim(string(b), `"`))
if err := h.Validate(); err != nil {
return err
}
return nil
}
func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
pref, ok := p.Hosts[h]
if !ok {
return nil, fmt.Errorf("unable to resolve host: %q", h)
}
err := pref.Validate()
if err != nil {
return nil, err
}
// If the IP is a single host, look for a node to ensure we add all the IPs of
// the node to the IPSet.
appendIfNodeHasIP(nodes, &ips, pref)
ips.AddPrefix(netip.Prefix(pref))
return ips.IPSet()
}
func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref Prefix) {
if netip.Prefix(pref).IsSingleIP() {
addr := netip.Prefix(pref).Addr()
for _, node := range nodes {
if node.HasIP(addr) {
node.AppendToIPSet(ips)
}
}
}
}
type Prefix netip.Prefix
func (p Prefix) Validate() error {
if !netip.Prefix(p).IsValid() {
return fmt.Errorf("Prefix %q is invalid", p)
}
return nil
}
func (p Prefix) String() string {
return netip.Prefix(p).String()
}
func (p *Prefix) parseString(addr string) error {
if !strings.Contains(addr, "/") {
addr, err := netip.ParseAddr(addr)
if err != nil {
return err
}
addrPref, err := addr.Prefix(addr.BitLen())
if err != nil {
return err
}
*p = Prefix(addrPref)
return nil
}
pref, err := netip.ParsePrefix(addr)
if err != nil {
return err
}
*p = Prefix(pref)
return nil
}
func (p *Prefix) UnmarshalJSON(b []byte) error {
err := p.parseString(strings.Trim(string(b), `"`))
if err != nil {
return err
}
if err := p.Validate(); err != nil {
return err
}
return nil
}
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
appendIfNodeHasIP(nodes, &ips, p)
ips.AddPrefix(netip.Prefix(p))
return ips.IPSet()
}
// AutoGroup is a special string which is always prefixed with `autogroup:`
type AutoGroup string
const (
AutoGroupInternet = "autogroup:internet"
)
var autogroups = []string{AutoGroupInternet}
func (ag AutoGroup) Validate() error {
for _, valid := range autogroups {
if valid == string(ag) {
return nil
}
}
return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups)
}
func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
*ag = AutoGroup(strings.Trim(string(b), `"`))
if err := ag.Validate(); err != nil {
return err
}
return nil
}
func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) {
switch ag {
case AutoGroupInternet:
return theInternet(), nil
}
return nil, nil
}
type Alias interface {
Validate() error
UnmarshalJSON([]byte) error
Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error)
}
type AliasWithPorts struct {
Alias
Ports []tailcfg.PortRange
}
func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
return err
}
switch vs := v.(type) {
case string:
var portsPart string
var err error
if strings.Contains(vs, ":") {
vs, portsPart, err = splitDestination(vs)
if err != nil {
return err
}
ports, err := parsePorts(portsPart)
if err != nil {
return err
}
ve.Ports = ports
}
ve.Alias = parseAlias(vs)
if ve.Alias == nil {
return fmt.Errorf("could not determine the type of %q", vs)
}
if err := ve.Alias.Validate(); err != nil {
return err
}
default:
return fmt.Errorf("type %T not supported", vs)
}
return nil
}
func parseAlias(vs string) Alias {
// case netip.Addr:
// ve.Alias = Addr(val)
// case netip.Prefix:
// ve.Alias = Prefix(val)
var pref Prefix
err := pref.parseString(vs)
if err == nil {
return &pref
}
switch {
case vs == "*":
return Asterix(0)
case strings.Contains(vs, "@"):
return ptr.To(Username(vs))
case strings.HasPrefix(vs, "group:"):
return ptr.To(Group(vs))
case strings.HasPrefix(vs, "tag:"):
return ptr.To(Tag(vs))
case strings.HasPrefix(vs, "autogroup:"):
return ptr.To(AutoGroup(vs))
}
if !strings.Contains(vs, "@") && !strings.Contains(vs, ":") {
return ptr.To(Host(vs))
}
return nil
}
// AliasEnc is used to deserialize a Alias.
type AliasEnc struct{ Alias }
func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
return err
}
switch val := v.(type) {
case string:
ve.Alias = parseAlias(val)
if ve.Alias == nil {
return fmt.Errorf("could not determine the type of %q", val)
}
if err := ve.Alias.Validate(); err != nil {
return err
}
default:
return fmt.Errorf("type %T not supported", val)
}
return nil
}
type Aliases []Alias
func (a *Aliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
(*a)[i] = alias.Alias
}
return nil
}
func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
for _, alias := range a {
aips, err := alias.Resolve(p, users, nodes)
if err != nil {
return nil, err
}
ips.AddSet(aips)
}
return ips.IPSet()
}
type Owner interface {
CanBeTagOwner() bool
UnmarshalJSON([]byte) error
}
// OwnerEnc is used to deserialize a Owner.
type OwnerEnc struct{ Owner }
func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
return err
}
switch val := v.(type) {
case string:
switch {
case strings.Contains(val, "@"):
ve.Owner = ptr.To(Username(val))
case strings.HasPrefix(val, "group:"):
ve.Owner = ptr.To(Group(val))
}
default:
return fmt.Errorf("type %T not supported", val)
}
return nil
}
type Owners []Owner
func (o *Owners) UnmarshalJSON(b []byte) error {
var owners []OwnerEnc
err := json.Unmarshal(b, &owners)
if err != nil {
return err
}
*o = make([]Owner, len(owners))
for i, owner := range owners {
(*o)[i] = owner.Owner
}
return nil
}
type Usernames []Username
// Groups are a map of Group to a list of Username.
type Groups map[Group]Usernames
// Hosts are alias for IP addresses or subnets.
type Hosts map[Host]Prefix
// TagOwners are a map of Tag to a list of the UserEntities that own the tag.
type TagOwners map[Tag]Owners
type AutoApprovers struct {
Routes map[string][]string `json:"routes"`
ExitNode []string `json:"exitNode"`
}
type ACL struct {
Action string `json:"action"`
Protocol string `json:"proto"`
Sources Aliases `json:"src"`
Destinations []AliasWithPorts `json:"dst"`
}
// Policy represents a Tailscale Network Policy.
// TODO(kradalby):
// Add validation method checking:
// All users exists
// All groups and users are valid tag TagOwners
// Everything referred to in ACLs exists in other
// entities.
type Policy struct {
// validated is set if the policy has been validated.
// It is not safe to use before it is validated, and
// callers using it should panic if not
validated bool `json:"-"`
Groups Groups `json:"groups"`
Hosts Hosts `json:"hosts"`
TagOwners TagOwners `json:"tagOwners"`
ACLs []ACL `json:"acls"`
AutoApprovers AutoApprovers `json:"autoApprovers"`
SSHs []SSH `json:"ssh"`
}
// SSH controls who can ssh into which machines.
type SSH struct {
Action string `json:"action"`
Sources SSHSrcAliases `json:"src"`
Destinations SSHDstAliases `json:"dst"`
Users []SSHUser `json:"users"`
CheckPeriod time.Duration `json:"checkPeriod,omitempty"`
}
// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.
// It can be a list of usernames, groups, tags or autogroups.
type SSHSrcAliases []Alias
func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
switch alias.Alias.(type) {
case *Username, *Group, *Tag, *AutoGroup:
(*a)[i] = alias.Alias
default:
return fmt.Errorf("type %T not supported", alias.Alias)
}
}
return nil
}
// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule.
// It can be a list of usernames, tags or autogroups.
type SSHDstAliases []Alias
func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
switch alias.Alias.(type) {
case *Username, *Tag, *AutoGroup:
(*a)[i] = alias.Alias
default:
return fmt.Errorf("type %T not supported", alias.Alias)
}
}
return nil
}
type SSHUser string
func policyFromBytes(b []byte) (*Policy, error) {
var policy Policy
ast, err := hujson.Parse(b)
if err != nil {
return nil, fmt.Errorf("parsing HuJSON: %w", err)
}
ast.Standardize()
acl := ast.Pack()
err = json.Unmarshal(acl, &policy)
if err != nil {
return nil, fmt.Errorf("parsing policy from bytes: %w", err)
}
return &policy, nil
}
const (
expectedTokenItems = 2
)
// TODO(kradalby): copy tests from parseDestination in policy
func splitDestination(dest string) (string, string, error) {
var tokens []string
// Check if there is a IPv4/6:Port combination, IPv6 has more than
// three ":".
tokens = strings.Split(dest, ":")
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
port := tokens[len(tokens)-1]
maybeIPv6Str := strings.TrimSuffix(dest, ":"+port)
filteredMaybeIPv6Str := maybeIPv6Str
if strings.Contains(maybeIPv6Str, "/") {
networkParts := strings.Split(maybeIPv6Str, "/")
filteredMaybeIPv6Str = networkParts[0]
}
if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() {
return "", "", fmt.Errorf(
"failed to split destination: %v",
tokens,
)
} else {
tokens = []string{maybeIPv6Str, port}
}
}
var alias string
// We can have here stuff like:
// git-server:*
// 192.168.1.0/24:22
// fd7a:115c:a1e0::2:22
// fd7a:115c:a1e0::2/128:22
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
if len(tokens) == expectedTokenItems {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
return alias, tokens[len(tokens)-1], nil
}
// TODO(kradalby): write/copy tests from expandPorts in policy
func parsePorts(portsStr string) ([]tailcfg.PortRange, error) {
if portsStr == "*" {
return []tailcfg.PortRange{
tailcfg.PortRangeAny,
}, nil
}
var ports []tailcfg.PortRange
for _, portStr := range strings.Split(portsStr, ",") {
rang := strings.Split(portStr, "-")
switch len(rang) {
case 1:
port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(port),
Last: uint16(port),
})
case expectedTokenItems:
start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(start),
Last: uint16(last),
})
default:
return nil, errors.New("invalid ports")
}
}
return ports, nil
}
// For some reason golang.org/x/net/internal/iana is an internal package.
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
protocolIPv4 = 4 // IPv4 encapsulation
protocolTCP = 6 // Transmission Control
protocolEGP = 8 // Exterior Gateway Protocol
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
protocolUDP = 17 // User Datagram
protocolGRE = 47 // Generic Routing Encapsulation
protocolESP = 50 // Encap Security Payload
protocolAH = 51 // Authentication Header
protocolIPv6ICMP = 58 // ICMP for IPv6
protocolSCTP = 132 // Stream Control Transmission Protocol
ProtocolFC = 133 // Fibre Channel
)
// parseProtocol reads the proto field of the ACL and generates a list of
// protocols that will be allowed, following the IANA IP protocol number
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
//
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
// as per Tailscale behaviour (see tailcfg.FilterRule).
//
// Also returns a boolean indicating if the protocol
// requires all the destinations to use wildcard as port number (only TCP,
// UDP and SCTP support specifying ports).
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return nil, false, nil
case "igmp":
return []int{protocolIGMP}, true, nil
case "ipv4", "ip-in-ip":
return []int{protocolIPv4}, true, nil
case "tcp":
return []int{protocolTCP}, false, nil
case "egp":
return []int{protocolEGP}, true, nil
case "igp":
return []int{protocolIGP}, true, nil
case "udp":
return []int{protocolUDP}, false, nil
case "gre":
return []int{protocolGRE}, true, nil
case "esp":
return []int{protocolESP}, true, nil
case "ah":
return []int{protocolAH}, true, nil
case "sctp":
return []int{protocolSCTP}, false, nil
case "icmp":
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
default:
protocolNumber, err := strconv.Atoi(protocol)
if err != nil {
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
}
// TODO(kradalby): What is this?
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard, nil
}
}

View file

@ -0,0 +1,555 @@
package policyv2
import (
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
func TestUnmarshalPolicy(t *testing.T) {
tests := []struct {
name string
input string
want *Policy
wantErr string
}{
{
name: "empty",
input: "{}",
want: &Policy{},
},
{
name: "groups",
input: `
{
"groups": {
"group:example": [
"derp@headscale.net",
],
},
}
`,
want: &Policy{
Groups: Groups{
Group("group:example"): []Username{Username("derp@headscale.net")},
},
},
},
{
name: "basic-types",
input: `
{
"groups": {
"group:example": [
"testuser@headscale.net",
],
"group:other": [
"otheruser@headscale.net",
],
},
"tagOwners": {
"tag:user": ["testuser@headscale.net"],
"tag:group": ["group:other"],
"tag:userandgroup": ["testuser@headscale.net", "group:other"],
},
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
"outside": "192.168.0.0/16",
},
"acls": [
// All
{
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"],
},
// Users
{
"action": "accept",
"proto": "tcp",
"src": ["testuser@headscale.net"],
"dst": ["otheruser@headscale.net:80"],
},
// Groups
{
"action": "accept",
"proto": "tcp",
"src": ["group:example"],
"dst": ["group:other:80"],
},
// Tailscale IP
{
"action": "accept",
"proto": "tcp",
"src": ["100.101.102.103"],
"dst": ["100.101.102.104:80"],
},
// Subnet
{
"action": "accept",
"proto": "udp",
"src": ["10.0.0.0/8"],
"dst": ["172.16.0.0/16:80"],
},
// Hosts
{
"action": "accept",
"proto": "tcp",
"src": ["subnet-1"],
"dst": ["host-1:80-88"],
},
// Tags
{
"action": "accept",
"proto": "tcp",
"src": ["tag:group"],
"dst": ["tag:user:80,443"],
},
// Autogroup
{
"action": "accept",
"proto": "tcp",
"src": ["tag:group"],
"dst": ["autogroup:internet:80"],
},
],
}
`,
want: &Policy{
Groups: Groups{
Group("group:example"): []Username{Username("testuser@headscale.net")},
Group("group:other"): []Username{Username("otheruser@headscale.net")},
},
TagOwners: TagOwners{
Tag("tag:user"): Owners{up("testuser@headscale.net")},
Tag("tag:group"): Owners{gp("group:other")},
Tag("tag:userandgroup"): Owners{up("testuser@headscale.net"), gp("group:other")},
},
Hosts: Hosts{
"host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")),
"subnet-1": Prefix(netip.MustParsePrefix("100.100.101.100/24")),
"outside": Prefix(netip.MustParsePrefix("192.168.0.0/16")),
},
ACLs: []ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
// TODO(kradalby): Should this be host?
// It is:
// All traffic originating from Tailscale devices in your tailnet,
// any approved subnets and autogroup:shared.
// It does not allow traffic originating from
// non-tailscale devices (unless it is an approved route).
hp("*"),
},
Destinations: []AliasWithPorts{
{
// TODO(kradalby): Should this be host?
// It is:
// Includes any destination (no restrictions).
Alias: hp("*"),
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
ptr.To(Username("testuser@headscale.net")),
},
Destinations: []AliasWithPorts{
{
Alias: ptr.To(Username("otheruser@headscale.net")),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
gp("group:example"),
},
Destinations: []AliasWithPorts{
{
Alias: gp("group:other"),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
pp("100.101.102.103/32"),
},
Destinations: []AliasWithPorts{
{
Alias: pp("100.101.102.104/32"),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
{
Action: "accept",
Protocol: "udp",
Sources: Aliases{
pp("10.0.0.0/8"),
},
Destinations: []AliasWithPorts{
{
Alias: pp("172.16.0.0/16"),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
hp("subnet-1"),
},
Destinations: []AliasWithPorts{
{
Alias: hp("host-1"),
Ports: []tailcfg.PortRange{{First: 80, Last: 88}},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
tp("tag:group"),
},
Destinations: []AliasWithPorts{
{
Alias: tp("tag:user"),
Ports: []tailcfg.PortRange{
{First: 80, Last: 80},
{First: 443, Last: 443},
},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
tp("tag:group"),
},
Destinations: []AliasWithPorts{
{
Alias: agp("autogroup:internet"),
Ports: []tailcfg.PortRange{
{First: 80, Last: 80},
},
},
},
},
},
},
},
{
name: "invalid-username",
input: `
{
"groups": {
"group:example": [
"valid@",
"invalid",
],
},
}
`,
wantErr: `Username has to contain @, got: "invalid"`,
},
{
name: "invalid-group",
input: `
{
"groups": {
"grou:example": [
"valid@",
],
},
}
`,
wantErr: `Group has to start with "group:", got: "grou:example"`,
},
{
name: "group-in-group",
input: `
{
"groups": {
"group:inner": [],
"group:example": [
"group:inner",
],
},
}
`,
wantErr: `Username has to contain @, got: "group:inner"`,
},
{
name: "invalid-prefix",
input: `
{
"hosts": {
"derp": "10.0",
},
}
`,
wantErr: `ParseAddr("10.0"): IPv4 address too short`,
},
{
name: "invalid-auto-group",
input: `
{
"acls": [
// Autogroup
{
"action": "accept",
"proto": "tcp",
"src": ["tag:group"],
"dst": ["autogroup:invalid:80"],
},
],
}
`,
wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet]`,
},
}
cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool {
return x == y
}))
cmps = append(cmps, cmpopts.IgnoreUnexported(Policy{}))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy, err := policyFromBytes([]byte(tt.input))
// TODO(kradalby): This error checking is broken,
// but so is my brain, #longflight
if err == nil {
if tt.wantErr == "" {
return
}
t.Fatalf("got success; wanted error %q", tt.wantErr)
}
if err.Error() != tt.wantErr {
t.Fatalf("got error %q; want %q", err, tt.wantErr)
// } else if err.Error() == tt.wantErr {
// return
}
if err != nil {
t.Fatalf("unexpected err: %q", err)
}
if diff := cmp.Diff(tt.want, &policy, cmps...); diff != "" {
t.Fatalf("unexpected policy (-want +got):\n%s", diff)
}
})
}
}
func gp(s string) *Group { return ptr.To(Group(s)) }
func up(s string) *Username { return ptr.To(Username(s)) }
func hp(s string) *Host { return ptr.To(Host(s)) }
func tp(s string) *Tag { return ptr.To(Tag(s)) }
func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) }
func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) }
func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) }
func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) }
func p(pref string) Prefix { return Prefix(netip.MustParsePrefix(pref)) }
func TestResolvePolicy(t *testing.T) {
tests := []struct {
name string
nodes types.Nodes
pol *Policy
toResolve Alias
want []netip.Prefix
}{
{
name: "prefix",
toResolve: pp("100.100.101.101/32"),
want: []netip.Prefix{mp("100.100.101.101/32")},
},
{
name: "host",
pol: &Policy{
Hosts: Hosts{
"testhost": p("100.100.101.102/32"),
},
},
toResolve: hp("testhost"),
want: []netip.Prefix{mp("100.100.101.102/32")},
},
{
name: "username",
toResolve: ptr.To(Username("testuser")),
nodes: types.Nodes{
// Not matching other user
{
User: types.User{
Name: "notme",
},
IPv4: ap("100.100.101.1"),
},
// Not matching forced tags
{
User: types.User{
Name: "testuser",
},
ForcedTags: []string{"tag:anything"},
IPv4: ap("100.100.101.2"),
},
// not matchin pak tag
{
User: types.User{
Name: "testuser",
},
AuthKey: &types.PreAuthKey{
Tags: []string{"alsotagged"},
},
IPv4: ap("100.100.101.3"),
},
{
User: types.User{
Name: "testuser",
},
IPv4: ap("100.100.101.103"),
},
{
User: types.User{
Name: "testuser",
},
IPv4: ap("100.100.101.104"),
},
},
want: []netip.Prefix{mp("100.100.101.103/32"), mp("100.100.101.104/32")},
},
{
name: "group",
toResolve: ptr.To(Group("group:testgroup")),
nodes: types.Nodes{
// Not matching other user
{
User: types.User{
Name: "notmetoo",
},
IPv4: ap("100.100.101.4"),
},
// Not matching forced tags
{
User: types.User{
Name: "groupuser",
},
ForcedTags: []string{"tag:anything"},
IPv4: ap("100.100.101.5"),
},
// not matchin pak tag
{
User: types.User{
Name: "groupuser",
},
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:alsotagged"},
},
IPv4: ap("100.100.101.6"),
},
{
User: types.User{
Name: "groupuser",
},
IPv4: ap("100.100.101.203"),
},
{
User: types.User{
Name: "groupuser",
},
IPv4: ap("100.100.101.204"),
},
},
pol: &Policy{
Groups: Groups{
"group:testgroup": Usernames{"groupuser"},
"group:othergroup": Usernames{"notmetoo"},
},
},
want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")},
},
{
name: "tag",
toResolve: tp("tag:test"),
nodes: types.Nodes{
// Not matching other user
{
User: types.User{
Name: "notmetoo",
},
IPv4: ap("100.100.101.9"),
},
// Not matching forced tags
{
ForcedTags: []string{"tag:anything"},
IPv4: ap("100.100.101.10"),
},
// not matchin pak tag
{
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:alsotagged"},
},
IPv4: ap("100.100.101.11"),
},
// Not matching forced tags
{
ForcedTags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
},
// not matchin pak tag
{
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:test"},
},
IPv4: ap("100.100.101.239"),
},
},
// TODO(kradalby): tests handling TagOwners + hostinfo
pol: &Policy{},
want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ips, err := tt.toResolve.Resolve(tt.pol,
types.Users{},
tt.nodes)
if err != nil {
t.Fatalf("failed to resolve: %s", err)
}
prefs := ips.Prefixes()
if diff := cmp.Diff(tt.want, prefs, util.Comparers...); diff != "" {
t.Fatalf("unexpected prefs (-want +got):\n%s", diff)
}
})
}
}

View file

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
"time"
@ -134,6 +135,60 @@ func (node *Node) IPs() []netip.Addr {
return ret
}
// HasIP reports if a node has a given IP address.
func (node *Node) HasIP(i netip.Addr) bool {
for _, ip := range node.IPs() {
if ip.Compare(i) == 0 {
return true
}
}
return false
}
// IsTagged reports if a device is tagged
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
func (node *Node) IsTagged() bool {
if len(node.ForcedTags) > 0 {
return true
}
if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 {
return true
}
if node.Hostinfo == nil {
return false
}
// TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags.
// Do this in other work.
return false
}
// HasTag reports if a node has a given tag.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
func (node *Node) HasTag(tag string) bool {
if slices.Contains(node.ForcedTags, tag) {
return true
}
if node.AuthKey != nil && slices.Contains(node.AuthKey.Tags, tag) {
return true
}
// TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags.
// Do this in other work.
return false
}
func (node *Node) Prefixes() []netip.Prefix {
addrs := []netip.Prefix{}
for _, nodeAddress := range node.IPs() {

View file

@ -2,7 +2,9 @@ package types
import (
"cmp"
"fmt"
"strconv"
"strings"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
@ -13,6 +15,19 @@ import (
type UserID uint64
type Users []User
func (u Users) String() string {
var sb strings.Builder
sb.WriteString("[ ")
for _, user := range u {
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
}
sb.WriteString(" ]")
return sb.String()
}
// User is the way Headscale implements the concept of users in Tailscale
//
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users