mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
Merge 45f5b15a9f
into e2d5ee0927
This commit is contained in:
commit
2ea807835a
8 changed files with 291 additions and 129 deletions
|
@ -1026,14 +1026,18 @@ func (h *Headscale) loadACLPolicy() error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||
}
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||
}
|
||||
|
||||
_, err = pol.CompileFilterRules(nodes)
|
||||
_, err = pol.CompileFilterRules(users, nodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verifying policy rules: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
|
||||
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verifying SSH rules: %w", err)
|
||||
}
|
||||
|
|
|
@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(testPeers), check.Equals, 9)
|
||||
|
||||
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
|
||||
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
|
||||
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
||||
|
|
|
@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes(
|
|||
if approvedAlias == node.User.Username() {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
} else {
|
||||
users, err := ListUsers(tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("looking up users to expand route alias: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias)
|
||||
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
||||
}
|
||||
|
|
|
@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||
}
|
||||
users, err := api.h.db.ListUsers()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||
}
|
||||
|
||||
_, err = pol.CompileFilterRules(nodes)
|
||||
_, err = pol.CompileFilterRules(users, nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verifying policy rules: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) > 0 {
|
||||
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
|
||||
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
||||
}
|
||||
|
|
|
@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
|||
func (m *Mapper) fullMapResponse(
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
users []types.User,
|
||||
pol *policy.ACLPolicy,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
|
@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse(
|
|||
pol,
|
||||
node,
|
||||
capVer,
|
||||
users,
|
||||
peers,
|
||||
peers,
|
||||
m.cfg,
|
||||
|
@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users, err := m.db.ListUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
|
||||
resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
users, err := m.db.ListUsers()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listing users for map response: %w", err)
|
||||
}
|
||||
|
||||
var removedIDs []tailcfg.NodeID
|
||||
var changedIDs []types.NodeID
|
||||
for nodeID, nodeChanged := range changed {
|
||||
|
@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse(
|
|||
pol,
|
||||
node,
|
||||
mapRequest.Version,
|
||||
users,
|
||||
peers,
|
||||
changedNodes,
|
||||
m.cfg,
|
||||
|
@ -508,16 +520,17 @@ func appendPeerChanges(
|
|||
pol *policy.ACLPolicy,
|
||||
node *types.Node,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
users []types.User,
|
||||
peers types.Nodes,
|
||||
changed types.Nodes,
|
||||
cfg *types.Config,
|
||||
) error {
|
||||
packetFilter, err := pol.CompileFilterRules(append(peers, node))
|
||||
packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
|
||||
sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
|
||||
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
||||
|
||||
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
|
||||
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
|
||||
|
||||
mini := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: mustMK(
|
||||
|
@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
IPv4: iap("100.64.0.1"),
|
||||
Hostname: "mini",
|
||||
GivenName: "mini",
|
||||
UserID: 0,
|
||||
User: types.User{Name: "mini"},
|
||||
UserID: user1.ID,
|
||||
User: user1,
|
||||
ForcedTags: []string{},
|
||||
AuthKey: &types.PreAuthKey{},
|
||||
LastSeen: &lastSeen,
|
||||
|
@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
IPv4: iap("100.64.0.2"),
|
||||
Hostname: "peer1",
|
||||
GivenName: "peer1",
|
||||
UserID: 0,
|
||||
User: types.User{Name: "mini"},
|
||||
UserID: user1.ID,
|
||||
User: user1,
|
||||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
|
@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
IPv4: iap("100.64.0.3"),
|
||||
Hostname: "peer2",
|
||||
GivenName: "peer2",
|
||||
UserID: 1,
|
||||
User: types.User{Name: "peer2"},
|
||||
UserID: user2.ID,
|
||||
User: user2,
|
||||
ForcedTags: []string{},
|
||||
LastSeen: &lastSeen,
|
||||
Expiry: &expire,
|
||||
|
@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) {
|
|||
got, err := mappy.fullMapResponse(
|
||||
tt.node,
|
||||
tt.peers,
|
||||
[]types.User{user1, user2},
|
||||
tt.pol,
|
||||
0,
|
||||
)
|
||||
|
|
|
@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests(
|
|||
policy *ACLPolicy,
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
users []types.User,
|
||||
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
||||
// If there is no policy defined, we default to allow all
|
||||
if policy == nil {
|
||||
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
||||
}
|
||||
|
||||
rules, err := policy.CompileFilterRules(append(peers, node))
|
||||
rules, err := policy.CompileFilterRules(users, append(peers, node))
|
||||
if err != nil {
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
|
||||
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
|
||||
|
||||
sshPolicy, err := policy.CompileSSHPolicy(node, peers)
|
||||
sshPolicy, err := policy.CompileSSHPolicy(node, users, peers)
|
||||
if err != nil {
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
|
@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests(
|
|||
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *ACLPolicy) CompileFilterRules(
|
||||
users []types.User,
|
||||
nodes types.Nodes,
|
||||
) ([]tailcfg.FilterRule, error) {
|
||||
if pol == nil {
|
||||
|
@ -176,7 +178,7 @@ func (pol *ACLPolicy) CompileFilterRules(
|
|||
|
||||
var srcIPs []string
|
||||
for srcIndex, src := range acl.Sources {
|
||||
srcs, err := pol.expandSource(src, nodes)
|
||||
srcs, err := pol.expandSource(src, users, nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err)
|
||||
}
|
||||
|
@ -197,6 +199,7 @@ func (pol *ACLPolicy) CompileFilterRules(
|
|||
|
||||
expanded, err := pol.ExpandAlias(
|
||||
nodes,
|
||||
users,
|
||||
alias,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -281,6 +284,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
|||
|
||||
func (pol *ACLPolicy) CompileSSHPolicy(
|
||||
node *types.Node,
|
||||
users []types.User,
|
||||
peers types.Nodes,
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil {
|
||||
|
@ -312,7 +316,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
|||
for index, sshACL := range pol.SSHs {
|
||||
var dest netipx.IPSetBuilder
|
||||
for _, src := range sshACL.Destinations {
|
||||
expanded, err := pol.ExpandAlias(append(peers, node), src)
|
||||
expanded, err := pol.ExpandAlias(append(peers, node), users, src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -363,6 +367,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
|||
} else {
|
||||
expandedSrcs, err := pol.ExpandAlias(
|
||||
peers,
|
||||
users,
|
||||
rawSrc,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -512,9 +517,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
|
|||
// with the given src alias.
|
||||
func (pol *ACLPolicy) expandSource(
|
||||
src string,
|
||||
users []types.User,
|
||||
nodes types.Nodes,
|
||||
) ([]string, error) {
|
||||
ipSet, err := pol.ExpandAlias(nodes, src)
|
||||
ipSet, err := pol.ExpandAlias(nodes, users, src)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
@ -538,6 +544,7 @@ func (pol *ACLPolicy) expandSource(
|
|||
// and transform these in IPAddresses.
|
||||
func (pol *ACLPolicy) ExpandAlias(
|
||||
nodes types.Nodes,
|
||||
users []types.User,
|
||||
alias string,
|
||||
) (*netipx.IPSet, error) {
|
||||
if isWildcard(alias) {
|
||||
|
@ -552,12 +559,12 @@ func (pol *ACLPolicy) ExpandAlias(
|
|||
|
||||
// if alias is a group
|
||||
if isGroup(alias) {
|
||||
return pol.expandIPsFromGroup(alias, nodes)
|
||||
return pol.expandIPsFromGroup(alias, users, nodes)
|
||||
}
|
||||
|
||||
// if alias is a tag
|
||||
if isTag(alias) {
|
||||
return pol.expandIPsFromTag(alias, nodes)
|
||||
return pol.expandIPsFromTag(alias, users, nodes)
|
||||
}
|
||||
|
||||
if isAutoGroup(alias) {
|
||||
|
@ -565,7 +572,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
|||
}
|
||||
|
||||
// if alias is a user
|
||||
if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil {
|
||||
if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil {
|
||||
return ips, err
|
||||
}
|
||||
|
||||
|
@ -574,7 +581,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
|||
if h, ok := pol.Hosts[alias]; ok {
|
||||
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
||||
|
||||
return pol.ExpandAlias(nodes, h.String())
|
||||
return pol.ExpandAlias(nodes, users, h.String())
|
||||
}
|
||||
|
||||
// if alias is an IP
|
||||
|
@ -751,16 +758,17 @@ func (pol *ACLPolicy) expandUsersFromGroup(
|
|||
|
||||
func (pol *ACLPolicy) expandIPsFromGroup(
|
||||
group string,
|
||||
users []types.User,
|
||||
nodes types.Nodes,
|
||||
) (*netipx.IPSet, error) {
|
||||
var build netipx.IPSetBuilder
|
||||
|
||||
users, err := pol.expandUsersFromGroup(group)
|
||||
userTokens, err := pol.expandUsersFromGroup(group)
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, err
|
||||
}
|
||||
for _, user := range users {
|
||||
filteredNodes := filterNodesByUser(nodes, user)
|
||||
for _, user := range userTokens {
|
||||
filteredNodes := filterNodesByUser(nodes, users, user)
|
||||
for _, node := range filteredNodes {
|
||||
node.AppendToIPSet(&build)
|
||||
}
|
||||
|
@ -771,6 +779,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
|
|||
|
||||
func (pol *ACLPolicy) expandIPsFromTag(
|
||||
alias string,
|
||||
users []types.User,
|
||||
nodes types.Nodes,
|
||||
) (*netipx.IPSet, error) {
|
||||
var build netipx.IPSetBuilder
|
||||
|
@ -803,7 +812,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
|||
|
||||
// filter out nodes per tag owner
|
||||
for _, user := range owners {
|
||||
nodes := filterNodesByUser(nodes, user)
|
||||
nodes := filterNodesByUser(nodes, users, user)
|
||||
for _, node := range nodes {
|
||||
if node.Hostinfo == nil {
|
||||
continue
|
||||
|
@ -820,11 +829,12 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
|||
|
||||
func (pol *ACLPolicy) expandIPsFromUser(
|
||||
user string,
|
||||
users []types.User,
|
||||
nodes types.Nodes,
|
||||
) (*netipx.IPSet, error) {
|
||||
var build netipx.IPSetBuilder
|
||||
|
||||
filteredNodes := filterNodesByUser(nodes, user)
|
||||
filteredNodes := filterNodesByUser(nodes, users, user)
|
||||
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
|
||||
|
||||
// shortcurcuit if we have no nodes to get ips from.
|
||||
|
@ -953,10 +963,43 @@ func (pol *ACLPolicy) TagsOfNode(
|
|||
return validTags, invalidTags
|
||||
}
|
||||
|
||||
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
|
||||
// filterNodesByUser returns a list of nodes that match the given userToken from a
|
||||
// policy.
|
||||
// Matching nodes are determined by first matching the user token to a user by checking:
|
||||
// - If it is an ID that mactches the user database ID
|
||||
// - It is the Provider Identifier from OIDC
|
||||
// - It matches the username or email of a user
|
||||
//
|
||||
// If the token matches more than one user, zero nodes will returned.
|
||||
func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes {
|
||||
var out types.Nodes
|
||||
|
||||
var potentialUsers []types.User
|
||||
for _, user := range users {
|
||||
if id, err := strconv.ParseUint(userToken, 10, 64); err == nil && user.ID == uint(id) {
|
||||
potentialUsers = append(potentialUsers, user)
|
||||
break
|
||||
}
|
||||
if user.ProviderIdentifier == userToken {
|
||||
potentialUsers = append(potentialUsers, user)
|
||||
break
|
||||
}
|
||||
if user.Email == userToken {
|
||||
potentialUsers = append(potentialUsers, user)
|
||||
}
|
||||
if user.Name == userToken {
|
||||
potentialUsers = append(potentialUsers, user)
|
||||
}
|
||||
}
|
||||
|
||||
if len(potentialUsers) != 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
user := potentialUsers[0]
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.User.Username() == user {
|
||||
if node.User.ID == user.ID {
|
||||
out = append(out, node)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
@ -375,18 +376,24 @@ func TestParsing(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
rules, err := pol.CompileFilterRules(types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.100.100.100"),
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "testuser",
|
||||
}
|
||||
rules, err := pol.CompileFilterRules(
|
||||
[]types.User{
|
||||
user,
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("200.200.200.200"),
|
||||
User: types.User{
|
||||
Name: "testuser",
|
||||
types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.100.100.100"),
|
||||
},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
})
|
||||
&types.Node{
|
||||
IPv4: iap("200.200.200.200"),
|
||||
User: user,
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
})
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@ -533,7 +540,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
|
|||
c.Assert(pol.ACLs, check.HasLen, 6)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, err := pol.CompileFilterRules(types.Nodes{})
|
||||
rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{})
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(rules, check.IsNil)
|
||||
}
|
||||
|
@ -549,7 +556,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
|
|||
},
|
||||
},
|
||||
}
|
||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
|
||||
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
|
||||
}
|
||||
|
||||
|
@ -568,7 +575,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
|
|||
},
|
||||
},
|
||||
}
|
||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
|
||||
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
|
||||
}
|
||||
|
||||
|
@ -584,7 +591,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
|
|||
},
|
||||
}
|
||||
|
||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
|
||||
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
|
||||
}
|
||||
|
||||
|
@ -861,6 +868,13 @@ func Test_expandPorts(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_listNodesInUser(t *testing.T) {
|
||||
users := []types.User{
|
||||
{Model: gorm.Model{ID: 1}, Name: "marc"},
|
||||
{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"},
|
||||
{Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"},
|
||||
{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"},
|
||||
}
|
||||
|
||||
type args struct {
|
||||
nodes types.Nodes
|
||||
user string
|
||||
|
@ -874,47 +888,101 @@ func Test_listNodesInUser(t *testing.T) {
|
|||
name: "1 node in user",
|
||||
args: args{
|
||||
nodes: types.Nodes{
|
||||
&types.Node{User: types.User{Name: "joe"}},
|
||||
&types.Node{User: users[1]},
|
||||
},
|
||||
user: "joe",
|
||||
},
|
||||
want: types.Nodes{
|
||||
&types.Node{User: types.User{Name: "joe"}},
|
||||
&types.Node{User: users[1]},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "3 nodes, 2 in user",
|
||||
args: args{
|
||||
nodes: types.Nodes{
|
||||
&types.Node{ID: 1, User: types.User{Name: "joe"}},
|
||||
&types.Node{ID: 2, User: types.User{Name: "marc"}},
|
||||
&types.Node{ID: 3, User: types.User{Name: "marc"}},
|
||||
&types.Node{ID: 1, User: users[1]},
|
||||
&types.Node{ID: 2, User: users[0]},
|
||||
&types.Node{ID: 3, User: users[0]},
|
||||
},
|
||||
user: "marc",
|
||||
},
|
||||
want: types.Nodes{
|
||||
&types.Node{ID: 2, User: types.User{Name: "marc"}},
|
||||
&types.Node{ID: 3, User: types.User{Name: "marc"}},
|
||||
&types.Node{ID: 2, User: users[0]},
|
||||
&types.Node{ID: 3, User: users[0]},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "5 nodes, 0 in user",
|
||||
args: args{
|
||||
nodes: types.Nodes{
|
||||
&types.Node{ID: 1, User: types.User{Name: "joe"}},
|
||||
&types.Node{ID: 2, User: types.User{Name: "marc"}},
|
||||
&types.Node{ID: 3, User: types.User{Name: "marc"}},
|
||||
&types.Node{ID: 4, User: types.User{Name: "marc"}},
|
||||
&types.Node{ID: 5, User: types.User{Name: "marc"}},
|
||||
&types.Node{ID: 1, User: users[1]},
|
||||
&types.Node{ID: 2, User: users[0]},
|
||||
&types.Node{ID: 3, User: users[0]},
|
||||
&types.Node{ID: 4, User: users[0]},
|
||||
&types.Node{ID: 5, User: users[0]},
|
||||
},
|
||||
user: "mickael",
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "match-by-userid",
|
||||
args: args{
|
||||
nodes: types.Nodes{
|
||||
&types.Node{ID: 1, User: users[1]},
|
||||
&types.Node{ID: 2, User: users[0]},
|
||||
&types.Node{ID: 3, User: users[0]},
|
||||
&types.Node{ID: 4, User: users[0]},
|
||||
&types.Node{ID: 5, User: users[0]},
|
||||
},
|
||||
user: "2",
|
||||
},
|
||||
want: types.Nodes{
|
||||
&types.Node{ID: 1, User: users[1]},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "match-by-provider-ident",
|
||||
args: args{
|
||||
nodes: types.Nodes{
|
||||
&types.Node{ID: 1, User: users[1]},
|
||||
&types.Node{ID: 2, User: users[2]},
|
||||
},
|
||||
user: "http://oidc.org/1234",
|
||||
},
|
||||
want: types.Nodes{
|
||||
&types.Node{ID: 2, User: users[2]},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "match-by-email",
|
||||
args: args{
|
||||
nodes: types.Nodes{
|
||||
&types.Node{ID: 1, User: users[1]},
|
||||
&types.Node{ID: 2, User: users[2]},
|
||||
},
|
||||
user: "joe@headscale.net",
|
||||
},
|
||||
want: types.Nodes{
|
||||
&types.Node{ID: 1, User: users[1]},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi-match-is-zero",
|
||||
args: args{
|
||||
nodes: types.Nodes{
|
||||
&types.Node{ID: 1, User: users[1]},
|
||||
&types.Node{ID: 2, User: users[2]},
|
||||
&types.Node{ID: 3, User: users[3]},
|
||||
},
|
||||
user: "mikael@headscale.net",
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := filterNodesByUser(test.args.nodes, test.args.user)
|
||||
got := filterNodesByUser(test.args.nodes, users, test.args.user)
|
||||
|
||||
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff)
|
||||
|
@ -940,6 +1008,12 @@ func Test_expandAlias(t *testing.T) {
|
|||
return s
|
||||
}
|
||||
|
||||
users := []types.User{
|
||||
{Model: gorm.Model{ID: 1}, Name: "joe"},
|
||||
{Model: gorm.Model{ID: 2}, Name: "marc"},
|
||||
{Model: gorm.Model{ID: 3}, Name: "mickael"},
|
||||
}
|
||||
|
||||
type field struct {
|
||||
pol ACLPolicy
|
||||
}
|
||||
|
@ -989,19 +1063,19 @@ func Test_expandAlias(t *testing.T) {
|
|||
nodes: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.3"),
|
||||
User: types.User{Name: "marc"},
|
||||
User: users[1],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.4"),
|
||||
User: types.User{Name: "mickael"},
|
||||
User: users[2],
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1022,19 +1096,19 @@ func Test_expandAlias(t *testing.T) {
|
|||
nodes: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.3"),
|
||||
User: types.User{Name: "marc"},
|
||||
User: users[1],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.4"),
|
||||
User: types.User{Name: "mickael"},
|
||||
User: users[2],
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1185,7 +1259,7 @@ func Test_expandAlias(t *testing.T) {
|
|||
nodes: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
|
@ -1194,7 +1268,7 @@ func Test_expandAlias(t *testing.T) {
|
|||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
|
@ -1203,11 +1277,11 @@ func Test_expandAlias(t *testing.T) {
|
|||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.3"),
|
||||
User: types.User{Name: "marc"},
|
||||
User: users[1],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.4"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1260,21 +1334,21 @@ func Test_expandAlias(t *testing.T) {
|
|||
nodes: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
ForcedTags: []string{"tag:hr-webserver"},
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
ForcedTags: []string{"tag:hr-webserver"},
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.3"),
|
||||
User: types.User{Name: "marc"},
|
||||
User: users[1],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.4"),
|
||||
User: types.User{Name: "mickael"},
|
||||
User: users[2],
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1295,12 +1369,12 @@ func Test_expandAlias(t *testing.T) {
|
|||
nodes: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
ForcedTags: []string{"tag:hr-webserver"},
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "foo",
|
||||
|
@ -1309,11 +1383,11 @@ func Test_expandAlias(t *testing.T) {
|
|||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.3"),
|
||||
User: types.User{Name: "marc"},
|
||||
User: users[1],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.4"),
|
||||
User: types.User{Name: "mickael"},
|
||||
User: users[2],
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1350,12 +1424,12 @@ func Test_expandAlias(t *testing.T) {
|
|||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.3"),
|
||||
User: types.User{Name: "marc"},
|
||||
User: users[1],
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.4"),
|
||||
User: types.User{Name: "joe"},
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
},
|
||||
|
@ -1368,6 +1442,7 @@ func Test_expandAlias(t *testing.T) {
|
|||
t.Run(test.name, func(t *testing.T) {
|
||||
got, err := test.field.pol.ExpandAlias(
|
||||
test.args.nodes,
|
||||
users,
|
||||
test.args.alias,
|
||||
)
|
||||
if (err != nil) != test.wantErr {
|
||||
|
@ -1715,6 +1790,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.field.pol.CompileFilterRules(
|
||||
[]types.User{},
|
||||
tt.args.nodes,
|
||||
)
|
||||
if (err != nil) != tt.wantErr {
|
||||
|
@ -1834,6 +1910,13 @@ func TestTheInternet(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestReduceFilterRules(t *testing.T) {
|
||||
users := []types.User{
|
||||
{Model: gorm.Model{ID: 1}, Name: "mickael"},
|
||||
{Model: gorm.Model{ID: 2}, Name: "user1"},
|
||||
{Model: gorm.Model{ID: 3}, Name: "user2"},
|
||||
{Model: gorm.Model{ID: 4}, Name: "user100"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
node *types.Node
|
||||
|
@ -1855,13 +1938,13 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
||||
User: types.User{Name: "mickael"},
|
||||
User: users[0],
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
|
||||
User: types.User{Name: "mickael"},
|
||||
User: users[0],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{},
|
||||
|
@ -1888,7 +1971,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1"},
|
||||
User: users[1],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.33.0.0/16"),
|
||||
|
@ -1899,7 +1982,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user1"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -1967,19 +2050,19 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1"},
|
||||
User: users[1],
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2"},
|
||||
User: users[2],
|
||||
},
|
||||
// "internal" exit node
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.100"),
|
||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
|
@ -2026,12 +2109,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2"},
|
||||
User: users[2],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -2113,7 +2196,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: iap("100.64.0.100"),
|
||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
|
@ -2122,12 +2205,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2"},
|
||||
User: users[2],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -2215,7 +2298,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: iap("100.64.0.100"),
|
||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
|
||||
},
|
||||
|
@ -2224,12 +2307,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2"},
|
||||
User: users[2],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -2292,7 +2375,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: iap("100.64.0.100"),
|
||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
|
||||
},
|
||||
|
@ -2301,12 +2384,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: iap("100.64.0.2"),
|
||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||
User: types.User{Name: "user2"},
|
||||
User: users[2],
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -2362,7 +2445,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: iap("100.64.0.100"),
|
||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||
User: types.User{Name: "user100"},
|
||||
User: users[3],
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||
},
|
||||
|
@ -2372,7 +2455,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: iap("100.64.0.1"),
|
||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||
User: types.User{Name: "user1"},
|
||||
User: users[1],
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
@ -2400,6 +2483,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _ := tt.pol.CompileFilterRules(
|
||||
users,
|
||||
append(tt.peers, tt.node),
|
||||
)
|
||||
|
||||
|
@ -3391,7 +3475,7 @@ func TestSSHRules(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
|
||||
got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
|
@ -3474,14 +3558,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
|
|||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
}
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
Hostname: "testnodes",
|
||||
IPv4: iap("100.64.0.1"),
|
||||
UserID: 0,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
},
|
||||
ID: 0,
|
||||
Hostname: "testnodes",
|
||||
IPv4: iap("100.64.0.1"),
|
||||
UserID: 0,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
|
@ -3498,7 +3585,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user})
|
||||
assert.NoError(t, err)
|
||||
|
||||
want := []tailcfg.FilterRule{
|
||||
|
@ -3532,7 +3619,8 @@ func TestInvalidTagValidUser(t *testing.T) {
|
|||
IPv4: iap("100.64.0.1"),
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &hostInfo,
|
||||
|
@ -3549,7 +3637,7 @@ func TestInvalidTagValidUser(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
|
||||
assert.NoError(t, err)
|
||||
|
||||
want := []tailcfg.FilterRule{
|
||||
|
@ -3583,7 +3671,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
|||
IPv4: iap("100.64.0.1"),
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
},
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &hostInfo,
|
||||
|
@ -3608,7 +3697,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
|||
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||
|
||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
|
||||
assert.NoError(t, err)
|
||||
|
||||
want := []tailcfg.FilterRule{
|
||||
|
@ -3637,15 +3726,17 @@ func TestValidTagInvalidUser(t *testing.T) {
|
|||
Hostname: "webserver",
|
||||
RequestTags: []string{"tag:webapp"},
|
||||
}
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "user1",
|
||||
}
|
||||
|
||||
node := &types.Node{
|
||||
ID: 1,
|
||||
Hostname: "webserver",
|
||||
IPv4: iap("100.64.0.1"),
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
},
|
||||
ID: 1,
|
||||
Hostname: "webserver",
|
||||
IPv4: iap("100.64.0.1"),
|
||||
UserID: 1,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &hostInfo,
|
||||
}
|
||||
|
@ -3656,13 +3747,11 @@ func TestValidTagInvalidUser(t *testing.T) {
|
|||
}
|
||||
|
||||
nodes2 := &types.Node{
|
||||
ID: 2,
|
||||
Hostname: "user",
|
||||
IPv4: iap("100.64.0.2"),
|
||||
UserID: 1,
|
||||
User: types.User{
|
||||
Name: "user1",
|
||||
},
|
||||
ID: 2,
|
||||
Hostname: "user",
|
||||
IPv4: iap("100.64.0.2"),
|
||||
UserID: 1,
|
||||
User: user,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &hostInfo2,
|
||||
}
|
||||
|
@ -3678,7 +3767,7 @@ func TestValidTagInvalidUser(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2})
|
||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user})
|
||||
assert.NoError(t, err)
|
||||
|
||||
want := []tailcfg.FilterRule{
|
||||
|
|
Loading…
Reference in a new issue