mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
policy manager
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
9375836569
commit
2e68455331
9 changed files with 455 additions and 70 deletions
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/policyv2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
|
@ -89,6 +90,7 @@ type Headscale struct {
|
|||
DERPServer *derpServer.DERPServer
|
||||
|
||||
ACLPolicy *policy.ACLPolicy
|
||||
PolicyManager *policyv2.PolicyManager
|
||||
|
||||
mapper *mapper.Mapper
|
||||
nodeNotifier *notifier.Notifier
|
||||
|
|
|
@ -3,6 +3,7 @@ package policyv2
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
|
@ -16,6 +17,7 @@ var (
|
|||
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *Policy) CompileFilterRules(
|
||||
users types.Users,
|
||||
nodes types.Nodes,
|
||||
) ([]tailcfg.FilterRule, error) {
|
||||
if pol == nil {
|
||||
|
@ -29,7 +31,7 @@ func (pol *Policy) CompileFilterRules(
|
|||
return nil, ErrInvalidAction
|
||||
}
|
||||
|
||||
srcIPs, err := acl.Sources.Resolve(pol, nodes)
|
||||
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving source ips: %w", err)
|
||||
}
|
||||
|
@ -43,7 +45,7 @@ func (pol *Policy) CompileFilterRules(
|
|||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Alias.Resolve(pol, nodes)
|
||||
ips, err := dest.Alias.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -69,6 +71,105 @@ func (pol *Policy) CompileFilterRules(
|
|||
return rules, nil
|
||||
}
|
||||
|
||||
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
return tailcfg.SSHAction{
|
||||
Reject: !accept,
|
||||
Accept: accept,
|
||||
SessionDuration: duration,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (pol *Policy) CompileSSHPolicy(
|
||||
users types.Users,
|
||||
node types.Node,
|
||||
nodes types.Nodes,
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for index, rule := range pol.SSHs {
|
||||
var dest netipx.IPSetBuilder
|
||||
for _, src := range rule.Destinations {
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
|
||||
destSet, err := dest.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !node.InIPSet(destSet) {
|
||||
continue
|
||||
}
|
||||
|
||||
var action tailcfg.SSHAction
|
||||
switch rule.Action {
|
||||
case "accept":
|
||||
action = sshAction(true, 0)
|
||||
case "check":
|
||||
action = sshAction(true, rule.CheckPeriod)
|
||||
default:
|
||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||
}
|
||||
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
for _, src := range rule.Sources {
|
||||
if isWildcard(rawSrc) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
Any: true,
|
||||
})
|
||||
} else if isGroup(rawSrc) {
|
||||
users, err := pol.expandUsersFromGroup(rawSrc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err)
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
UserLogin: user,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
expandedSrcs, err := pol.ExpandAlias(
|
||||
peers,
|
||||
rawSrc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
|
||||
}
|
||||
for _, expandedSrc := range expandedSrcs.Prefixes() {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: expandedSrc.Addr().String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userMap := make(map[string]string, len(rule.Users))
|
||||
for _, user := range rule.Users {
|
||||
userMap[user] = "="
|
||||
}
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: &action,
|
||||
})
|
||||
}
|
||||
|
||||
return &tailcfg.SSHPolicy{
|
||||
Rules: rules,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
var out []string
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
@ -17,6 +18,9 @@ import (
|
|||
// Move it here, run it against both old and new CompileFilterRules
|
||||
|
||||
func TestParsing(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "testuser@"},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
|
@ -340,7 +344,7 @@ func TestParsing(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pol, err := PolicyFromBytes([]byte(tt.acl))
|
||||
pol, err := policyFromBytes([]byte(tt.acl))
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
|
@ -355,15 +359,15 @@ func TestParsing(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
rules, err := pol.CompileFilterRules(types.Nodes{
|
||||
rules, err := pol.CompileFilterRules(
|
||||
users,
|
||||
types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.100.100.100"),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("200.200.200.200"),
|
||||
User: types.User{
|
||||
Name: "testuser@",
|
||||
},
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
})
|
||||
|
@ -435,6 +439,14 @@ var hsExitNodeDestForTest = []tailcfg.NetPortRange{
|
|||
}
|
||||
|
||||
func TestReduceFilterRules(t *testing.T) {
|
||||
users := types.Users{
|
||||
types.User{Model: gorm.Model{ID: 1}, Name: "mickael"},
|
||||
types.User{Model: gorm.Model{ID: 2}, Name: "user1@"},
|
||||
types.User{Model: gorm.Model{ID: 3}, Name: "user2@"},
|
||||
types.User{Model: gorm.Model{ID: 4}, Name: "user100@"},
|
||||
types.User{Model: gorm.Model{ID: 5}, Name: "user3@"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
node *types.Node
|
||||
|
@ -463,13 +475,13 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
||||
User: types.User{Name: "mickael"},
|
||||
User: users[0],
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
|
||||
User: types.User{Name: "mickael"},
|
||||
User: users[0],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{},
|
||||
|
@ -510,7 +522,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1@"},
|
||||
User: users[1],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.33.0.0/16"),
|
||||
|
@ -521,7 +533,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user1@"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -600,19 +612,19 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1@"},
|
||||
User: users[1],
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2@"},
|
||||
User: users[2],
|
||||
},
|
||||
// "internal" exit node
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100@"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
|
@ -661,7 +673,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100@"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
|
@ -670,12 +682,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2@"},
|
||||
User: users[2],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1@"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -768,7 +780,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100@"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
|
@ -777,12 +789,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2@"},
|
||||
User: users[2],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1@"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -809,9 +821,11 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny},
|
||||
// This should not be included I believe, seems like
|
||||
// this is a bug in the v1 code.
|
||||
// {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny},
|
||||
// {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny},
|
||||
// {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
|
||||
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
|
||||
|
@ -881,7 +895,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100@"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
|
||||
},
|
||||
|
@ -890,12 +904,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2@"},
|
||||
User: users[2],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1@"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -969,7 +983,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100@"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
|
||||
},
|
||||
|
@ -978,12 +992,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2@"},
|
||||
User: users[2],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1@"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -1046,7 +1060,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100@"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||
},
|
||||
|
@ -1056,7 +1070,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1@"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -1090,22 +1104,19 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
filterV1, _ := polV1.CompileFilterRules(
|
||||
append(tt.peers, tt.node),
|
||||
)
|
||||
polV2, err := PolicyFromBytes([]byte(tt.pol))
|
||||
pm, err := NewPolicyManager([]byte(tt.pol), users, append(tt.peers, tt.node))
|
||||
if err != nil {
|
||||
t.Fatalf("parsing policy: %s", err)
|
||||
}
|
||||
filterV2, _ := polV2.CompileFilterRules(
|
||||
append(tt.peers, tt.node),
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(filterV1, filterV2); diff != "" {
|
||||
log.Trace().Interface("got", filterV2).Msg("result")
|
||||
if diff := cmp.Diff(filterV1, pm.Filter()); diff != "" {
|
||||
log.Trace().Interface("got", pm.Filter()).Msg("result")
|
||||
t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Move this from v1, or
|
||||
// rewrite.
|
||||
filterV2 = policy.ReduceFilterRules(tt.node, filterV2)
|
||||
filterV2 := policy.ReduceFilterRules(tt.node, pm.Filter())
|
||||
|
||||
if diff := cmp.Diff(tt.want, filterV2); diff != "" {
|
||||
log.Trace().Interface("got", filterV2).Msg("result")
|
||||
|
|
80
hscontrol/policyv2/policy.go
Normal file
80
hscontrol/policyv2/policy.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package policyv2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
type PolicyManager struct {
|
||||
mu sync.Mutex
|
||||
pol *Policy
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filter []tailcfg.FilterRule
|
||||
|
||||
// TODO(kradalby): Implement SSH policy
|
||||
sshPolicy *tailcfg.SSHPolicy
|
||||
}
|
||||
|
||||
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
|
||||
// It returns an error if the policy file is invalid.
|
||||
// The policy manager will update the filter rules based on the users and nodes.
|
||||
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
policy, err := policyFromBytes(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm := PolicyManager{
|
||||
pol: policy,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
// Filter returns the current filter rules for the entire tailnet.
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManager) updateLocked() error {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetUsers(users []types.User) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetNodes(nodes types.Nodes) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
58
hscontrol/policyv2/policy_test.go
Normal file
58
hscontrol/policyv2/policy_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package policyv2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||
return &types.Node{
|
||||
ID: 0,
|
||||
Hostname: name,
|
||||
IPv4: ap(ipv4),
|
||||
IPv6: ap(ipv6),
|
||||
User: user,
|
||||
UserID: user.ID,
|
||||
Hostinfo: hostinfo,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyManager(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
|
||||
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol string
|
||||
nodes types.Nodes
|
||||
wantFilter []tailcfg.FilterRule
|
||||
}{
|
||||
{
|
||||
name: "empty-policy",
|
||||
pol: "{}",
|
||||
nodes: types.Nodes{},
|
||||
wantFilter: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||
require.NoError(t, err)
|
||||
|
||||
filter := pm.Filter()
|
||||
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
|
||||
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Test SSH Policy
|
||||
})
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
|
@ -64,7 +65,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
ips.AddPrefix(tsaddr.AllIPv4())
|
||||
|
@ -99,15 +100,47 @@ func (u Username) CanBeTagOwner() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (u Username) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (u Username) resolveUser(users types.Users) (*types.User, error) {
|
||||
var potentialUsers types.Users
|
||||
for _, user := range users {
|
||||
if user.ProviderIdentifier == string(u) {
|
||||
potentialUsers = append(potentialUsers, user)
|
||||
|
||||
break
|
||||
}
|
||||
if user.Email == string(u) {
|
||||
potentialUsers = append(potentialUsers, user)
|
||||
}
|
||||
if user.Name == string(u) {
|
||||
potentialUsers = append(potentialUsers, user)
|
||||
}
|
||||
}
|
||||
|
||||
if len(potentialUsers) > 1 {
|
||||
return nil, fmt.Errorf("unable to resolve user identifier to distinct: %s matched multiple %s", u, potentialUsers)
|
||||
} else if len(potentialUsers) == 0 {
|
||||
return nil, fmt.Errorf("unable to resolve user identifier, no user found: %s not in %s", u, users)
|
||||
}
|
||||
|
||||
user := potentialUsers[0]
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
user, err := u.resolveUser(users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.IsTagged() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.User.Username() == string(u) {
|
||||
if node.User.ID == user.ID {
|
||||
node.AppendToIPSet(&ips)
|
||||
}
|
||||
}
|
||||
|
@ -137,11 +170,11 @@ func (g Group) CanBeTagOwner() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (g Group) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
for _, user := range p.Groups[g] {
|
||||
uips, err := user.Resolve(nil, nodes)
|
||||
uips, err := user.Resolve(nil, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -170,7 +203,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (t Tag) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (t Tag) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
for _, node := range nodes {
|
||||
|
@ -197,7 +230,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
pref, ok := p.Hosts[h]
|
||||
|
@ -208,11 +241,26 @@ func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the IP is a single host, look for a node to ensure we add all the IPs of
|
||||
// the node to the IPSet.
|
||||
appendIfNodeHasIP(nodes, &ips, pref)
|
||||
ips.AddPrefix(netip.Prefix(pref))
|
||||
|
||||
return ips.IPSet()
|
||||
}
|
||||
|
||||
func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref Prefix) {
|
||||
if netip.Prefix(pref).IsSingleIP() {
|
||||
addr := netip.Prefix(pref).Addr()
|
||||
for _, node := range nodes {
|
||||
if node.HasIP(addr) {
|
||||
node.AppendToIPSet(ips)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Prefix netip.Prefix
|
||||
|
||||
func (p Prefix) Validate() error {
|
||||
|
@ -261,9 +309,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p Prefix) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
|
||||
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
appendIfNodeHasIP(nodes, &ips, p)
|
||||
ips.AddPrefix(netip.Prefix(p))
|
||||
|
||||
return ips.IPSet()
|
||||
|
@ -296,7 +345,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
|
||||
func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) {
|
||||
switch ag {
|
||||
case AutoGroupInternet:
|
||||
return theInternet(), nil
|
||||
|
@ -308,7 +357,7 @@ func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
|
|||
type Alias interface {
|
||||
Validate() error
|
||||
UnmarshalJSON([]byte) error
|
||||
Resolve(*Policy, types.Nodes) (*netipx.IPSet, error)
|
||||
Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error)
|
||||
}
|
||||
|
||||
type AliasWithPorts struct {
|
||||
|
@ -428,11 +477,11 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a Aliases) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
for _, alias := range a {
|
||||
aips, err := alias.Resolve(p, nodes)
|
||||
aips, err := alias.Resolve(p, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -530,10 +579,67 @@ type Policy struct {
|
|||
TagOwners TagOwners `json:"tagOwners"`
|
||||
ACLs []ACL `json:"acls"`
|
||||
AutoApprovers AutoApprovers `json:"autoApprovers"`
|
||||
// SSHs []SSH `json:"ssh"`
|
||||
SSHs []SSH `json:"ssh"`
|
||||
}
|
||||
|
||||
func PolicyFromBytes(b []byte) (*Policy, error) {
|
||||
// SSH controls who can ssh into which machines.
|
||||
type SSH struct {
|
||||
Action string `json:"action"`
|
||||
Sources SSHSrcAliases `json:"src"`
|
||||
Destinations SSHDstAliases `json:"dst"`
|
||||
Users []SSHUser `json:"users"`
|
||||
CheckPeriod time.Duration `json:"checkPeriod,omitempty"`
|
||||
}
|
||||
|
||||
// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.
|
||||
// It can be a list of usernames, groups, tags or autogroups.
|
||||
type SSHSrcAliases []Alias
|
||||
|
||||
func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
|
||||
var aliases []AliasEnc
|
||||
err := json.Unmarshal(b, &aliases)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*a = make([]Alias, len(aliases))
|
||||
for i, alias := range aliases {
|
||||
switch alias.Alias.(type) {
|
||||
case *Username, *Group, *Tag, *AutoGroup:
|
||||
(*a)[i] = alias.Alias
|
||||
default:
|
||||
return fmt.Errorf("type %T not supported", alias.Alias)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule.
|
||||
// It can be a list of usernames, tags or autogroups.
|
||||
type SSHDstAliases []Alias
|
||||
|
||||
func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
|
||||
var aliases []AliasEnc
|
||||
err := json.Unmarshal(b, &aliases)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*a = make([]Alias, len(aliases))
|
||||
for i, alias := range aliases {
|
||||
switch alias.Alias.(type) {
|
||||
case *Username, *Tag, *AutoGroup:
|
||||
(*a)[i] = alias.Alias
|
||||
default:
|
||||
return fmt.Errorf("type %T not supported", alias.Alias)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SSHUser string
|
||||
|
||||
func policyFromBytes(b []byte) (*Policy, error) {
|
||||
var policy Policy
|
||||
ast, err := hujson.Parse(b)
|
||||
if err != nil {
|
||||
|
|
|
@ -173,7 +173,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
Destinations: []AliasWithPorts{
|
||||
{
|
||||
Alias: ptr.To(Username("otheruser@headscale.net")),
|
||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
|
||||
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -186,7 +186,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
Destinations: []AliasWithPorts{
|
||||
{
|
||||
Alias: gp("group:other"),
|
||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
|
||||
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -199,7 +199,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
Destinations: []AliasWithPorts{
|
||||
{
|
||||
Alias: pp("100.101.102.104/32"),
|
||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
|
||||
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -212,7 +212,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
Destinations: []AliasWithPorts{
|
||||
{
|
||||
Alias: pp("172.16.0.0/16"),
|
||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
|
||||
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -225,7 +225,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
Destinations: []AliasWithPorts{
|
||||
{
|
||||
Alias: hp("host-1"),
|
||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}},
|
||||
Ports: []tailcfg.PortRange{{First: 80, Last: 88}},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -239,8 +239,8 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
{
|
||||
Alias: tp("tag:user"),
|
||||
Ports: []tailcfg.PortRange{
|
||||
tailcfg.PortRange{First: 80, Last: 80},
|
||||
tailcfg.PortRange{First: 443, Last: 443},
|
||||
{First: 80, Last: 80},
|
||||
{First: 443, Last: 443},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -255,7 +255,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
{
|
||||
Alias: agp("autogroup:internet"),
|
||||
Ports: []tailcfg.PortRange{
|
||||
tailcfg.PortRange{First: 80, Last: 80},
|
||||
{First: 80, Last: 80},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -341,7 +341,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
policy, err := PolicyFromBytes([]byte(tt.input))
|
||||
policy, err := policyFromBytes([]byte(tt.input))
|
||||
// TODO(kradalby): This error checking is broken,
|
||||
// but so is my brain, #longflight
|
||||
if err == nil {
|
||||
|
@ -538,7 +538,9 @@ func TestResolvePolicy(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ips, err := tt.toResolve.Resolve(tt.pol, tt.nodes)
|
||||
ips, err := tt.toResolve.Resolve(tt.pol,
|
||||
types.Users{},
|
||||
tt.nodes)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to resolve: %s", err)
|
||||
}
|
||||
|
|
|
@ -130,6 +130,16 @@ func (node *Node) IPs() []netip.Addr {
|
|||
return ret
|
||||
}
|
||||
|
||||
// HasIP reports if a node has a given IP address.
|
||||
func (node *Node) HasIP(i netip.Addr) bool {
|
||||
for _, ip := range node.IPs() {
|
||||
if ip.Compare(i) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTagged reports if a device is tagged
|
||||
// and therefore should not be treated as a
|
||||
// user owned device.
|
||||
|
|
|
@ -2,7 +2,9 @@ package types
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
|
@ -13,6 +15,19 @@ import (
|
|||
|
||||
type UserID uint64
|
||||
|
||||
type Users []User
|
||||
|
||||
func (u Users) String() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[ ")
|
||||
for _, user := range u {
|
||||
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
|
||||
}
|
||||
sb.WriteString(" ]")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// User is the way Headscale implements the concept of users in Tailscale
|
||||
//
|
||||
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
||||
|
|
Loading…
Reference in a new issue