This commit is contained in:
Kristoffer Dalby 2024-11-22 17:01:29 +00:00 committed by GitHub
commit 5860b9d652
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 527 additions and 141 deletions

View file

@ -1029,14 +1029,18 @@ func (h *Headscale) loadACLPolicy() error {
if err != nil {
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
}
users, err := h.db.ListUsers()
if err != nil {
return fmt.Errorf("loading users from database to validate policy: %w", err)
}
_, err = pol.CompileFilterRules(nodes)
_, err = pol.CompileFilterRules(users, nodes)
if err != nil {
return fmt.Errorf("verifying policy rules: %w", err)
}
if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
if err != nil {
return fmt.Errorf("verifying SSH rules: %w", err)
}

View file

@ -256,10 +256,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(testPeers), check.Equals, 9)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil)
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)

View file

@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes(
if approvedAlias == node.User.Username() {
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
users, err := ListUsers(tx)
if err != nil {
return fmt.Errorf("looking up users to expand route alias: %w", err)
}
// TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias)
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
if err != nil {
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
}

View file

@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy(
if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
}
users, err := api.h.db.ListUsers()
if err != nil {
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
}
_, err = pol.CompileFilterRules(nodes)
_, err = pol.CompileFilterRules(users, nodes)
if err != nil {
return nil, fmt.Errorf("verifying policy rules: %w", err)
}
if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err)
}

View file

@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
func (m *Mapper) fullMapResponse(
node *types.Node,
peers types.Nodes,
users []types.User,
pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse(
pol,
node,
capVer,
users,
peers,
peers,
m.cfg,
@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse(
if err != nil {
return nil, err
}
users, err := m.db.ListUsers()
if err != nil {
return nil, err
}
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
if err != nil {
return nil, err
}
@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse(
return nil, err
}
users, err := m.db.ListUsers()
if err != nil {
return nil, fmt.Errorf("listing users for map response: %w", err)
}
var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed {
@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse(
pol,
node,
mapRequest.Version,
users,
peers,
changedNodes,
m.cfg,
@ -508,16 +520,17 @@ func appendPeerChanges(
pol *policy.ACLPolicy,
node *types.Node,
capVer tailcfg.CapabilityVersion,
users []types.User,
peers types.Nodes,
changed types.Nodes,
cfg *types.Config,
) error {
packetFilter, err := pol.CompileFilterRules(append(peers, node))
packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
if err != nil {
return err
}
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
if err != nil {
return err
}

View file

@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
mini := &types.Node{
ID: 0,
MachineKey: mustMK(
@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: 0,
User: types.User{Name: "mini"},
UserID: user1.ID,
User: user1,
ForcedTags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.2"),
Hostname: "peer1",
GivenName: "peer1",
UserID: 0,
User: types.User{Name: "mini"},
UserID: user1.ID,
User: user1,
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.3"),
Hostname: "peer2",
GivenName: "peer2",
UserID: 1,
User: types.User{Name: "peer2"},
UserID: user2.ID,
User: user2,
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) {
got, err := mappy.fullMapResponse(
tt.node,
tt.peers,
[]types.User{user1, user2},
tt.pol,
0,
)

View file

@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests(
policy *ACLPolicy,
node *types.Node,
peers types.Nodes,
users []types.User,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all
if policy == nil {
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
}
rules, err := policy.CompileFilterRules(append(peers, node))
rules, err := policy.CompileFilterRules(users, append(peers, node))
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
sshPolicy, err := policy.CompileSSHPolicy(node, peers)
sshPolicy, err := policy.CompileSSHPolicy(node, users, peers)
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests(
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) CompileFilterRules(
users []types.User,
nodes types.Nodes,
) ([]tailcfg.FilterRule, error) {
if pol == nil {
@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules(
var srcIPs []string
for srcIndex, src := range acl.Sources {
srcs, err := pol.expandSource(src, nodes)
srcs, err := pol.expandSource(src, users, nodes)
if err != nil {
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err)
return nil, fmt.Errorf(
"parsing policy, acl index: %d->%d: %w",
index,
srcIndex,
err,
)
}
srcIPs = append(srcIPs, srcs...)
}
@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules(
expanded, err := pol.ExpandAlias(
nodes,
users,
alias,
)
if err != nil {
@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
func (pol *ACLPolicy) CompileSSHPolicy(
node *types.Node,
users []types.User,
peers types.Nodes,
) (*tailcfg.SSHPolicy, error) {
if pol == nil {
@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
for index, sshACL := range pol.SSHs {
var dest netipx.IPSetBuilder
for _, src := range sshACL.Destinations {
expanded, err := pol.ExpandAlias(append(peers, node), src)
expanded, err := pol.ExpandAlias(append(peers, node), users, src)
if err != nil {
return nil, err
}
@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy(
case "check":
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err)
return nil, fmt.Errorf(
"parsing SSH policy, parsing check duration, index: %d: %w",
index,
err,
)
} else {
action = *checkAction
}
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err)
return nil, fmt.Errorf(
"parsing SSH policy, unknown action %q, index: %d: %w",
sshACL.Action,
index,
err,
)
}
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
} else {
expandedSrcs, err := pol.ExpandAlias(
peers,
users,
rawSrc,
)
if err != nil {
@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// with the given src alias.
func (pol *ACLPolicy) expandSource(
src string,
users []types.User,
nodes types.Nodes,
) ([]string, error) {
ipSet, err := pol.ExpandAlias(nodes, src)
ipSet, err := pol.ExpandAlias(nodes, users, src)
if err != nil {
return []string{}, err
}
@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource(
// and transform these in IPAddresses.
func (pol *ACLPolicy) ExpandAlias(
nodes types.Nodes,
users []types.User,
alias string,
) (*netipx.IPSet, error) {
if isWildcard(alias) {
@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group
if isGroup(alias) {
return pol.expandIPsFromGroup(alias, nodes)
return pol.expandIPsFromGroup(alias, users, nodes)
}
// if alias is a tag
if isTag(alias) {
return pol.expandIPsFromTag(alias, nodes)
return pol.expandIPsFromTag(alias, users, nodes)
}
if isAutoGroup(alias) {
@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias(
}
// if alias is a user
if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil {
if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil {
return ips, err
}
@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias(
if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.ExpandAlias(nodes, h.String())
return pol.ExpandAlias(nodes, users, h.String())
}
// if alias is an IP
@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup(
func (pol *ACLPolicy) expandIPsFromGroup(
group string,
users []types.User,
nodes types.Nodes,
) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder
users, err := pol.expandUsersFromGroup(group)
userTokens, err := pol.expandUsersFromGroup(group)
if err != nil {
return &netipx.IPSet{}, err
}
for _, user := range users {
filteredNodes := filterNodesByUser(nodes, user)
for _, user := range userTokens {
filteredNodes := filterNodesByUser(nodes, users, user)
for _, node := range filteredNodes {
node.AppendToIPSet(&build)
}
@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
func (pol *ACLPolicy) expandIPsFromTag(
alias string,
users []types.User,
nodes types.Nodes,
) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder
@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
// filter out nodes per tag owner
for _, user := range owners {
nodes := filterNodesByUser(nodes, user)
nodes := filterNodesByUser(nodes, users, user)
for _, node := range nodes {
if node.Hostinfo == nil {
continue
@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag(
func (pol *ACLPolicy) expandIPsFromUser(
user string,
users []types.User,
nodes types.Nodes,
) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder
filteredNodes := filterNodesByUser(nodes, user)
filteredNodes := filterNodesByUser(nodes, users, user)
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
// shortcurcuit if we have no nodes to get ips from.
@ -953,10 +977,43 @@ func (pol *ACLPolicy) TagsOfNode(
return validTags, invalidTags
}
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
// filterNodesByUser returns a list of nodes that match the given userToken from a
// policy.
// Matching nodes are determined by first matching the user token to a user by checking:
// - If it is an ID that mactches the user database ID
// - It is the Provider Identifier from OIDC
// - It matches the username or email of a user
//
// If the token matches more than one user, zero nodes will returned.
func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes {
var out types.Nodes
var potentialUsers []types.User
for _, user := range users {
if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == userToken {
// If a user is matching with a known unique field,
// disgard all other users and only keep the current
// user.
potentialUsers = []types.User{user}
break
}
if user.Email == userToken {
potentialUsers = append(potentialUsers, user)
}
if user.Name == userToken {
potentialUsers = append(potentialUsers, user)
}
}
if len(potentialUsers) != 1 {
return nil
}
user := potentialUsers[0]
for _, node := range nodes {
if node.User.Username() == user {
if node.User.ID == user.ID {
out = append(out, node)
}
}

View file

@ -1,9 +1,12 @@
package policy
import (
"database/sql"
"errors"
"math/rand/v2"
"net/netip"
"slices"
"sort"
"testing"
"github.com/google/go-cmp/cmp"
@ -14,6 +17,7 @@ import (
"github.com/stretchr/testify/require"
"go4.org/netipx"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
)
@ -375,15 +379,21 @@ func TestParsing(t *testing.T) {
return
}
rules, err := pol.CompileFilterRules(types.Nodes{
user := types.User{
Model: gorm.Model{ID: 1},
Name: "testuser",
}
rules, err := pol.CompileFilterRules(
[]types.User{
user,
},
types.Nodes{
&types.Node{
IPv4: iap("100.100.100.100"),
},
&types.Node{
IPv4: iap("200.200.200.200"),
User: types.User{
Name: "testuser",
},
User: user,
Hostinfo: &tailcfg.Hostinfo{},
},
})
@ -533,7 +543,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil)
rules, err := pol.CompileFilterRules(types.Nodes{})
rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{})
c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil)
}
@ -549,7 +559,12 @@ func (s *Suite) TestInvalidAction(c *check.C) {
},
},
}
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
_, _, err := GenerateFilterAndSSHRulesForTests(
pol,
&types.Node{},
types.Nodes{},
[]types.User{},
)
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
}
@ -568,7 +583,12 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
},
},
}
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
_, _, err := GenerateFilterAndSSHRulesForTests(
pol,
&types.Node{},
types.Nodes{},
[]types.User{},
)
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
}
@ -584,7 +604,12 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
},
}
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
_, _, err := GenerateFilterAndSSHRulesForTests(
pol,
&types.Node{},
types.Nodes{},
[]types.User{},
)
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
}
@ -860,7 +885,25 @@ func Test_expandPorts(t *testing.T) {
}
}
func Test_listNodesInUser(t *testing.T) {
func Test_filterNodesByUser(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "marc"},
{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"},
{
Model: gorm.Model{ID: 3},
Name: "mikael",
Email: "mikael@headscale.net",
ProviderIdentifier: sql.NullString{String: "http://oidc.org/1234", Valid: true},
},
{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"},
{Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"},
{Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"},
{Model: gorm.Model{ID: 7}, Name: "1"},
{Model: gorm.Model{ID: 8}, Name: "alex", Email: "alex@headscale.net"},
{Model: gorm.Model{ID: 9}, Name: "alex@headscale.net"},
{Model: gorm.Model{ID: 10}, Email: "http://oidc.org/1234"},
}
type args struct {
nodes types.Nodes
user string
@ -874,50 +917,258 @@ func Test_listNodesInUser(t *testing.T) {
name: "1 node in user",
args: args{
nodes: types.Nodes{
&types.Node{User: types.User{Name: "joe"}},
&types.Node{User: users[1]},
},
user: "joe",
},
want: types.Nodes{
&types.Node{User: types.User{Name: "joe"}},
&types.Node{User: users[1]},
},
},
{
name: "3 nodes, 2 in user",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}},
&types.Node{ID: 2, User: types.User{Name: "marc"}},
&types.Node{ID: 3, User: types.User{Name: "marc"}},
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: users[0]},
},
user: "marc",
},
want: types.Nodes{
&types.Node{ID: 2, User: types.User{Name: "marc"}},
&types.Node{ID: 3, User: types.User{Name: "marc"}},
&types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: users[0]},
},
},
{
name: "5 nodes, 0 in user",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}},
&types.Node{ID: 2, User: types.User{Name: "marc"}},
&types.Node{ID: 3, User: types.User{Name: "marc"}},
&types.Node{ID: 4, User: types.User{Name: "marc"}},
&types.Node{ID: 5, User: types.User{Name: "marc"}},
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: users[0]},
&types.Node{ID: 4, User: users[0]},
&types.Node{ID: 5, User: users[0]},
},
user: "mickael",
},
want: nil,
},
{
name: "match-by-provider-ident",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[2]},
},
},
{
name: "match-by-email",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
&types.Node{ID: 8, User: users[7]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[1]},
},
},
{
name: "multi-match-is-zero",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
&types.Node{ID: 3, User: users[3]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-email-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match email, then provider id
&types.Node{ID: 3, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-username-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match username, then provider id
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-duplicate-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-unique-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "marc",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[0]},
},
},
{
name: "all-users-no-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 8, User: users[7]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[1]},
},
},
{
name: "email-as-username-duplicate",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[7]},
&types.Node{ID: 2, User: users[8]},
},
user: "alex@headscale.net",
},
want: nil,
},
{
name: "all-users-no-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working@headscale.net",
},
want: nil,
},
{
name: "all-users-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 6, User: users[5]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 3, User: users[2]},
},
},
{
name: "all-users-no-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 6, User: users[5]},
},
user: "http://oidc.org/4321",
},
want: nil,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := filterNodesByUser(test.args.nodes, test.args.user)
for range 1000 {
ns := test.args.nodes
rand.Shuffle(len(ns), func(i, j int) {
ns[i], ns[j] = ns[j], ns[i]
})
us := users
rand.Shuffle(len(us), func(i, j int) {
us[i], us[j] = us[j], us[i]
})
got := filterNodesByUser(ns, us, test.args.user)
sort.Slice(got, func(i, j int) bool {
return got[i].ID < got[j].ID
})
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff)
t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff)
}
}
})
}
@ -940,6 +1191,12 @@ func Test_expandAlias(t *testing.T) {
return s
}
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "joe"},
{Model: gorm.Model{ID: 2}, Name: "marc"},
{Model: gorm.Model{ID: 3}, Name: "mickael"},
}
type field struct {
pol ACLPolicy
}
@ -989,19 +1246,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"},
User: users[2],
},
},
},
@ -1022,19 +1279,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"},
User: users[2],
},
},
},
@ -1185,7 +1442,7 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
@ -1194,7 +1451,7 @@ func Test_expandAlias(t *testing.T) {
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
@ -1203,11 +1460,11 @@ func Test_expandAlias(t *testing.T) {
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"},
User: users[0],
},
},
},
@ -1260,21 +1517,21 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
ForcedTags: []string{"tag:hr-webserver"},
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
ForcedTags: []string{"tag:hr-webserver"},
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"},
User: users[2],
},
},
},
@ -1295,12 +1552,12 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
ForcedTags: []string{"tag:hr-webserver"},
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
@ -1309,11 +1566,11 @@ func Test_expandAlias(t *testing.T) {
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"},
User: users[2],
},
},
},
@ -1350,12 +1607,12 @@ func Test_expandAlias(t *testing.T) {
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
Hostinfo: &tailcfg.Hostinfo{},
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"},
User: users[0],
Hostinfo: &tailcfg.Hostinfo{},
},
},
@ -1368,6 +1625,7 @@ func Test_expandAlias(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
got, err := test.field.pol.ExpandAlias(
test.args.nodes,
users,
test.args.alias,
)
if (err != nil) != test.wantErr {
@ -1715,6 +1973,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.CompileFilterRules(
[]types.User{},
tt.args.nodes,
)
if (err != nil) != tt.wantErr {
@ -1842,6 +2101,13 @@ func TestTheInternet(t *testing.T) {
}
func TestReduceFilterRules(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "mickael"},
{Model: gorm.Model{ID: 2}, Name: "user1"},
{Model: gorm.Model{ID: 3}, Name: "user2"},
{Model: gorm.Model{ID: 4}, Name: "user100"},
}
tests := []struct {
name string
node *types.Node
@ -1863,13 +2129,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: types.User{Name: "mickael"},
User: users[0],
},
peers: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: types.User{Name: "mickael"},
User: users[0],
},
},
want: []tailcfg.FilterRule{},
@ -1896,7 +2162,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
@ -1907,7 +2173,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -1975,19 +2241,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
peers: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
// "internal" exit node
&types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -2034,12 +2300,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2131,7 +2397,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -2140,12 +2406,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2243,7 +2509,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("8.0.0.0/16"),
@ -2255,12 +2521,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2333,7 +2599,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("8.0.0.0/8"),
@ -2345,12 +2611,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2416,7 +2682,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
@ -2426,7 +2692,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2454,6 +2720,7 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, _ := tt.pol.CompileFilterRules(
users,
append(tt.peers, tt.node),
)
@ -3461,7 +3728,7 @@ func TestSSHRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers)
require.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" {
@ -3544,14 +3811,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
RequestTags: []string{"tag:test"},
}
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{
ID: 0,
Hostname: "testnodes",
IPv4: iap("100.64.0.1"),
UserID: 0,
User: types.User{
Name: "user1",
},
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo,
}
@ -3568,7 +3838,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
},
}
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user})
require.NoError(t, err)
want := []tailcfg.FilterRule{
@ -3602,6 +3872,7 @@ func TestInvalidTagValidUser(t *testing.T) {
IPv4: iap("100.64.0.1"),
UserID: 1,
User: types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
@ -3619,7 +3890,12 @@ func TestInvalidTagValidUser(t *testing.T) {
},
}
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
got, _, err := GenerateFilterAndSSHRulesForTests(
pol,
node,
types.Nodes{},
[]types.User{node.User},
)
require.NoError(t, err)
want := []tailcfg.FilterRule{
@ -3653,6 +3929,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
IPv4: iap("100.64.0.1"),
UserID: 1,
User: types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
@ -3678,7 +3955,12 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
got, _, err := GenerateFilterAndSSHRulesForTests(
pol,
node,
types.Nodes{},
[]types.User{node.User},
)
require.NoError(t, err)
want := []tailcfg.FilterRule{
@ -3707,15 +3989,17 @@ func TestValidTagInvalidUser(t *testing.T) {
Hostname: "webserver",
RequestTags: []string{"tag:webapp"},
}
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{
ID: 1,
Hostname: "webserver",
IPv4: iap("100.64.0.1"),
UserID: 1,
User: types.User{
Name: "user1",
},
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo,
}
@ -3730,9 +4014,7 @@ func TestValidTagInvalidUser(t *testing.T) {
Hostname: "user",
IPv4: iap("100.64.0.2"),
UserID: 1,
User: types.User{
Name: "user1",
},
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo2,
}
@ -3748,7 +4030,12 @@ func TestValidTagInvalidUser(t *testing.T) {
},
}
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2})
got, _, err := GenerateFilterAndSSHRulesForTests(
pol,
node,
types.Nodes{nodes2},
[]types.User{user},
)
require.NoError(t, err)
want := []tailcfg.FilterRule{

View file

@ -2,6 +2,8 @@ package types
import (
"cmp"
"database/sql"
"net/mail"
"strconv"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -34,7 +36,7 @@ type User struct {
// Unique identifier of the user from OIDC,
// comes from `sub` claim in the OIDC token
// and is used to lookup the user.
ProviderIdentifier string `gorm:"index"`
ProviderIdentifier sql.NullString `gorm:"index"`
// Provider is the origin of the user account,
// same as RegistrationMethod, without authkey.
@ -51,7 +53,7 @@ type User struct {
// should be used throughout headscale, in information returned to the
// user and the Policy engine.
func (u *User) Username() string {
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10))
}
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
@ -107,7 +109,7 @@ func (u *User) Proto() *v1.User {
CreatedAt: timestamppb.New(u.CreatedAt),
DisplayName: u.DisplayName,
Email: u.Email,
ProviderId: u.ProviderIdentifier,
ProviderId: u.ProviderIdentifier.String,
Provider: u.Provider,
ProfilePicUrl: u.ProfilePicURL,
}
@ -129,10 +131,20 @@ type OIDCClaims struct {
// FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) {
u.ProviderIdentifier = claims.Sub
u.DisplayName = claims.Name
u.Email = claims.Email
err := util.CheckForFQDNRules(claims.Username)
if err == nil {
u.Name = claims.Username
}
if claims.EmailVerified {
_, err = mail.ParseAddress(claims.Email)
if err == nil {
u.Email = claims.Email
}
}
u.ProviderIdentifier.String = claims.Sub
u.DisplayName = claims.Name
u.ProfilePicURL = claims.ProfilePictureURL
u.Provider = util.RegisterMethodOIDC
}