mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
policyman before #2255
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
af969f602c
commit
02c76bda99
7 changed files with 430 additions and 70 deletions
|
@ -30,6 +30,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policyv2"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||||
|
@ -88,7 +89,8 @@ type Headscale struct {
|
||||||
DERPMap *tailcfg.DERPMap
|
DERPMap *tailcfg.DERPMap
|
||||||
DERPServer *derpServer.DERPServer
|
DERPServer *derpServer.DERPServer
|
||||||
|
|
||||||
ACLPolicy *policy.ACLPolicy
|
ACLPolicy *policy.ACLPolicy
|
||||||
|
PolicyManager *policyv2.PolicyManager
|
||||||
|
|
||||||
mapper *mapper.Mapper
|
mapper *mapper.Mapper
|
||||||
nodeNotifier *notifier.Notifier
|
nodeNotifier *notifier.Notifier
|
||||||
|
|
|
@ -3,6 +3,7 @@ package policyv2
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
|
@ -16,6 +17,7 @@ var (
|
||||||
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||||
func (pol *Policy) CompileFilterRules(
|
func (pol *Policy) CompileFilterRules(
|
||||||
|
users types.Users,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) ([]tailcfg.FilterRule, error) {
|
) ([]tailcfg.FilterRule, error) {
|
||||||
if pol == nil {
|
if pol == nil {
|
||||||
|
@ -29,7 +31,7 @@ func (pol *Policy) CompileFilterRules(
|
||||||
return nil, ErrInvalidAction
|
return nil, ErrInvalidAction
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIPs, err := acl.Sources.Resolve(pol, nodes)
|
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("resolving source ips: %w", err)
|
return nil, fmt.Errorf("resolving source ips: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -43,7 +45,7 @@ func (pol *Policy) CompileFilterRules(
|
||||||
|
|
||||||
var destPorts []tailcfg.NetPortRange
|
var destPorts []tailcfg.NetPortRange
|
||||||
for _, dest := range acl.Destinations {
|
for _, dest := range acl.Destinations {
|
||||||
ips, err := dest.Alias.Resolve(pol, nodes)
|
ips, err := dest.Alias.Resolve(pol, users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -69,6 +71,105 @@ func (pol *Policy) CompileFilterRules(
|
||||||
return rules, nil
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||||
|
return tailcfg.SSHAction{
|
||||||
|
Reject: !accept,
|
||||||
|
Accept: accept,
|
||||||
|
SessionDuration: duration,
|
||||||
|
AllowAgentForwarding: true,
|
||||||
|
AllowLocalPortForwarding: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pol *Policy) CompileSSHPolicy(
|
||||||
|
users types.Users,
|
||||||
|
node types.Node,
|
||||||
|
nodes types.Nodes,
|
||||||
|
) (*tailcfg.SSHPolicy, error) {
|
||||||
|
if pol == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules []*tailcfg.SSHRule
|
||||||
|
|
||||||
|
for index, rule := range pol.SSHs {
|
||||||
|
var dest netipx.IPSetBuilder
|
||||||
|
for _, src := range rule.Destinations {
|
||||||
|
ips, err := src.Resolve(pol, users, nodes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dest.AddSet(ips)
|
||||||
|
}
|
||||||
|
|
||||||
|
destSet, err := dest.IPSet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !node.InIPSet(destSet) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var action tailcfg.SSHAction
|
||||||
|
switch rule.Action {
|
||||||
|
case "accept":
|
||||||
|
action = sshAction(true, 0)
|
||||||
|
case "check":
|
||||||
|
action = sshAction(true, rule.CheckPeriod)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var principals []*tailcfg.SSHPrincipal
|
||||||
|
for _, src := range rule.Sources {
|
||||||
|
if isWildcard(rawSrc) {
|
||||||
|
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||||
|
Any: true,
|
||||||
|
})
|
||||||
|
} else if isGroup(rawSrc) {
|
||||||
|
users, err := pol.expandUsersFromGroup(rawSrc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, user := range users {
|
||||||
|
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||||
|
UserLogin: user,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
expandedSrcs, err := pol.ExpandAlias(
|
||||||
|
peers,
|
||||||
|
rawSrc,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
|
||||||
|
}
|
||||||
|
for _, expandedSrc := range expandedSrcs.Prefixes() {
|
||||||
|
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||||
|
NodeIP: expandedSrc.Addr().String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userMap := make(map[string]string, len(rule.Users))
|
||||||
|
for _, user := range rule.Users {
|
||||||
|
userMap[user] = "="
|
||||||
|
}
|
||||||
|
rules = append(rules, &tailcfg.SSHRule{
|
||||||
|
Principals: principals,
|
||||||
|
SSHUsers: userMap,
|
||||||
|
Action: &action,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tailcfg.SSHPolicy{
|
||||||
|
Rules: rules,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||||
var out []string
|
var out []string
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
@ -17,6 +18,9 @@ import (
|
||||||
// Move it here, run it against both old and new CompileFilterRules
|
// Move it here, run it against both old and new CompileFilterRules
|
||||||
|
|
||||||
func TestParsing(t *testing.T) {
|
func TestParsing(t *testing.T) {
|
||||||
|
users := types.Users{
|
||||||
|
{Model: gorm.Model{ID: 1}, Name: "testuser@"},
|
||||||
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
format string
|
format string
|
||||||
|
@ -340,7 +344,7 @@ func TestParsing(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
pol, err := PolicyFromBytes([]byte(tt.acl))
|
pol, err := policyFromBytes([]byte(tt.acl))
|
||||||
if tt.wantErr && err == nil {
|
if tt.wantErr && err == nil {
|
||||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
|
||||||
|
@ -355,18 +359,18 @@ func TestParsing(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := pol.CompileFilterRules(types.Nodes{
|
rules, err := pol.CompileFilterRules(
|
||||||
&types.Node{
|
users,
|
||||||
IPv4: ap("100.100.100.100"),
|
types.Nodes{
|
||||||
},
|
&types.Node{
|
||||||
&types.Node{
|
IPv4: ap("100.100.100.100"),
|
||||||
IPv4: ap("200.200.200.200"),
|
|
||||||
User: types.User{
|
|
||||||
Name: "testuser@",
|
|
||||||
},
|
},
|
||||||
Hostinfo: &tailcfg.Hostinfo{},
|
&types.Node{
|
||||||
},
|
IPv4: ap("200.200.200.200"),
|
||||||
})
|
User: users[0],
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -435,6 +439,14 @@ var hsExitNodeDestForTest = []tailcfg.NetPortRange{
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReduceFilterRules(t *testing.T) {
|
func TestReduceFilterRules(t *testing.T) {
|
||||||
|
users := types.Users{
|
||||||
|
types.User{Model: gorm.Model{ID: 1}, Name: "mickael"},
|
||||||
|
types.User{Model: gorm.Model{ID: 2}, Name: "user1@"},
|
||||||
|
types.User{Model: gorm.Model{ID: 3}, Name: "user2@"},
|
||||||
|
types.User{Model: gorm.Model{ID: 4}, Name: "user100@"},
|
||||||
|
types.User{Model: gorm.Model{ID: 5}, Name: "user3@"},
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
node *types.Node
|
node *types.Node
|
||||||
|
@ -463,13 +475,13 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
||||||
User: types.User{Name: "mickael"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
peers: types.Nodes{
|
peers: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.2"),
|
IPv4: ap("100.64.0.2"),
|
||||||
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
|
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
|
||||||
User: types.User{Name: "mickael"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{},
|
want: []tailcfg.FilterRule{},
|
||||||
|
@ -510,7 +522,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1@"},
|
User: users[1],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{
|
RoutableIPs: []netip.Prefix{
|
||||||
netip.MustParsePrefix("10.33.0.0/16"),
|
netip.MustParsePrefix("10.33.0.0/16"),
|
||||||
|
@ -521,7 +533,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.2"),
|
IPv4: ap("100.64.0.2"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user1@"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -600,19 +612,19 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1@"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
peers: types.Nodes{
|
peers: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.2"),
|
IPv4: ap("100.64.0.2"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2@"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
// "internal" exit node
|
// "internal" exit node
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.100"),
|
IPv4: ap("100.64.0.100"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100@"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: tsaddr.ExitRoutes(),
|
RoutableIPs: tsaddr.ExitRoutes(),
|
||||||
},
|
},
|
||||||
|
@ -661,7 +673,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: ap("100.64.0.100"),
|
IPv4: ap("100.64.0.100"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100@"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: tsaddr.ExitRoutes(),
|
RoutableIPs: tsaddr.ExitRoutes(),
|
||||||
},
|
},
|
||||||
|
@ -670,12 +682,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.2"),
|
IPv4: ap("100.64.0.2"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2@"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1@"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -768,7 +780,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: ap("100.64.0.100"),
|
IPv4: ap("100.64.0.100"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100@"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: tsaddr.ExitRoutes(),
|
RoutableIPs: tsaddr.ExitRoutes(),
|
||||||
},
|
},
|
||||||
|
@ -777,12 +789,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.2"),
|
IPv4: ap("100.64.0.2"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2@"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1@"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -809,9 +821,11 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
|
{IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny},
|
||||||
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
|
{IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny},
|
||||||
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
|
{IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny},
|
||||||
{IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny},
|
// This should not be included I believe, seems like
|
||||||
{IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny},
|
// this is a bug in the v1 code.
|
||||||
{IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny},
|
// {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny},
|
||||||
|
// {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny},
|
||||||
|
// {IP: "fd7a:115c:a1e0::100/128", Ports: tailcfg.PortRangeAny},
|
||||||
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
|
{IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny},
|
||||||
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
|
{IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny},
|
||||||
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
|
{IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny},
|
||||||
|
@ -881,7 +895,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: ap("100.64.0.100"),
|
IPv4: ap("100.64.0.100"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100@"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
|
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
|
||||||
},
|
},
|
||||||
|
@ -890,12 +904,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.2"),
|
IPv4: ap("100.64.0.2"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2@"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1@"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -969,7 +983,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: ap("100.64.0.100"),
|
IPv4: ap("100.64.0.100"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100@"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
|
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
|
||||||
},
|
},
|
||||||
|
@ -978,12 +992,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.2"),
|
IPv4: ap("100.64.0.2"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2@"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1@"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -1046,7 +1060,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: ap("100.64.0.100"),
|
IPv4: ap("100.64.0.100"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100@"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||||
},
|
},
|
||||||
|
@ -1056,7 +1070,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: ap("100.64.0.1"),
|
IPv4: ap("100.64.0.1"),
|
||||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1@"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -1090,22 +1104,19 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
filterV1, _ := polV1.CompileFilterRules(
|
filterV1, _ := polV1.CompileFilterRules(
|
||||||
append(tt.peers, tt.node),
|
append(tt.peers, tt.node),
|
||||||
)
|
)
|
||||||
polV2, err := PolicyFromBytes([]byte(tt.pol))
|
pm, err := NewPolicyManager([]byte(tt.pol), users, append(tt.peers, tt.node))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("parsing policy: %s", err)
|
t.Fatalf("parsing policy: %s", err)
|
||||||
}
|
}
|
||||||
filterV2, _ := polV2.CompileFilterRules(
|
|
||||||
append(tt.peers, tt.node),
|
|
||||||
)
|
|
||||||
|
|
||||||
if diff := cmp.Diff(filterV1, filterV2); diff != "" {
|
if diff := cmp.Diff(filterV1, pm.Filter()); diff != "" {
|
||||||
log.Trace().Interface("got", filterV2).Msg("result")
|
log.Trace().Interface("got", pm.Filter()).Msg("result")
|
||||||
t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff)
|
t.Errorf("TestReduceFilterRules() unexpected diff between v1 and v2 (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Move this from v1, or
|
// TODO(kradalby): Move this from v1, or
|
||||||
// rewrite.
|
// rewrite.
|
||||||
filterV2 = policy.ReduceFilterRules(tt.node, filterV2)
|
filterV2 := policy.ReduceFilterRules(tt.node, pm.Filter())
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, filterV2); diff != "" {
|
if diff := cmp.Diff(tt.want, filterV2); diff != "" {
|
||||||
log.Trace().Interface("got", filterV2).Msg("result")
|
log.Trace().Interface("got", filterV2).Msg("result")
|
||||||
|
|
80
hscontrol/policyv2/policy.go
Normal file
80
hscontrol/policyv2/policy.go
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
package policyv2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PolicyManager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
pol *Policy
|
||||||
|
users []types.User
|
||||||
|
nodes types.Nodes
|
||||||
|
|
||||||
|
filter []tailcfg.FilterRule
|
||||||
|
|
||||||
|
// TODO(kradalby): Implement SSH policy
|
||||||
|
sshPolicy *tailcfg.SSHPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
|
||||||
|
// It returns an error if the policy file is invalid.
|
||||||
|
// The policy manager will update the filter rules based on the users and nodes.
|
||||||
|
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||||
|
policy, err := policyFromBytes(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := PolicyManager{
|
||||||
|
pol: policy,
|
||||||
|
users: users,
|
||||||
|
nodes: nodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pm.updateLocked()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter returns the current filter rules for the entire tailnet.
|
||||||
|
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
return pm.filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||||
|
// It must be called with the lock held.
|
||||||
|
func (pm *PolicyManager) updateLocked() error {
|
||||||
|
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("compiling filter rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.filter = filter
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||||
|
func (pm *PolicyManager) SetUsers(users []types.User) error {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
pm.users = users
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||||
|
func (pm *PolicyManager) SetNodes(nodes types.Nodes) error {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
pm.nodes = nodes
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
58
hscontrol/policyv2/policy_test.go
Normal file
58
hscontrol/policyv2/policy_test.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package policyv2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||||
|
return &types.Node{
|
||||||
|
ID: 0,
|
||||||
|
Hostname: name,
|
||||||
|
IPv4: ap(ipv4),
|
||||||
|
IPv6: ap(ipv6),
|
||||||
|
User: user,
|
||||||
|
UserID: user.ID,
|
||||||
|
Hostinfo: hostinfo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyManager(t *testing.T) {
|
||||||
|
users := types.Users{
|
||||||
|
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
|
||||||
|
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pol string
|
||||||
|
nodes types.Nodes
|
||||||
|
wantFilter []tailcfg.FilterRule
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty-policy",
|
||||||
|
pol: "{}",
|
||||||
|
nodes: types.Nodes{},
|
||||||
|
wantFilter: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
filter := pm.Filter()
|
||||||
|
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
|
||||||
|
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Test SSH Policy
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
@ -64,7 +65,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a Asterix) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
|
|
||||||
ips.AddPrefix(tsaddr.AllIPv4())
|
ips.AddPrefix(tsaddr.AllIPv4())
|
||||||
|
@ -99,15 +100,47 @@ func (u Username) CanBeTagOwner() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u Username) Resolve(_ *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
func (u Username) resolveUser(users types.Users) (*types.User, error) {
|
||||||
|
var potentialUsers types.Users
|
||||||
|
for _, user := range users {
|
||||||
|
if user.ProviderIdentifier == string(u) {
|
||||||
|
potentialUsers = append(potentialUsers, user)
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if user.Email == string(u) {
|
||||||
|
potentialUsers = append(potentialUsers, user)
|
||||||
|
}
|
||||||
|
if user.Name == string(u) {
|
||||||
|
potentialUsers = append(potentialUsers, user)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(potentialUsers) > 1 {
|
||||||
|
return nil, fmt.Errorf("unable to resolve user identifier to distinct: %s matched multiple %s", u, potentialUsers)
|
||||||
|
} else if len(potentialUsers) == 0 {
|
||||||
|
return nil, fmt.Errorf("unable to resolve user identifier, no user found: %s not in %s", u, users)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := potentialUsers[0]
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
|
|
||||||
|
user, err := u.resolveUser(users)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.IsTagged() {
|
if node.IsTagged() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.User.Username() == string(u) {
|
if node.User.ID == user.ID {
|
||||||
node.AppendToIPSet(&ips)
|
node.AppendToIPSet(&ips)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -137,11 +170,11 @@ func (g Group) CanBeTagOwner() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g Group) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
|
|
||||||
for _, user := range p.Groups[g] {
|
for _, user := range p.Groups[g] {
|
||||||
uips, err := user.Resolve(nil, nodes)
|
uips, err := user.Resolve(nil, users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -170,7 +203,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tag) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
func (t Tag) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
|
@ -197,7 +230,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
|
|
||||||
pref, ok := p.Hosts[h]
|
pref, ok := p.Hosts[h]
|
||||||
|
@ -208,11 +241,26 @@ func (h Host) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the IP is a single host, look for a node to ensure we add all the IPs of
|
||||||
|
// the node to the IPSet.
|
||||||
|
appendIfNodeHasIP(nodes, &ips, pref)
|
||||||
ips.AddPrefix(netip.Prefix(pref))
|
ips.AddPrefix(netip.Prefix(pref))
|
||||||
|
|
||||||
return ips.IPSet()
|
return ips.IPSet()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref Prefix) {
|
||||||
|
if netip.Prefix(pref).IsSingleIP() {
|
||||||
|
addr := netip.Prefix(pref).Addr()
|
||||||
|
for _, node := range nodes {
|
||||||
|
if node.HasIP(addr) {
|
||||||
|
node.AppendToIPSet(ips)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Prefix netip.Prefix
|
type Prefix netip.Prefix
|
||||||
|
|
||||||
func (p Prefix) Validate() error {
|
func (p Prefix) Validate() error {
|
||||||
|
@ -261,9 +309,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p Prefix) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
|
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
|
|
||||||
|
appendIfNodeHasIP(nodes, &ips, p)
|
||||||
ips.AddPrefix(netip.Prefix(p))
|
ips.AddPrefix(netip.Prefix(p))
|
||||||
|
|
||||||
return ips.IPSet()
|
return ips.IPSet()
|
||||||
|
@ -296,7 +345,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
|
func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) {
|
||||||
switch ag {
|
switch ag {
|
||||||
case AutoGroupInternet:
|
case AutoGroupInternet:
|
||||||
return theInternet(), nil
|
return theInternet(), nil
|
||||||
|
@ -308,7 +357,7 @@ func (ag AutoGroup) Resolve(_ *Policy, _ types.Nodes) (*netipx.IPSet, error) {
|
||||||
type Alias interface {
|
type Alias interface {
|
||||||
Validate() error
|
Validate() error
|
||||||
UnmarshalJSON([]byte) error
|
UnmarshalJSON([]byte) error
|
||||||
Resolve(*Policy, types.Nodes) (*netipx.IPSet, error)
|
Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliasWithPorts struct {
|
type AliasWithPorts struct {
|
||||||
|
@ -428,11 +477,11 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a Aliases) Resolve(p *Policy, nodes types.Nodes) (*netipx.IPSet, error) {
|
func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||||
var ips netipx.IPSetBuilder
|
var ips netipx.IPSetBuilder
|
||||||
|
|
||||||
for _, alias := range a {
|
for _, alias := range a {
|
||||||
aips, err := alias.Resolve(p, nodes)
|
aips, err := alias.Resolve(p, users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -530,10 +579,67 @@ type Policy struct {
|
||||||
TagOwners TagOwners `json:"tagOwners"`
|
TagOwners TagOwners `json:"tagOwners"`
|
||||||
ACLs []ACL `json:"acls"`
|
ACLs []ACL `json:"acls"`
|
||||||
AutoApprovers AutoApprovers `json:"autoApprovers"`
|
AutoApprovers AutoApprovers `json:"autoApprovers"`
|
||||||
// SSHs []SSH `json:"ssh"`
|
SSHs []SSH `json:"ssh"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func PolicyFromBytes(b []byte) (*Policy, error) {
|
// SSH controls who can ssh into which machines.
|
||||||
|
type SSH struct {
|
||||||
|
Action string `json:"action"`
|
||||||
|
Sources SSHSrcAliases `json:"src"`
|
||||||
|
Destinations SSHDstAliases `json:"dst"`
|
||||||
|
Users []SSHUser `json:"users"`
|
||||||
|
CheckPeriod time.Duration `json:"checkPeriod,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.
|
||||||
|
// It can be a list of usernames, groups, tags or autogroups.
|
||||||
|
type SSHSrcAliases []Alias
|
||||||
|
|
||||||
|
func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
|
||||||
|
var aliases []AliasEnc
|
||||||
|
err := json.Unmarshal(b, &aliases)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*a = make([]Alias, len(aliases))
|
||||||
|
for i, alias := range aliases {
|
||||||
|
switch alias.Alias.(type) {
|
||||||
|
case *Username, *Group, *Tag, *AutoGroup:
|
||||||
|
(*a)[i] = alias.Alias
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("type %T not supported", alias.Alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule.
|
||||||
|
// It can be a list of usernames, tags or autogroups.
|
||||||
|
type SSHDstAliases []Alias
|
||||||
|
|
||||||
|
func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
|
||||||
|
var aliases []AliasEnc
|
||||||
|
err := json.Unmarshal(b, &aliases)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*a = make([]Alias, len(aliases))
|
||||||
|
for i, alias := range aliases {
|
||||||
|
switch alias.Alias.(type) {
|
||||||
|
case *Username, *Tag, *AutoGroup:
|
||||||
|
(*a)[i] = alias.Alias
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("type %T not supported", alias.Alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type SSHUser string
|
||||||
|
|
||||||
|
func policyFromBytes(b []byte) (*Policy, error) {
|
||||||
var policy Policy
|
var policy Policy
|
||||||
ast, err := hujson.Parse(b)
|
ast, err := hujson.Parse(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -173,7 +173,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||||
Destinations: []AliasWithPorts{
|
Destinations: []AliasWithPorts{
|
||||||
{
|
{
|
||||||
Alias: ptr.To(Username("otheruser@headscale.net")),
|
Alias: ptr.To(Username("otheruser@headscale.net")),
|
||||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
|
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -186,7 +186,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||||
Destinations: []AliasWithPorts{
|
Destinations: []AliasWithPorts{
|
||||||
{
|
{
|
||||||
Alias: gp("group:other"),
|
Alias: gp("group:other"),
|
||||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
|
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -199,7 +199,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||||
Destinations: []AliasWithPorts{
|
Destinations: []AliasWithPorts{
|
||||||
{
|
{
|
||||||
Alias: pp("100.101.102.104/32"),
|
Alias: pp("100.101.102.104/32"),
|
||||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
|
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -212,7 +212,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||||
Destinations: []AliasWithPorts{
|
Destinations: []AliasWithPorts{
|
||||||
{
|
{
|
||||||
Alias: pp("172.16.0.0/16"),
|
Alias: pp("172.16.0.0/16"),
|
||||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 80}},
|
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -225,7 +225,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||||
Destinations: []AliasWithPorts{
|
Destinations: []AliasWithPorts{
|
||||||
{
|
{
|
||||||
Alias: hp("host-1"),
|
Alias: hp("host-1"),
|
||||||
Ports: []tailcfg.PortRange{tailcfg.PortRange{First: 80, Last: 88}},
|
Ports: []tailcfg.PortRange{{First: 80, Last: 88}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -239,8 +239,8 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||||
{
|
{
|
||||||
Alias: tp("tag:user"),
|
Alias: tp("tag:user"),
|
||||||
Ports: []tailcfg.PortRange{
|
Ports: []tailcfg.PortRange{
|
||||||
tailcfg.PortRange{First: 80, Last: 80},
|
{First: 80, Last: 80},
|
||||||
tailcfg.PortRange{First: 443, Last: 443},
|
{First: 443, Last: 443},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -255,7 +255,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||||
{
|
{
|
||||||
Alias: agp("autogroup:internet"),
|
Alias: agp("autogroup:internet"),
|
||||||
Ports: []tailcfg.PortRange{
|
Ports: []tailcfg.PortRange{
|
||||||
tailcfg.PortRange{First: 80, Last: 80},
|
{First: 80, Last: 80},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -341,7 +341,7 @@ func TestUnmarshalPolicy(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
policy, err := PolicyFromBytes([]byte(tt.input))
|
policy, err := policyFromBytes([]byte(tt.input))
|
||||||
// TODO(kradalby): This error checking is broken,
|
// TODO(kradalby): This error checking is broken,
|
||||||
// but so is my brain, #longflight
|
// but so is my brain, #longflight
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -538,7 +538,9 @@ func TestResolvePolicy(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ips, err := tt.toResolve.Resolve(tt.pol, tt.nodes)
|
ips, err := tt.toResolve.Resolve(tt.pol,
|
||||||
|
types.Users{},
|
||||||
|
tt.nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to resolve: %s", err)
|
t.Fatalf("failed to resolve: %s", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue