Compare commits

...

4 commits

Author SHA1 Message Date
Kristoffer Dalby
f8ec54d816
only set username and email if valid
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:20:19 +01:00
Kristoffer Dalby
31d398c793
ensure provider id is found out of order
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:20:18 +01:00
Kristoffer Dalby
3dc452dee4
resolve user identifier to stable ID
currently, the policy approach node to user matching
with a quite naive approach looking at the username
provided in the policy and matched it with the username
on the nodes. This worked ok as long as usernames were
unique and did not change.

As usernames are no longer guarenteed to be unique in
an OIDC environment we cant rely on this.

This changes the mechanism that matches the user string
(now user token) with nodes:

- first find all potential users by looking up:
  - database ID
  - provider ID (OIDC)
  - username/email

If more than one user is matching, then the query is
rejected, and zero matching nodes are returned.

When a single user is found, the node is matched against
the User database ID, which are also present on the actual
node.

This means that from this commit, users can use the following
to identify users in the policy:
- provider identity (iss + sub)
- username
- email
- database id

There are more changes coming to this, so it is not recommended
to start using any of these new abilities, with the exception
of email, which will not change since it includes an @.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:20:12 +01:00
enoperm
5fbf3f8327
Websocket derp test fixes (#2247)
* integration testing: add and validate build-time options for tailscale head

* fixup! integration testing: add and validate build-time options for tailscale head

integration testing: comply with linter

* fixup! fixup! integration testing: add and validate build-time options for tailscale head

integration testing: tsic.New must never return nil

* fixup! fixup! fixup! integration testing: add and validate build-time options for tailscale head

* minor fixes
2024-11-22 11:57:01 +01:00
12 changed files with 558 additions and 145 deletions

View file

@ -28,7 +28,9 @@ ARG VERSION_GIT_HASH=""
ENV VERSION_GIT_HASH=$VERSION_GIT_HASH ENV VERSION_GIT_HASH=$VERSION_GIT_HASH
ARG TARGETARCH ARG TARGETARCH
RUN GOARCH=$TARGETARCH go install -ldflags="\ ARG BUILD_TAGS=""
RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\
-X tailscale.com/version.longStamp=$VERSION_LONG \ -X tailscale.com/version.longStamp=$VERSION_LONG \
-X tailscale.com/version.shortStamp=$VERSION_SHORT \ -X tailscale.com/version.shortStamp=$VERSION_SHORT \
-X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \ -X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \

View file

@ -1027,14 +1027,18 @@ func (h *Headscale) loadACLPolicy() error {
if err != nil { if err != nil {
return fmt.Errorf("loading nodes from database to validate policy: %w", err) return fmt.Errorf("loading nodes from database to validate policy: %w", err)
} }
users, err := h.db.ListUsers()
if err != nil {
return fmt.Errorf("loading users from database to validate policy: %w", err)
}
_, err = pol.CompileFilterRules(nodes) _, err = pol.CompileFilterRules(users, nodes)
if err != nil { if err != nil {
return fmt.Errorf("verifying policy rules: %w", err) return fmt.Errorf("verifying policy rules: %w", err)
} }
if len(nodes) > 0 { if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes) _, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
if err != nil { if err != nil {
return fmt.Errorf("verifying SSH rules: %w", err) return fmt.Errorf("verifying SSH rules: %w", err)
} }

View file

@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(testPeers), check.Equals, 9) c.Assert(len(testPeers), check.Equals, 9)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers) adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers) testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)

View file

@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes(
if approvedAlias == node.User.Username() { if approvedAlias == node.User.Username() {
approvedRoutes = append(approvedRoutes, advertisedRoute) approvedRoutes = append(approvedRoutes, advertisedRoute)
} else { } else {
users, err := ListUsers(tx)
if err != nil {
return fmt.Errorf("looking up users to expand route alias: %w", err)
}
// TODO(kradalby): figure out how to get this to depend on less stuff // TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
if err != nil { if err != nil {
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
} }

View file

@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy(
if err != nil { if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err) return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
} }
users, err := api.h.db.ListUsers()
if err != nil {
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
}
_, err = pol.CompileFilterRules(nodes) _, err = pol.CompileFilterRules(users, nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("verifying policy rules: %w", err) return nil, fmt.Errorf("verifying policy rules: %w", err)
} }
if len(nodes) > 0 { if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes) _, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err) return nil, fmt.Errorf("verifying SSH rules: %w", err)
} }

View file

@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
func (m *Mapper) fullMapResponse( func (m *Mapper) fullMapResponse(
node *types.Node, node *types.Node,
peers types.Nodes, peers types.Nodes,
users []types.User,
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse(
pol, pol,
node, node,
capVer, capVer,
users,
peers, peers,
peers, peers,
m.cfg, m.cfg,
@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse(
if err != nil { if err != nil {
return nil, err return nil, err
} }
users, err := m.db.ListUsers()
if err != nil {
return nil, err
}
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version) resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse(
return nil, err return nil, err
} }
users, err := m.db.ListUsers()
if err != nil {
return nil, fmt.Errorf("listing users for map response: %w", err)
}
var removedIDs []tailcfg.NodeID var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed { for nodeID, nodeChanged := range changed {
@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse(
pol, pol,
node, node,
mapRequest.Version, mapRequest.Version,
users,
peers, peers,
changedNodes, changedNodes,
m.cfg, m.cfg,
@ -508,16 +520,17 @@ func appendPeerChanges(
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
node *types.Node, node *types.Node,
capVer tailcfg.CapabilityVersion, capVer tailcfg.CapabilityVersion,
users []types.User,
peers types.Nodes, peers types.Nodes,
changed types.Nodes, changed types.Nodes,
cfg *types.Config, cfg *types.Config,
) error { ) error {
packetFilter, err := pol.CompileFilterRules(append(peers, node)) packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
if err != nil { if err != nil {
return err return err
} }
sshPolicy, err := pol.CompileSSHPolicy(node, peers) sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
if err != nil { if err != nil {
return err return err
} }

View file

@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
mini := &types.Node{ mini := &types.Node{
ID: 0, ID: 0,
MachineKey: mustMK( MachineKey: mustMK(
@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
Hostname: "mini", Hostname: "mini",
GivenName: "mini", GivenName: "mini",
UserID: 0, UserID: user1.ID,
User: types.User{Name: "mini"}, User: user1,
ForcedTags: []string{}, ForcedTags: []string{},
AuthKey: &types.PreAuthKey{}, AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
Hostname: "peer1", Hostname: "peer1",
GivenName: "peer1", GivenName: "peer1",
UserID: 0, UserID: user1.ID,
User: types.User{Name: "mini"}, User: user1,
ForcedTags: []string{}, ForcedTags: []string{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Expiry: &expire, Expiry: &expire,
@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
Hostname: "peer2", Hostname: "peer2",
GivenName: "peer2", GivenName: "peer2",
UserID: 1, UserID: user2.ID,
User: types.User{Name: "peer2"}, User: user2,
ForcedTags: []string{}, ForcedTags: []string{},
LastSeen: &lastSeen, LastSeen: &lastSeen,
Expiry: &expire, Expiry: &expire,
@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) {
got, err := mappy.fullMapResponse( got, err := mappy.fullMapResponse(
tt.node, tt.node,
tt.peers, tt.peers,
[]types.User{user1, user2},
tt.pol, tt.pol,
0, 0,
) )

View file

@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests(
policy *ACLPolicy, policy *ACLPolicy,
node *types.Node, node *types.Node,
peers types.Nodes, peers types.Nodes,
users []types.User,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all // If there is no policy defined, we default to allow all
if policy == nil { if policy == nil {
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
} }
rules, err := policy.CompileFilterRules(append(peers, node)) rules, err := policy.CompileFilterRules(users, append(peers, node))
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
sshPolicy, err := policy.CompileSSHPolicy(node, peers) sshPolicy, err := policy.CompileSSHPolicy(node, users, peers)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests(
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a // CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients. // set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) CompileFilterRules( func (pol *ACLPolicy) CompileFilterRules(
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
if pol == nil { if pol == nil {
@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules(
var srcIPs []string var srcIPs []string
for srcIndex, src := range acl.Sources { for srcIndex, src := range acl.Sources {
srcs, err := pol.expandSource(src, nodes) srcs, err := pol.expandSource(src, users, nodes)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err) return nil, fmt.Errorf(
"parsing policy, acl index: %d->%d: %w",
index,
srcIndex,
err,
)
} }
srcIPs = append(srcIPs, srcs...) srcIPs = append(srcIPs, srcs...)
} }
@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules(
expanded, err := pol.ExpandAlias( expanded, err := pol.ExpandAlias(
nodes, nodes,
users,
alias, alias,
) )
if err != nil { if err != nil {
@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
func (pol *ACLPolicy) CompileSSHPolicy( func (pol *ACLPolicy) CompileSSHPolicy(
node *types.Node, node *types.Node,
users []types.User,
peers types.Nodes, peers types.Nodes,
) (*tailcfg.SSHPolicy, error) { ) (*tailcfg.SSHPolicy, error) {
if pol == nil { if pol == nil {
@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
for index, sshACL := range pol.SSHs { for index, sshACL := range pol.SSHs {
var dest netipx.IPSetBuilder var dest netipx.IPSetBuilder
for _, src := range sshACL.Destinations { for _, src := range sshACL.Destinations {
expanded, err := pol.ExpandAlias(append(peers, node), src) expanded, err := pol.ExpandAlias(append(peers, node), users, src)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy(
case "check": case "check":
checkAction, err := sshCheckAction(sshACL.CheckPeriod) checkAction, err := sshCheckAction(sshACL.CheckPeriod)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err) return nil, fmt.Errorf(
"parsing SSH policy, parsing check duration, index: %d: %w",
index,
err,
)
} else { } else {
action = *checkAction action = *checkAction
} }
default: default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err) return nil, fmt.Errorf(
"parsing SSH policy, unknown action %q, index: %d: %w",
sshACL.Action,
index,
err,
)
} }
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
} else { } else {
expandedSrcs, err := pol.ExpandAlias( expandedSrcs, err := pol.ExpandAlias(
peers, peers,
users,
rawSrc, rawSrc,
) )
if err != nil { if err != nil {
@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// with the given src alias. // with the given src alias.
func (pol *ACLPolicy) expandSource( func (pol *ACLPolicy) expandSource(
src string, src string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) ([]string, error) { ) ([]string, error) {
ipSet, err := pol.ExpandAlias(nodes, src) ipSet, err := pol.ExpandAlias(nodes, users, src)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource(
// and transform these in IPAddresses. // and transform these in IPAddresses.
func (pol *ACLPolicy) ExpandAlias( func (pol *ACLPolicy) ExpandAlias(
nodes types.Nodes, nodes types.Nodes,
users []types.User,
alias string, alias string,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
if isWildcard(alias) { if isWildcard(alias) {
@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group // if alias is a group
if isGroup(alias) { if isGroup(alias) {
return pol.expandIPsFromGroup(alias, nodes) return pol.expandIPsFromGroup(alias, users, nodes)
} }
// if alias is a tag // if alias is a tag
if isTag(alias) { if isTag(alias) {
return pol.expandIPsFromTag(alias, nodes) return pol.expandIPsFromTag(alias, users, nodes)
} }
if isAutoGroup(alias) { if isAutoGroup(alias) {
@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias(
} }
// if alias is a user // if alias is a user
if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil {
return ips, err return ips, err
} }
@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias(
if h, ok := pol.Hosts[alias]; ok { if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.ExpandAlias(nodes, h.String()) return pol.ExpandAlias(nodes, users, h.String())
} }
// if alias is an IP // if alias is an IP
@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup(
func (pol *ACLPolicy) expandIPsFromGroup( func (pol *ACLPolicy) expandIPsFromGroup(
group string, group string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
users, err := pol.expandUsersFromGroup(group) userTokens, err := pol.expandUsersFromGroup(group)
if err != nil { if err != nil {
return &netipx.IPSet{}, err return &netipx.IPSet{}, err
} }
for _, user := range users { for _, user := range userTokens {
filteredNodes := filterNodesByUser(nodes, user) filteredNodes := filterNodesByUser(nodes, users, user)
for _, node := range filteredNodes { for _, node := range filteredNodes {
node.AppendToIPSet(&build) node.AppendToIPSet(&build)
} }
@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
func (pol *ACLPolicy) expandIPsFromTag( func (pol *ACLPolicy) expandIPsFromTag(
alias string, alias string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
// filter out nodes per tag owner // filter out nodes per tag owner
for _, user := range owners { for _, user := range owners {
nodes := filterNodesByUser(nodes, user) nodes := filterNodesByUser(nodes, users, user)
for _, node := range nodes { for _, node := range nodes {
if node.Hostinfo == nil { if node.Hostinfo == nil {
continue continue
@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag(
func (pol *ACLPolicy) expandIPsFromUser( func (pol *ACLPolicy) expandIPsFromUser(
user string, user string,
users []types.User,
nodes types.Nodes, nodes types.Nodes,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder var build netipx.IPSetBuilder
filteredNodes := filterNodesByUser(nodes, user) filteredNodes := filterNodesByUser(nodes, users, user)
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
// shortcurcuit if we have no nodes to get ips from. // shortcurcuit if we have no nodes to get ips from.
@ -953,10 +977,43 @@ func (pol *ACLPolicy) TagsOfNode(
return validTags, invalidTags return validTags, invalidTags
} }
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { // filterNodesByUser returns a list of nodes that match the given userToken from a
// policy.
// Matching nodes are determined by first matching the user token to a user by checking:
// - If it is an ID that mactches the user database ID
// - It is the Provider Identifier from OIDC
// - It matches the username or email of a user
//
// If the token matches more than one user, zero nodes will returned.
func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes {
var out types.Nodes var out types.Nodes
var potentialUsers []types.User
for _, user := range users {
if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == userToken {
// If a user is matching with a known unique field,
// disgard all other users and only keep the current
// user.
potentialUsers = []types.User{user}
break
}
if user.Email == userToken {
potentialUsers = append(potentialUsers, user)
}
if user.Name == userToken {
potentialUsers = append(potentialUsers, user)
}
}
if len(potentialUsers) != 1 {
return nil
}
user := potentialUsers[0]
for _, node := range nodes { for _, node := range nodes {
if node.User.Username() == user { if node.User.ID == user.ID {
out = append(out, node) out = append(out, node)
} }
} }

View file

@ -1,9 +1,12 @@
package policy package policy
import ( import (
"database/sql"
"errors" "errors"
"math/rand/v2"
"net/netip" "net/netip"
"slices" "slices"
"sort"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -14,6 +17,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@ -375,15 +379,21 @@ func TestParsing(t *testing.T) {
return return
} }
rules, err := pol.CompileFilterRules(types.Nodes{ user := types.User{
Model: gorm.Model{ID: 1},
Name: "testuser",
}
rules, err := pol.CompileFilterRules(
[]types.User{
user,
},
types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.100.100.100"), IPv4: iap("100.100.100.100"),
}, },
&types.Node{ &types.Node{
IPv4: iap("200.200.200.200"), IPv4: iap("200.200.200.200"),
User: types.User{ User: user,
Name: "testuser",
},
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
}) })
@ -533,7 +543,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := pol.CompileFilterRules(types.Nodes{}) rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{})
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil) c.Assert(rules, check.IsNil)
} }
@ -549,7 +559,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
} }
@ -568,7 +578,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
} }
@ -584,7 +594,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
}, },
} }
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}) _, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
} }
@ -860,7 +870,20 @@ func Test_expandPorts(t *testing.T) {
} }
} }
func Test_listNodesInUser(t *testing.T) { func Test_filterNodesByUser(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "marc"},
{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"},
{Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: sql.NullString{String: "http://oidc.org/1234", Valid: true}},
{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"},
{Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"},
{Model: gorm.Model{ID: 6}, Name: "http://oidc.org/1234", Email: "mikael@headscale.net"},
{Model: gorm.Model{ID: 7}, Name: "1"},
{Model: gorm.Model{ID: 8}, Name: "alex", Email: "alex@headscale.net"},
{Model: gorm.Model{ID: 9}, Name: "alex@headscale.net"},
{Model: gorm.Model{ID: 10}, Email: "http://oidc.org/1234"},
}
type args struct { type args struct {
nodes types.Nodes nodes types.Nodes
user string user string
@ -874,50 +897,258 @@ func Test_listNodesInUser(t *testing.T) {
name: "1 node in user", name: "1 node in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{User: types.User{Name: "joe"}}, &types.Node{User: users[1]},
}, },
user: "joe", user: "joe",
}, },
want: types.Nodes{ want: types.Nodes{
&types.Node{User: types.User{Name: "joe"}}, &types.Node{User: users[1]},
}, },
}, },
{ {
name: "3 nodes, 2 in user", name: "3 nodes, 2 in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}}, &types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
}, },
user: "marc", user: "marc",
}, },
want: types.Nodes{ want: types.Nodes{
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
}, },
}, },
{ {
name: "5 nodes, 0 in user", name: "5 nodes, 0 in user",
args: args{ args: args{
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}}, &types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: types.User{Name: "marc"}}, &types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: types.User{Name: "marc"}}, &types.Node{ID: 3, User: users[0]},
&types.Node{ID: 4, User: types.User{Name: "marc"}}, &types.Node{ID: 4, User: users[0]},
&types.Node{ID: 5, User: types.User{Name: "marc"}}, &types.Node{ID: 5, User: users[0]},
}, },
user: "mickael", user: "mickael",
}, },
want: nil, want: nil,
}, },
{
name: "match-by-provider-ident",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[2]},
},
},
{
name: "match-by-email",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
&types.Node{ID: 8, User: users[7]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[1]},
},
},
{
name: "multi-match-is-zero",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
&types.Node{ID: 3, User: users[3]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-email-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match email, then provider id
&types.Node{ID: 3, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-username-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match username, then provider id
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-duplicate-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-unique-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "marc",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[0]},
},
},
{
name: "all-users-no-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 8, User: users[7]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[1]},
},
},
{
name: "email-as-username-duplicate",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[7]},
&types.Node{ID: 2, User: users[8]},
},
user: "alex@headscale.net",
},
want: nil,
},
{
name: "all-users-no-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working@headscale.net",
},
want: nil,
},
{
name: "all-users-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 6, User: users[5]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 3, User: users[2]},
},
},
{
name: "all-users-no-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
&types.Node{ID: 6, User: users[5]},
},
user: "http://oidc.org/4321",
},
want: nil,
},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got := filterNodesByUser(test.args.nodes, test.args.user) for range 1000 {
ns := test.args.nodes
rand.Shuffle(len(ns), func(i, j int) {
ns[i], ns[j] = ns[j], ns[i]
})
us := users
rand.Shuffle(len(us), func(i, j int) {
us[i], us[j] = us[j], us[i]
})
got := filterNodesByUser(ns, us, test.args.user)
sort.Slice(got, func(i, j int) bool {
return got[i].ID < got[j].ID
})
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff)
}
} }
}) })
} }
@ -940,6 +1171,12 @@ func Test_expandAlias(t *testing.T) {
return s return s
} }
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "joe"},
{Model: gorm.Model{ID: 2}, Name: "marc"},
{Model: gorm.Model{ID: 3}, Name: "mickael"},
}
type field struct { type field struct {
pol ACLPolicy pol ACLPolicy
} }
@ -989,19 +1226,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1022,19 +1259,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1185,7 +1422,7 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1194,7 +1431,7 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1203,11 +1440,11 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"}, User: users[0],
}, },
}, },
}, },
@ -1260,21 +1497,21 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1295,12 +1532,12 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{ nodes: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"}, User: users[0],
ForcedTags: []string{"tag:hr-webserver"}, ForcedTags: []string{"tag:hr-webserver"},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
OS: "centos", OS: "centos",
Hostname: "foo", Hostname: "foo",
@ -1309,11 +1546,11 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"}, User: users[2],
}, },
}, },
}, },
@ -1350,12 +1587,12 @@ func Test_expandAlias(t *testing.T) {
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.3"), IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"}, User: users[1],
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.4"), IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"}, User: users[0],
Hostinfo: &tailcfg.Hostinfo{}, Hostinfo: &tailcfg.Hostinfo{},
}, },
}, },
@ -1368,6 +1605,7 @@ func Test_expandAlias(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
got, err := test.field.pol.ExpandAlias( got, err := test.field.pol.ExpandAlias(
test.args.nodes, test.args.nodes,
users,
test.args.alias, test.args.alias,
) )
if (err != nil) != test.wantErr { if (err != nil) != test.wantErr {
@ -1715,6 +1953,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.CompileFilterRules( got, err := tt.field.pol.CompileFilterRules(
[]types.User{},
tt.args.nodes, tt.args.nodes,
) )
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -1834,6 +2073,13 @@ func TestTheInternet(t *testing.T) {
} }
func TestReduceFilterRules(t *testing.T) { func TestReduceFilterRules(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "mickael"},
{Model: gorm.Model{ID: 2}, Name: "user1"},
{Model: gorm.Model{ID: 3}, Name: "user2"},
{Model: gorm.Model{ID: 4}, Name: "user100"},
}
tests := []struct { tests := []struct {
name string name string
node *types.Node node *types.Node
@ -1855,13 +2101,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: types.User{Name: "mickael"}, User: users[0],
}, },
peers: types.Nodes{ peers: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: types.User{Name: "mickael"}, User: users[0],
}, },
}, },
want: []tailcfg.FilterRule{}, want: []tailcfg.FilterRule{},
@ -1888,7 +2134,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{ RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"), netip.MustParsePrefix("10.33.0.0/16"),
@ -1899,7 +2145,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -1967,19 +2213,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
peers: types.Nodes{ peers: types.Nodes{
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
// "internal" exit node // "internal" exit node
&types.Node{ &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -2026,12 +2272,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2113,7 +2359,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(), RoutableIPs: tsaddr.ExitRoutes(),
}, },
@ -2122,12 +2368,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2215,7 +2461,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
}, },
@ -2224,12 +2470,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2292,7 +2538,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
}, },
@ -2301,12 +2547,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"), IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"}, User: users[2],
}, },
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2362,7 +2608,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"}, User: users[3],
Hostinfo: &tailcfg.Hostinfo{ Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
}, },
@ -2372,7 +2618,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{ &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"}, User: users[1],
}, },
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
@ -2400,6 +2646,7 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, _ := tt.pol.CompileFilterRules( got, _ := tt.pol.CompileFilterRules(
users,
append(tt.peers, tt.node), append(tt.peers, tt.node),
) )
@ -3391,7 +3638,7 @@ func TestSSHRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers) got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers)
assert.NoError(t, err) assert.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
@ -3474,14 +3721,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
RequestTags: []string{"tag:test"}, RequestTags: []string{"tag:test"},
} }
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{ node := &types.Node{
ID: 0, ID: 0,
Hostname: "testnodes", Hostname: "testnodes",
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 0, UserID: 0,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
@ -3498,7 +3748,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3532,6 +3782,7 @@ func TestInvalidTagValidUser(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: types.User{
Model: gorm.Model{ID: 1},
Name: "user1", Name: "user1",
}, },
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -3549,7 +3800,7 @@ func TestInvalidTagValidUser(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3583,6 +3834,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: types.User{
Model: gorm.Model{ID: 1},
Name: "user1", Name: "user1",
}, },
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
@ -3608,7 +3860,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -3637,15 +3889,17 @@ func TestValidTagInvalidUser(t *testing.T) {
Hostname: "webserver", Hostname: "webserver",
RequestTags: []string{"tag:webapp"}, RequestTags: []string{"tag:webapp"},
} }
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{ node := &types.Node{
ID: 1, ID: 1,
Hostname: "webserver", Hostname: "webserver",
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
UserID: 1, UserID: 1,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo, Hostinfo: &hostInfo,
} }
@ -3660,9 +3914,7 @@ func TestValidTagInvalidUser(t *testing.T) {
Hostname: "user", Hostname: "user",
IPv4: iap("100.64.0.2"), IPv4: iap("100.64.0.2"),
UserID: 1, UserID: 1,
User: types.User{ User: user,
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey, RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo2, Hostinfo: &hostInfo2,
} }
@ -3678,7 +3930,7 @@ func TestValidTagInvalidUser(t *testing.T) {
}, },
} }
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}) got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{

View file

@ -2,6 +2,8 @@ package types
import ( import (
"cmp" "cmp"
"database/sql"
"net/mail"
"strconv" "strconv"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -34,7 +36,7 @@ type User struct {
// Unique identifier of the user from OIDC, // Unique identifier of the user from OIDC,
// comes from `sub` claim in the OIDC token // comes from `sub` claim in the OIDC token
// and is used to lookup the user. // and is used to lookup the user.
ProviderIdentifier string `gorm:"index"` ProviderIdentifier sql.NullString `gorm:"index"`
// Provider is the origin of the user account, // Provider is the origin of the user account,
// same as RegistrationMethod, without authkey. // same as RegistrationMethod, without authkey.
@ -51,7 +53,7 @@ type User struct {
// should be used throughout headscale, in information returned to the // should be used throughout headscale, in information returned to the
// user and the Policy engine. // user and the Policy engine.
func (u *User) Username() string { func (u *User) Username() string {
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) return cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10))
} }
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise // DisplayNameOrUsername returns the DisplayName if it exists, otherwise
@ -107,7 +109,7 @@ func (u *User) Proto() *v1.User {
CreatedAt: timestamppb.New(u.CreatedAt), CreatedAt: timestamppb.New(u.CreatedAt),
DisplayName: u.DisplayName, DisplayName: u.DisplayName,
Email: u.Email, Email: u.Email,
ProviderId: u.ProviderIdentifier, ProviderId: u.ProviderIdentifier.String,
Provider: u.Provider, Provider: u.Provider,
ProfilePicUrl: u.ProfilePicURL, ProfilePicUrl: u.ProfilePicURL,
} }
@ -129,10 +131,20 @@ type OIDCClaims struct {
// FromClaim overrides a User from OIDC claims. // FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID. // All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) { func (u *User) FromClaim(claims *OIDCClaims) {
u.ProviderIdentifier = claims.Sub err := util.CheckForFQDNRules(claims.Username)
u.DisplayName = claims.Name if err == nil {
u.Email = claims.Email
u.Name = claims.Username u.Name = claims.Username
}
if claims.EmailVerified {
_, err = mail.ParseAddress(claims.Email)
if err == nil {
u.Email = claims.Email
}
}
u.ProviderIdentifier.String = claims.Sub
u.DisplayName = claims.Name
u.ProfilePicURL = claims.ProfilePictureURL u.ProfilePicURL = claims.ProfilePictureURL
u.Provider = util.RegisterMethodOIDC u.Provider = util.RegisterMethodOIDC
} }

View file

@ -55,7 +55,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
spec := map[string]ClientsSpec{ spec := map[string]ClientsSpec{
"user1": { "user1": {
Plain: 0, Plain: 0,
WebsocketDERP: len(MustTestVersions), WebsocketDERP: 2,
}, },
} }
@ -239,10 +239,13 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv(
if clientCount.WebsocketDERP > 0 { if clientCount.WebsocketDERP > 0 {
// Containers that use DERP-over-WebSocket // Containers that use DERP-over-WebSocket
// Note that these clients *must* be built
// from source, which is currently
// only done for HEAD.
err = s.CreateTailscaleIsolatedNodesInUser( err = s.CreateTailscaleIsolatedNodesInUser(
hash, hash,
userName, userName,
"all", tsic.VersionHead,
clientCount.WebsocketDERP, clientCount.WebsocketDERP,
tsic.WithWebsocketDERP(true), tsic.WithWebsocketDERP(true),
) )

View file

@ -12,6 +12,7 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"os" "os"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -44,6 +45,11 @@ var (
errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey") errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey")
errTailscaleNotConnected = errors.New("tailscale not connected") errTailscaleNotConnected = errors.New("tailscale not connected")
errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login") errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login")
errInvalidClientConfig = errors.New("verifiably invalid client config requested")
)
const (
VersionHead = "head"
) )
func errTailscaleStatus(hostname string, err error) error { func errTailscaleStatus(hostname string, err error) error {
@ -74,6 +80,13 @@ type TailscaleInContainer struct {
withExtraHosts []string withExtraHosts []string
workdir string workdir string
netfilter string netfilter string
// build options, solely for HEAD
buildConfig TailscaleInContainerBuildConfig
}
type TailscaleInContainerBuildConfig struct {
tags []string
} }
// Option represent optional settings that can be given to a // Option represent optional settings that can be given to a
@ -175,6 +188,22 @@ func WithNetfilter(state string) Option {
} }
} }
// WithBuildTag adds an additional value to the `-tags=` parameter
// of the Go compiler, allowing callers to customize the Tailscale client build.
// This option is only meaningful when invoked on **HEAD** versions of the client.
// Attempts to use it with any other version is a bug in the calling code.
func WithBuildTag(tag string) Option {
return func(tsic *TailscaleInContainer) {
if tsic.version != VersionHead {
panic(errInvalidClientConfig)
}
tsic.buildConfig.tags = append(
tsic.buildConfig.tags, tag,
)
}
}
// New returns a new TailscaleInContainer instance. // New returns a new TailscaleInContainer instance.
func New( func New(
pool *dockertest.Pool, pool *dockertest.Pool,
@ -219,6 +248,12 @@ func New(
} }
if tsic.withWebsocketDERP { if tsic.withWebsocketDERP {
if version != VersionHead {
return tsic, errInvalidClientConfig
}
WithBuildTag("ts_debug_websockets")(tsic)
tailscaleOptions.Env = append( tailscaleOptions.Env = append(
tailscaleOptions.Env, tailscaleOptions.Env,
fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP), fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP),
@ -245,14 +280,36 @@ func New(
} }
var container *dockertest.Resource var container *dockertest.Resource
if version != VersionHead {
// build options are not meaningful with pre-existing images,
// let's not lead anyone astray by pretending otherwise.
defaultBuildConfig := TailscaleInContainerBuildConfig{}
hasBuildConfig := !reflect.DeepEqual(defaultBuildConfig, tsic.buildConfig)
if hasBuildConfig {
return tsic, errInvalidClientConfig
}
}
switch version { switch version {
case "head": case VersionHead:
buildOptions := &dockertest.BuildOptions{ buildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.tailscale-HEAD", Dockerfile: "Dockerfile.tailscale-HEAD",
ContextDir: dockerContextPath, ContextDir: dockerContextPath,
BuildArgs: []docker.BuildArg{}, BuildArgs: []docker.BuildArg{},
} }
buildTags := strings.Join(tsic.buildConfig.tags, ",")
if len(buildTags) > 0 {
buildOptions.BuildArgs = append(
buildOptions.BuildArgs,
docker.BuildArg{
Name: "BUILD_TAGS",
Value: buildTags,
},
)
}
container, err = pool.BuildAndRunWithBuildOptions( container, err = pool.BuildAndRunWithBuildOptions(
buildOptions, buildOptions,
tailscaleOptions, tailscaleOptions,