This commit is contained in:
Kristoffer Dalby 2024-11-18 14:58:49 +00:00 committed by GitHub
commit bcb1dceb5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 495 additions and 145 deletions

View file

@ -1027,14 +1027,18 @@ func (h *Headscale) loadACLPolicy() error {
if err != nil { if err != nil {
return fmt.Errorf("loading nodes from database to validate policy: %w", err) return fmt.Errorf("loading nodes from database to validate policy: %w", err)
} }
users, err := h.db.ListUsers()
if err != nil {
return fmt.Errorf("loading users from database to validate policy: %w", err)
}
_, err = pol.CompileFilterRules(nodes) _, err = pol.CompileFilterRules(users, nodes)
if err != nil { if err != nil {
return fmt.Errorf("verifying policy rules: %w", err) return fmt.Errorf("verifying policy rules: %w", err)
} }
if len(nodes) > 0 { if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes) _, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
if err != nil { if err != nil {
return fmt.Errorf("verifying SSH rules: %w", err) return fmt.Errorf("verifying SSH rules: %w", err)
} }

View file

@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(testPeers), check.Equals, 9) c.Assert(len(testPeers), check.Equals, 9)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)

View file

@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes(
if approvedAlias == node.User.Username() { if approvedAlias == node.User.Username() {
approvedRoutes = append(approvedRoutes, advertisedRoute) approvedRoutes = append(approvedRoutes, advertisedRoute)
} else { } else {
users, err := ListUsers(tx)
if err != nil {
return fmt.Errorf("looking up users to expand route alias: %w", err)
}
// TODO(kradalby): figure out how to get this to depend on less stuff // TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
if err != nil { if err != nil {
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
} }

View file

@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy(
if err != nil { if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
} }
users, err := api.h.db.ListUsers()
if err != nil {
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
}
_, err = pol.CompileFilterRules(nodes) _, err = pol.CompileFilterRules(users, nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("verifying policy rules: %w", err) return nil, fmt.Errorf("verifying policy rules: %w", err)
} }
if len(nodes) > 0 { if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes) _, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err) return nil, fmt.Errorf("verifying SSH rules: %w", err)
} }

View file

@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
func (m *Mapper) fullMapResponse( func (m *Mapper) fullMapResponse(
node *types.Node, node *types.Node,
peers types.Nodes, peers types.Nodes,
users []types.User,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse(
pol, pol,
node, node,
capVer, capVer,
users,
peers, peers,
peers, peers,
m.cfg, m.cfg,
@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse(
if err != nil { if err != nil {
return nil, err return nil, err
} }
users, err := m.db.ListUsers()
if err != nil {
return nil, err
}
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse(
return nil, err return nil, err
} }
users, err := m.db.ListUsers()
if err != nil {
return nil, fmt.Errorf("listing users for map response: %w", err)
}
var removedIDs []tailcfg.NodeID var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed { for nodeID, nodeChanged := range changed {
@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse(
pol, pol,
node, node,
mapRequest.Version, mapRequest.Version,
users,
peers, peers,
changedNodes, changedNodes,
m.cfg, m.cfg,
@ -508,16 +520,17 @@ func appendPeerChanges(
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
node *types.Node, node *types.Node,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
users []types.User,
peers types.Nodes, peers types.Nodes,
changed types.Nodes, changed types.Nodes,
cfg *types.Config, cfg *types.Config,
) error { ) error {
packetFilter, err := pol.CompileFilterRules(append(peers, node)) packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
if err != nil { if err != nil {
return err return err
} }
sshPolicy, err := pol.CompileSSHPolicy(node, peers) sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
if err != nil { if err != nil {
return err return err
} }

View file

@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
mini := &types.Node{ mini := &types.Node{
ID: 0, ID: 0,
MachineKey: mustMK( MachineKey: mustMK(
@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
Hostname: "mini", Hostname: "mini",
GivenName: "mini", GivenName: "mini",
UserID: 0, UserID: user1.ID,
User: types.User{Name: "mini"}, User: user1,
ForcedTags: []string{}, ForcedTags: []string{},
AuthKey: &types.PreAuthKey{}, AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
Hostname: "peer1", Hostname: "peer1",
GivenName: "peer1", GivenName: "peer1",
UserID: 0, UserID: user1.ID,
User: types.User{Name: "mini"}, User: user1,
ForcedTags: []string{}, ForcedTags: []string{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Expiry: &expire, Expiry: &expire,
@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
Hostname: "peer2", Hostname: "peer2",
GivenName: "peer2", GivenName: "peer2",
UserID: 1, UserID: user2.ID,
User: types.User{Name: "peer2"}, User: user2,
ForcedTags: []string{}, ForcedTags: []string{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Expiry: &expire, Expiry: &expire,
@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) {
got, err := mappy.fullMapResponse( got, err := mappy.fullMapResponse(
tt.node, tt.node,
tt.peers, tt.peers,
[]types.User{user1, user2},
tt.pol, tt.pol,
0, 0,
) )

View file

@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests(
policy *ACLPolicy, policy *ACLPolicy,
node *types.Node, node *types.Node,
peers types.Nodes, peers types.Nodes,
users []types.User,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all // If there is no policy defined, we default to allow all
if policy == nil { if policy == nil {
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
} }
rules, err := policy.CompileFilterRules(append(peers, node)) rules, err := policy.CompileFilterRules(users, append(peers, node))
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
sshPolicy, err := policy.CompileSSHPolicy(node, peers) sshPolicy, err := policy.CompileSSHPolicy(node, users, peers)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests(
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) CompileFilterRules( func (pol *ACLPolicy) CompileFilterRules(
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
if pol == nil { if pol == nil {
@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules(
var srcIPs []string var srcIPs []string
for srcIndex, src := range acl.Sources { for srcIndex, src := range acl.Sources {
srcs, err := pol.expandSource(src, nodes) srcs, err := pol.expandSource(src, users, nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) return nil, fmt.Errorf(
"parsing policy, acl index: %d->%d: %w",
index,
srcIndex,
err,
)
} }
srcIPs = append(srcIPs, srcs...) srcIPs = append(srcIPs, srcs...)
} }
@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules(
expanded, err := pol.ExpandAlias( expanded, err := pol.ExpandAlias(
nodes, nodes,
users,
alias, alias,
) )
if err != nil { if err != nil {
@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
func (pol *ACLPolicy) CompileSSHPolicy( func (pol *ACLPolicy) CompileSSHPolicy(
node *types.Node, node *types.Node,
users []types.User,
peers types.Nodes, peers types.Nodes,
) (*tailcfg.SSHPolicy, error) { ) (*tailcfg.SSHPolicy, error) {
if pol == nil { if pol == nil {
@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
for index, sshACL := range pol.SSHs { for index, sshACL := range pol.SSHs {
var dest netipx.IPSetBuilder var dest netipx.IPSetBuilder
for _, src := range sshACL.Destinations { for _, src := range sshACL.Destinations {
expanded, err := pol.ExpandAlias(append(peers, node), src) expanded, err := pol.ExpandAlias(append(peers, node), users, src)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy(
case "check": case "check":
checkAction, err := sshCheckAction(sshACL.CheckPeriod) checkAction, err := sshCheckAction(sshACL.CheckPeriod)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) return nil, fmt.Errorf(
"parsing SSH policy, parsing check duration, index: %d: %w",
index,
err,
)
} else { } else {
action = *checkAction action = *checkAction
} }
default: default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) return nil, fmt.Errorf(
"parsing SSH policy, unknown action %q, index: %d: %w",
sshACL.Action,
index,
err,
)
} }
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
} else { } else {
expandedSrcs, err := pol.ExpandAlias( expandedSrcs, err := pol.ExpandAlias(
peers, peers,
users,
rawSrc, rawSrc,
) )
if err != nil { if err != nil {
@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// with the given src alias. // with the given src alias.
func (pol *ACLPolicy) expandSource( func (pol *ACLPolicy) expandSource(
src string, src string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) ([]string, error) { ) ([]string, error) {
ipSet, err := pol.ExpandAlias(nodes, src) ipSet, err := pol.ExpandAlias(nodes, users, src)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource(
// and transform these in IPAddresses. // and transform these in IPAddresses.
func (pol *ACLPolicy) ExpandAlias( func (pol *ACLPolicy) ExpandAlias(
nodes types.Nodes, nodes types.Nodes,
users []types.User,
alias string, alias string,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
if isWildcard(alias) { if isWildcard(alias) {
@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group // if alias is a group
if isGroup(alias) { if isGroup(alias) {
return pol.expandIPsFromGroup(alias, nodes) return pol.expandIPsFromGroup(alias, users, nodes)
} }
// if alias is a tag // if alias is a tag
if isTag(alias) { if isTag(alias) {
return pol.expandIPsFromTag(alias, nodes) return pol.expandIPsFromTag(alias, users, nodes)
} }
if isAutoGroup(alias) { if isAutoGroup(alias) {
@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias(
} }
// if alias is a user // if alias is a user
if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil {
return ips, err return ips, err
} }
@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias(
if h, ok := pol.Hosts[alias]; ok { if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.ExpandAlias(nodes, h.String()) return pol.ExpandAlias(nodes, users, h.String())
} }
// if alias is an IP // if alias is an IP
@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup(
func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromGroup(
group string, group string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
users, err := pol.expandUsersFromGroup(group) userTokens, err := pol.expandUsersFromGroup(group)
if err != nil { if err != nil {
return &netipx.IPSet{}, err return &netipx.IPSet{}, err
} }
for _, user := range users { for _, user := range userTokens {
filteredNodes := filterNodesByUser(nodes, user) filteredNodes := filterNodesByUser(nodes, users, user)
for _, node := range filteredNodes { for _, node := range filteredNodes {
node.AppendToIPSet(&build) node.AppendToIPSet(&build)
} }
@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromTag(
alias string, alias string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
// filter out nodes per tag owner // filter out nodes per tag owner
for _, user := range owners { for _, user := range owners {
nodes := filterNodesByUser(nodes, user) nodes := filterNodesByUser(nodes, users, user)
for _, node := range nodes { for _, node := range nodes {
if node.Hostinfo == nil { if node.Hostinfo == nil {
continue continue
@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag(
func (pol *ACLPolicy) expandIPsFromUser( func (pol *ACLPolicy) expandIPsFromUser(
user string, user string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
filteredNodes := filterNodesByUser(nodes, user) filteredNodes := filterNodesByUser(nodes, users, user)
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
// shortcurcuit if we have no nodes to get ips from. // shortcurcuit if we have no nodes to get ips from.
@ -953,10 +977,43 @@ func (pol *ACLPolicy) TagsOfNode(
return validTags, invalidTags return validTags, invalidTags
} }
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { // filterNodesByUser returns a list of nodes that match the given userToken from a
// policy.
// Matching nodes are determined by first matching the user token to a user by checking:
// - If it is an ID that mactches the user database ID
// - It is the Provider Identifier from OIDC
// - It matches the username or email of a user
//
// If the token matches more than one user, zero nodes will returned.
func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes {
var out types.Nodes var out types.Nodes
var potentialUsers []types.User
for _, user := range users {
if user.ProviderIdentifier == userToken {
// If a user is matching with a known unique field,
// disgard all other users and only keep the current
// user.
potentialUsers = []types.User{user}
break
}
if user.Email == userToken {
potentialUsers = append(potentialUsers, user)
}
if user.Name == userToken {
potentialUsers = append(potentialUsers, user)
}
}
if len(potentialUsers) != 1 {
return nil
}
user := potentialUsers[0]
for _, node := range nodes { for _, node := range nodes {
if node.User.Username() == user { if node.User.ID == user.ID {
out = append(out, node) out = append(out, node)
} }
} }

View file

@ -2,8 +2,10 @@ package policy
import ( import (
"errors" "errors"
"math/rand/v2"
"net/netip" "net/netip"
"slices" "slices"
"sort"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -14,6 +16,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -375,15 +378,21 @@ func TestParsing(t *testing.T) {
return return
} }
rules, err := pol.CompileFilterRules(types.Nodes{ user := types.User{
Model: gorm.Model{ID: 1},
Name: "testuser",
}
rules, err := pol.CompileFilterRules(
[]types.User{
user,
},
types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.100.100.100"), IPv4: iap("100.100.100.100"),
}, },
&types.Node{ &types.Node{
IPv4: iap("200.200.200.200"), IPv4: iap("200.200.200.200"),
User: types.User{ User: user,
Name: "testuser",
},
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
}) })
@ -533,7 +542,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := pol.CompileFilterRules(types.Nodes{}) rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{})
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil) c.Assert(rules, check.IsNil)
} }
@ -549,7 +558,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
} }
@ -568,7 +577,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
} }
@ -584,7 +593,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
} }
@ -860,7 +869,20 @@ func Test_expandPorts(t *testing.T) {
} }
} }
func Test_listNodesInUser(t *testing.T) { func Test_filterNodesByUser(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "marc"},
{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"},
{Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"},
{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"},
{Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"},
{Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"},
{Model: gorm.Model{ID: 7}, Name: "1"},
{Model: gorm.Model{ID: 8}, Name: "alex", Email: "alex@headscale.net"},
{Model: gorm.Model{ID: 9}, Name: "alex@headscale.net"},
{Model: gorm.Model{ID: 10}, Email: "http://oidc.org/1234"},
}
type args struct { type args struct {
nodes types.Nodes nodes types.Nodes
user string user string
@ -874,50 +896,258 @@ func Test_listNodesInUser(t *testing.T) {
name: "1 node in user", name: "1 node in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{User: types.User{Name: "joe"}}, &types.Node{User: users[1]},
}, },
user: "joe", user: "joe",
}, },
want: types.Nodes{ want: types.Nodes{
&types.Node{User: types.User{Name: "joe"}}, &types.Node{User: users[1]},
}, },
}, },
{ {
name: "3 nodes, 2 in user", name: "3 nodes, 2 in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}}, &types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
}, },
user: "marc", user: "marc",
}, },
want: types.Nodes{ want: types.Nodes{
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
}, },
}, },
{ {
name: "5 nodes, 0 in user", name: "5 nodes, 0 in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}}, &types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
&types.Node{ID: 4, User: types.User{Name: "marc"}}, &types.Node{ID: 4, User: users[0]},
&types.Node{ID: 5, User: types.User{Name: "marc"}}, &types.Node{ID: 5, User: users[0]},
}, },
user: "mickael", user: "mickael",
}, },
want: nil, want: nil,
}, },
{
name: "match-by-provider-ident",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[2]},
},
},
{
name: "match-by-email",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
&types.Node{ID: 8, User: users[7]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[1]},
},
},
{
name: "multi-match-is-zero",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
&types.Node{ID: 3, User: users[3]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-email-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match email, then provider id
&types.Node{ID: 3, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-username-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match username, then provider id
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-duplicate-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-unique-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "marc",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[0]},
},
},
{
name: "all-users-no-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 8, User: users[7]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[1]},
},
},
{
name: "email-as-username-duplicate",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[7]},
&types.Node{ID: 2, User: users[8]},
},
user: "alex@headscale.net",
},
want: nil,
},
{
name: "all-users-no-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working@headscale.net",
},
want: nil,
},
{
name: "all-users-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 6, User: users[5]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 3, User: users[2]},
},
},
{
name: "all-users-no-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 6, User: users[5]},
},
user: "http://oidc.org/4321",
},
want: nil,
},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got := filterNodesByUser(test.args.nodes, test.args.user) for range 1000 {
ns := test.args.nodes
rand.Shuffle(len(ns), func(i, j int) {
ns[i], ns[j] = ns[j], ns[i]
})
us := users
rand.Shuffle(len(us), func(i, j int) {
us[i], us[j] = us[j], us[i]
})
got := filterNodesByUser(ns, us, test.args.user)
sort.Slice(got, func(i, j int) bool {
return got[i].ID < got[j].ID
})
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff)
}
} }
}) })
} }
@ -940,6 +1170,12 @@ func Test_expandAlias(t *testing.T) {
return s return s
} }
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "joe"},
{Model: gorm.Model{ID: 2}, Name: "marc"},
{Model: gorm.Model{ID: 3}, Name: "mickael"},
}
type field struct { type field struct {
pol ACLPolicy pol ACLPolicy
} }
@ -989,19 +1225,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1022,19 +1258,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1185,7 +1421,7 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1194,7 +1430,7 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1203,11 +1439,11 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"}, User: users[0],
}, },
}, },
}, },
@ -1260,21 +1496,21 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1295,12 +1531,12 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1309,11 +1545,11 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1350,12 +1586,12 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
}, },
@ -1368,6 +1604,7 @@ func Test_expandAlias(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got, err := test.field.pol.ExpandAlias( got, err := test.field.pol.ExpandAlias(
test.args.nodes, test.args.nodes,
users,
test.args.alias, test.args.alias,
) )
if (err != nil) != test.wantErr { if (err != nil) != test.wantErr {
@ -1715,6 +1952,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.CompileFilterRules( got, err := tt.field.pol.CompileFilterRules(
[]types.User{},
tt.args.nodes, tt.args.nodes,
) )
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -1834,6 +2072,13 @@ func TestTheInternet(t *testing.T) {
} }
func TestReduceFilterRules(t *testing.T) { func TestReduceFilterRules(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "mickael"},
{Model: gorm.Model{ID: 2}, Name: "user1"},
{Model: gorm.Model{ID: 3}, Name: "user2"},
{Model: gorm.Model{ID: 4}, Name: "user100"},
}
tests := []struct { tests := []struct {
name string name string
node *types.Node node *types.Node
@ -1855,13 +2100,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: types.User{Name: "mickael"}, User: users[0],
}, },
peers: types.Nodes{ peers: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: types.User{Name: "mickael"}, User: users[0],
}, },
}, },
want: []tailcfg.FilterRule{}, want: []tailcfg.FilterRule{},
@ -1888,7 +2133,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{ RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"), netip.MustParsePrefix("10.33.0.0/16"),
@ -1899,7 +2144,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -1967,19 +2212,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
peers: types.Nodes{ peers: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
// "internal" exit node // "internal" exit node
&types.Node{ &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -2026,12 +2271,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2113,7 +2358,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -2122,12 +2367,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2215,7 +2460,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
}, },
@ -2224,12 +2469,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2292,7 +2537,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
}, },
@ -2301,12 +2546,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2362,7 +2607,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
}, },
@ -2372,7 +2617,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2400,6 +2645,7 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, _ := tt.pol.CompileFilterRules( got, _ := tt.pol.CompileFilterRules(
users,
append(tt.peers, tt.node), append(tt.peers, tt.node),
) )
@ -3391,7 +3637,7 @@ func TestSSHRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers)
assert.NoError(t, err) assert.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
@ -3474,14 +3720,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
RequestTags: []string{"tag:test"}, RequestTags: []string{"tag:test"},
} }
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
Hostname: "testnodes", Hostname: "testnodes",
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 0, UserID: 0,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
@ -3498,7 +3747,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3532,6 +3781,7 @@ func TestInvalidTagValidUser(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: types.User{
Model: gorm.Model{ID: 1},
Name: "user1", Name: "user1",
}, },
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -3549,7 +3799,7 @@ func TestInvalidTagValidUser(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3583,6 +3833,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: types.User{
Model: gorm.Model{ID: 1},
Name: "user1", Name: "user1",
}, },
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -3608,7 +3859,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3637,15 +3888,17 @@ func TestValidTagInvalidUser(t *testing.T) {
Hostname: "webserver", Hostname: "webserver",
RequestTags: []string{"tag:webapp"}, RequestTags: []string{"tag:webapp"},
} }
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{ node := &types.Node{
ID: 1, ID: 1,
Hostname: "webserver", Hostname: "webserver",
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
@ -3660,9 +3913,7 @@ func TestValidTagInvalidUser(t *testing.T) {
Hostname: "user", Hostname: "user",
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
UserID: 1, UserID: 1,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo2, Hostinfo: &hostInfo2,
} }
@ -3678,7 +3929,7 @@ func TestValidTagInvalidUser(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{

View file

@ -2,6 +2,8 @@ package types
import ( import (
"cmp" "cmp"
"database/sql"
"net/mail"
"strconv" "strconv"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -22,19 +24,19 @@ type User struct {
// Username for the user, is used if email is empty // Username for the user, is used if email is empty
// Should not be used, please use Username(). // Should not be used, please use Username().
Name string `gorm:"unique"` Name sql.NullString `gorm:"unique"`
// Typically the full name of the user // Typically the full name of the user
DisplayName string DisplayName string
// Email of the user // Email of the user
// Should not be used, please use Username(). // Should not be used, please use Username().
Email string Email sql.NullString
// Unique identifier of the user from OIDC, // Unique identifier of the user from OIDC,
// comes from `sub` claim in the OIDC token // comes from `sub` claim in the OIDC token
// and is used to lookup the user. // and is used to lookup the user.
ProviderIdentifier string `gorm:"index"` ProviderIdentifier sql.NullString `gorm:"index"`
// Provider is the origin of the user account, // Provider is the origin of the user account,
// same as RegistrationMethod, without authkey. // same as RegistrationMethod, without authkey.
@ -51,7 +53,7 @@ type User struct {
// should be used throughout headscale, in information returned to the // should be used throughout headscale, in information returned to the
// user and the Policy engine. // user and the Policy engine.
func (u *User) Username() string { func (u *User) Username() string {
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) return cmp.Or(u.Email.String, u.Name.String, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10))
} }
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise // DisplayNameOrUsername returns the DisplayName if it exists, otherwise
@ -103,11 +105,11 @@ func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
func (u *User) Proto() *v1.User { func (u *User) Proto() *v1.User {
return &v1.User{ return &v1.User{
Id: strconv.FormatUint(uint64(u.ID), util.Base10), Id: strconv.FormatUint(uint64(u.ID), util.Base10),
Name: u.Name, Name: u.Name.String,
CreatedAt: timestamppb.New(u.CreatedAt), CreatedAt: timestamppb.New(u.CreatedAt),
DisplayName: u.DisplayName, DisplayName: u.DisplayName,
Email: u.Email, Email: u.Email.String,
ProviderId: u.ProviderIdentifier, ProviderId: u.ProviderIdentifier.String,
Provider: u.Provider, Provider: u.Provider,
ProfilePicUrl: u.ProfilePicURL, ProfilePicUrl: u.ProfilePicURL,
} }
@ -129,10 +131,20 @@ type OIDCClaims struct {
// FromClaim overrides a User from OIDC claims. // FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID. // All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) { func (u *User) FromClaim(claims *OIDCClaims) {
u.ProviderIdentifier = claims.Sub err := util.CheckForFQDNRules(claims.Username)
if err == nil {
u.Name.String = claims.Username
}
if claims.EmailVerified {
_, err = mail.ParseAddress(claims.Email)
if err == nil {
u.Email.String = claims.Email
}
}
u.ProviderIdentifier.String = claims.Sub
u.DisplayName = claims.Name u.DisplayName = claims.Name
u.Email = claims.Email
u.Name = claims.Username
u.ProfilePicURL = claims.ProfilePictureURL u.ProfilePicURL = claims.ProfilePictureURL
u.Provider = util.RegisterMethodOIDC u.Provider = util.RegisterMethodOIDC
} }