policy manager

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2024-10-25 11:52:12 -05:00
parent 53cbdfc277
commit 907449bd99
No known key found for this signature in database
9 changed files with 455 additions and 70 deletions

View file

@ -30,6 +30,7 @@ import (
"github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policyv2"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog" zerolog "github.com/philip-bui/grpc-zerolog"
@ -89,6 +90,7 @@ type Headscale struct {
DERPServer *derpServer.DERPServer DERPServer *derpServer.DERPServer
ACLPolicy *policy.ACLPolicy ACLPolicy *policy.ACLPolicy
PolicyManager *policyv2.PolicyManager
mapper *mapper.Mapper mapper *mapper.Mapper
nodeNotifier *notifier.Notifier nodeNotifier *notifier.Notifier

View file

@ -3,6 +3,7 @@ package policyv2
import ( import (
"errors" "errors"
"fmt" "fmt"
"time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx" "go4.org/netipx"
@ -16,6 +17,7 @@ var (
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *Policy) CompileFilterRules( func (pol *Policy) CompileFilterRules(
users types.Users,
nodes types.Nodes, nodes types.Nodes,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
if pol == nil { if pol == nil {
@ -29,7 +31,7 @@ func (pol *Policy) CompileFilterRules(
return nil, ErrInvalidAction return nil, ErrInvalidAction
} }
srcIPs, err := acl.Sources.Resolve(pol, nodes) srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("resolving source ips: %w", err) return nil, fmt.Errorf("resolving source ips: %w", err)
} }
@ -43,7 +45,7 @@ func (pol *Policy) CompileFilterRules(
var destPorts []tailcfg.NetPortRange var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations { for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, nodes) ips, err := dest.Alias.Resolve(pol, users, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -69,6 +71,105 @@ func (pol *Policy) CompileFilterRules(
return rules, nil return rules, nil
} }
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
return tailcfg.SSHAction{
Reject: !accept,
Accept: accept,
SessionDuration: duration,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
}
}
func (pol *Policy) CompileSSHPolicy(
users types.Users,
node types.Node,
nodes types.Nodes,
) (*tailcfg.SSHPolicy, error) {
if pol == nil {
return nil, nil
}
var rules []*tailcfg.SSHRule
for index, rule := range pol.SSHs {
var dest netipx.IPSetBuilder
for _, src := range rule.Destinations {
ips, err := src.Resolve(pol, users, nodes)
if err != nil {
return nil, err
}
dest.AddSet(ips)
}
destSet, err := dest.IPSet()
if err != nil {
return nil, err
}
if !node.InIPSet(destSet) {
continue
}
var action tailcfg.SSHAction
switch rule.Action {
case "accept":
action = sshAction(true, 0)
case "check":
action = sshAction(true, rule.CheckPeriod)
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}
var principals []*tailcfg.SSHPrincipal
for _, src := range rule.Sources {
if isWildcard(rawSrc) {
principals = append(principals, &tailcfg.SSHPrincipal{
Any: true,
})
} else if isGroup(rawSrc) {
users, err := pol.expandUsersFromGroup(rawSrc)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err)
}
for _, user := range users {
principals = append(principals, &tailcfg.SSHPrincipal{
UserLogin: user,
})
}
} else {
expandedSrcs, err := pol.ExpandAlias(
peers,
rawSrc,
)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
}
for _, expandedSrc := range expandedSrcs.Prefixes() {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc.Addr().String(),
})
}
}
}
userMap := make(map[string]string, len(rule.Users))
for _, user := range rule.Users {
userMap[user] = "="
}
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}
return &tailcfg.SSHPolicy{
Rules: rules,
}, nil
}
func ipSetToPrefixStringList(ips *netipx.IPSet) []string { func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
var out []string var out []string

View file

@ -8,6 +8,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -17,6 +18,9 @@ import (
// Move it here, run it against both old and new CompileFilterRules // Move it here, run it against both old and new CompileFilterRules
func TestParsing(t *testing.T) { func TestParsing(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser@"},
}
tests := []struct { tests := []struct {
name string name string
format string format string
@ -340,7 +344,7 @@ func TestParsing(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
pol, err := PolicyFromBytes([]byte(tt.acl)) pol, err := policyFromBytes([]byte(tt.acl))
if tt.wantErr && err == nil { if tt.wantErr && err == nil {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
@ -355,15 +359,15 @@ func TestParsing(t *testing.T) {
return return
} }
rules, err := pol.CompileFilterRules(types.Nodes{ rules, err := pol.CompileFilterRules(
users,
types.Nodes{
&types.Node{ &types.Node{
IPv4: ap("100.100.100.100"), IPv4: ap("100.100.100.100"),
}, },
&types.Node{ &types.Node{
IPv4: ap("200.200.200.200"), IPv4: ap("200.200.200.200"),
User: types.User{ User: users[0],
Name: "testuser@",
},
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
}) })
@ -435,6 +439,14 @@ var hsExitNodeDestForTest = []tailcfg.NetPortRange{
} }
func TestReduceFilterRules(t *testing.T) { func TestReduceFilterRules(t *testing.T) {
users := types.Users{
types.User{Model: gorm.Model{ID: 1}, Name: "mickael"},
types.User{Model: gorm.Model{ID: 2}, Name: "user1@"},
types.User{Model: gorm.Model{ID: 3}, Name: "user2@"},
types.User{Model: gorm.Model{ID: 4}, Name: "user100@"},
types.User{Model: gorm.Model{ID: 5}, Name: "user3@"},
}
tests := []struct { tests := []struct {
name string name string
node *types.Node node *types.Node
@ -463,13 +475,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: types.User{Name: "mickael"}, User: users[0],
}, },
peers: types.Nodes{ peers: types.Nodes{
&types.Node{ &types.Node{
IPv4: ap("100.64.0.2"), IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: types.User{Name: "mickael"}, User: users[0],
}, },
}, },
want: []tailcfg.FilterRule{}, want: []tailcfg.FilterRule{},
@ -510,7 +522,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"), IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"}, User: users[1],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{ RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"), netip.MustParsePrefix("10.33.0.0/16"),
@ -521,7 +533,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: ap("100.64.0.2"), IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"), IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user1@"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -600,19 +612,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"), IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"}, User: users[1],
}, },
peers: types.Nodes{ peers: types.Nodes{
&types.Node{ &types.Node{
IPv4: ap("100.64.0.2"), IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"), IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"}, User: users[2],
}, },
// "internal" exit node // "internal" exit node
&types.Node{ &types.Node{
IPv4: ap("100.64.0.100"), IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"), IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -661,7 +673,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: ap("100.64.0.100"), IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"), IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -670,12 +682,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: ap("100.64.0.2"), IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"), IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"), IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -768,7 +780,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: ap("100.64.0.100"), IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"), IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -777,12 +789,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: ap("100.64.0.2"), IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"), IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"), IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -809,9 +821,11 @@ func TestReduceFilterRules(t *testing.T) {
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny}, {IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
{IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, // This should not be included I believe, seems like
{IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, // this is a bug in the v1 code.
{IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny}, // {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny},
// {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny},
// {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny},
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny}, {IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
@ -881,7 +895,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: ap("100.64.0.100"), IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"), IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
}, },
@ -890,12 +904,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: ap("100.64.0.2"), IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"), IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"), IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -969,7 +983,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: ap("100.64.0.100"), IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"), IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
}, },
@ -978,12 +992,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: ap("100.64.0.2"), IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"), IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2@"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"), IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -1046,7 +1060,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: ap("100.64.0.100"), IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"), IPv6: ap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100@"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
}, },
@ -1056,7 +1070,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: ap("100.64.0.1"), IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"), IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1@"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -1090,22 +1104,19 @@ func TestReduceFilterRules(t *testing.T) {
filterV1, _ := polV1.CompileFilterRules( filterV1, _ := polV1.CompileFilterRules(
append(tt.peers, tt.node), append(tt.peers, tt.node),
) )
polV2, err := PolicyFromBytes([]byte(tt.pol)) pm, err := NewPolicyManager([]byte(tt.pol), users, append(tt.peers, tt.node))
if err != nil { if err != nil {
t.Fatalf("parsing policy: %s", err) t.Fatalf("parsing policy: %s", err)
} }
filterV2, _ := polV2.CompileFilterRules(
append(tt.peers, tt.node),
)
if diff := cmp.Diff(filterV1, filterV2); diff != "" { if diff := cmp.Diff(filterV1, pm.Filter()); diff != "" {
log.Trace().Interface("got", filterV2).Msg("result") log.Trace().Interface("got", pm.Filter()).Msg("result")
t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff) t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff)
} }
// TODO(kradalby): Move this from v1, or // TODO(kradalby): Move this from v1, or
// rewrite. // rewrite.
filterV2 = policy.ReduceFilterRules(tt.node, filterV2) filterV2 := policy.ReduceFilterRules(tt.node, pm.Filter())
if diff := cmp.Diff(tt.want, filterV2); diff != "" { if diff := cmp.Diff(tt.want, filterV2); diff != "" {
log.Trace().Interface("got", filterV2).Msg("result") log.Trace().Interface("got", filterV2).Msg("result")

View file

@ -0,0 +1,80 @@
package policyv2
import (
"fmt"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
)
type PolicyManager struct {
mu sync.Mutex
pol *Policy
users []types.User
nodes types.Nodes
filter []tailcfg.FilterRule
// TODO(kradalby): Implement SSH policy
sshPolicy *tailcfg.SSHPolicy
}
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
// It returns an error if the policy file is invalid.
// The policy manager will update the filter rules based on the users and nodes.
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
policy, err := policyFromBytes(b)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}
pm := PolicyManager{
pol: policy,
users: users,
nodes: nodes,
}
err = pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
// Filter returns the current filter rules for the entire tailnet.
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManager) updateLocked() error {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return fmt.Errorf("compiling filter rules: %w", err)
}
pm.filter = filter
return nil
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetUsers(users []types.User) error {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetNodes(nodes types.Nodes) error {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}

View file

@ -0,0 +1,58 @@
package policyv2
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
return &types.Node{
ID: 0,
Hostname: name,
IPv4: ap(ipv4),
IPv6: ap(ipv6),
User: user,
UserID: user.ID,
Hostinfo: hostinfo,
}
}
func TestPolicyManager(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
}
tests := []struct {
name string
pol string
nodes types.Nodes
wantFilter []tailcfg.FilterRule
}{
{
name: "empty-policy",
pol: "{}",
nodes: types.Nodes{},
wantFilter: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
require.NoError(t, err)
filter := pm.Filter()
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
}
// TODO(kradalby): Test SSH Policy
})
}
}

View file

@ -8,6 +8,7 @@ import (
"net/netip" "net/netip"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -64,7 +65,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
ips.AddPrefix(tsaddr.AllIPv4()) ips.AddPrefix(tsaddr.AllIPv4())
@ -99,15 +100,47 @@ func (u Username) CanBeTagOwner() bool {
return true return true
} }
func (u Username) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) { func (u Username) resolveUser(users types.Users) (*types.User, error) {
var potentialUsers types.Users
for _, user := range users {
if user.ProviderIdentifier == string(u) {
potentialUsers = append(potentialUsers, user)
break
}
if user.Email == string(u) {
potentialUsers = append(potentialUsers, user)
}
if user.Name == string(u) {
potentialUsers = append(potentialUsers, user)
}
}
if len(potentialUsers) > 1 {
return nil, fmt.Errorf("unable to resolve user identifier to distinct: %s matched multiple %s", u, potentialUsers)
} else if len(potentialUsers) == 0 {
return nil, fmt.Errorf("unable to resolve user identifier, no user found: %s not in %s", u, users)
}
user := potentialUsers[0]
return &user, nil
}
func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
user, err := u.resolveUser(users)
if err != nil {
return nil, err
}
for _, node := range nodes { for _, node := range nodes {
if node.IsTagged() { if node.IsTagged() {
continue continue
} }
if node.User.Username() == string(u) { if node.User.ID == user.ID {
node.AppendToIPSet(&ips) node.AppendToIPSet(&ips)
} }
} }
@ -137,11 +170,11 @@ func (g Group) CanBeTagOwner() bool {
return true return true
} }
func (g Group) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
for _, user := range p.Groups[g] { for _, user := range p.Groups[g] {
uips, err := user.Resolve(nil, nodes) uips, err := user.Resolve(nil, users, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -170,7 +203,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (t Tag) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { func (t Tag) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
for _, node := range nodes { for _, node := range nodes {
@ -197,7 +230,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
pref, ok := p.Hosts[h] pref, ok := p.Hosts[h]
@ -208,11 +241,26 @@ func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
// If the IP is a single host, look for a node to ensure we add all the IPs of
// the node to the IPSet.
appendIfNodeHasIP(nodes, &ips, pref)
ips.AddPrefix(netip.Prefix(pref)) ips.AddPrefix(netip.Prefix(pref))
return ips.IPSet() return ips.IPSet()
} }
func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref Prefix) {
if netip.Prefix(pref).IsSingleIP() {
addr := netip.Prefix(pref).Addr()
for _, node := range nodes {
if node.HasIP(addr) {
node.AppendToIPSet(ips)
}
}
}
}
type Prefix netip.Prefix type Prefix netip.Prefix
func (p Prefix) Validate() error { func (p Prefix) Validate() error {
@ -261,9 +309,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (p Prefix) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
appendIfNodeHasIP(nodes, &ips, p)
ips.AddPrefix(netip.Prefix(p)) ips.AddPrefix(netip.Prefix(p))
return ips.IPSet() return ips.IPSet()
@ -296,7 +345,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) { func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) {
switch ag { switch ag {
case AutoGroupInternet: case AutoGroupInternet:
return theInternet(), nil return theInternet(), nil
@ -308,7 +357,7 @@ func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
type Alias interface { type Alias interface {
Validate() error Validate() error
UnmarshalJSON([]byte) error UnmarshalJSON([]byte) error
Resolve(*Policy, types.Nodes) (*netipx.IPSet, error) Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error)
} }
type AliasWithPorts struct { type AliasWithPorts struct {
@ -428,11 +477,11 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (a Aliases) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) { func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
for _, alias := range a { for _, alias := range a {
aips, err := alias.Resolve(p, nodes) aips, err := alias.Resolve(p, users, nodes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -530,10 +579,67 @@ type Policy struct {
TagOwners TagOwners `json:"tagOwners"` TagOwners TagOwners `json:"tagOwners"`
ACLs []ACL `json:"acls"` ACLs []ACL `json:"acls"`
AutoApprovers AutoApprovers `json:"autoApprovers"` AutoApprovers AutoApprovers `json:"autoApprovers"`
// SSHs []SSH `json:"ssh"` SSHs []SSH `json:"ssh"`
} }
func PolicyFromBytes(b []byte) (*Policy, error) { // SSH controls who can ssh into which machines.
type SSH struct {
Action string `json:"action"`
Sources SSHSrcAliases `json:"src"`
Destinations SSHDstAliases `json:"dst"`
Users []SSHUser `json:"users"`
CheckPeriod time.Duration `json:"checkPeriod,omitempty"`
}
// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.
// It can be a list of usernames, groups, tags or autogroups.
type SSHSrcAliases []Alias
func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
switch alias.Alias.(type) {
case *Username, *Group, *Tag, *AutoGroup:
(*a)[i] = alias.Alias
default:
return fmt.Errorf("type %T not supported", alias.Alias)
}
}
return nil
}
// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule.
// It can be a list of usernames, tags or autogroups.
type SSHDstAliases []Alias
func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
switch alias.Alias.(type) {
case *Username, *Tag, *AutoGroup:
(*a)[i] = alias.Alias
default:
return fmt.Errorf("type %T not supported", alias.Alias)
}
}
return nil
}
type SSHUser string
func policyFromBytes(b []byte) (*Policy, error) {
var policy Policy var policy Policy
ast, err := hujson.Parse(b) ast, err := hujson.Parse(b)
if err != nil { if err != nil {

View file

@ -173,7 +173,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: ptr.To(Username("otheruser@headscale.net")), Alias: ptr.To(Username("otheruser@headscale.net")),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
}, },
}, },
}, },
@ -186,7 +186,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: gp("group:other"), Alias: gp("group:other"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
}, },
}, },
}, },
@ -199,7 +199,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: pp("100.101.102.104/32"), Alias: pp("100.101.102.104/32"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
}, },
}, },
}, },
@ -212,7 +212,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: pp("172.16.0.0/16"), Alias: pp("172.16.0.0/16"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}}, Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
}, },
}, },
}, },
@ -225,7 +225,7 @@ func TestUnmarshalPolicy(t *testing.T) {
Destinations: []AliasWithPorts{ Destinations: []AliasWithPorts{
{ {
Alias: hp("host-1"), Alias: hp("host-1"),
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}}, Ports: []tailcfg.PortRange{{First: 80, Last: 88}},
}, },
}, },
}, },
@ -239,8 +239,8 @@ func TestUnmarshalPolicy(t *testing.T) {
{ {
Alias: tp("tag:user"), Alias: tp("tag:user"),
Ports: []tailcfg.PortRange{ Ports: []tailcfg.PortRange{
tailcfg.PortRange{First: 80, Last: 80}, {First: 80, Last: 80},
tailcfg.PortRange{First: 443, Last: 443}, {First: 443, Last: 443},
}, },
}, },
}, },
@ -255,7 +255,7 @@ func TestUnmarshalPolicy(t *testing.T) {
{ {
Alias: agp("autogroup:internet"), Alias: agp("autogroup:internet"),
Ports: []tailcfg.PortRange{ Ports: []tailcfg.PortRange{
tailcfg.PortRange{First: 80, Last: 80}, {First: 80, Last: 80},
}, },
}, },
}, },
@ -341,7 +341,7 @@ func TestUnmarshalPolicy(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
policy, err := PolicyFromBytes([]byte(tt.input)) policy, err := policyFromBytes([]byte(tt.input))
// TODO(kradalby): This error checking is broken, // TODO(kradalby): This error checking is broken,
// but so is my brain, #longflight // but so is my brain, #longflight
if err == nil { if err == nil {
@ -538,7 +538,9 @@ func TestResolvePolicy(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ips, err := tt.toResolve.Resolve(tt.pol, tt.nodes) ips, err := tt.toResolve.Resolve(tt.pol,
types.Users{},
tt.nodes)
if err != nil { if err != nil {
t.Fatalf("failed to resolve: %s", err) t.Fatalf("failed to resolve: %s", err)
} }

View file

@ -135,6 +135,16 @@ func (node *Node) IPs() []netip.Addr {
return ret return ret
} }
// HasIP reports if a node has a given IP address.
func (node *Node) HasIP(i netip.Addr) bool {
for _, ip := range node.IPs() {
if ip.Compare(i) == 0 {
return true
}
}
return false
}
// IsTagged reports if a device is tagged // IsTagged reports if a device is tagged
// and therefore should not be treated as a // and therefore should not be treated as a
// user owned device. // user owned device.

View file

@ -2,7 +2,9 @@ package types
import ( import (
"cmp" "cmp"
"fmt"
"strconv" "strconv"
"strings"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -13,6 +15,19 @@ import (
type UserID uint64 type UserID uint64
type Users []User
func (u Users) String() string {
var sb strings.Builder
sb.WriteString("[ ")
for _, user := range u {
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
}
sb.WriteString(" ]")
return sb.String()
}
// User is the way Headscale implements the concept of users in Tailscale // User is the way Headscale implements the concept of users in Tailscale
// //
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users // At the end of the day, users in Tailscale are some kind of 'bubbles' or users