remove "stripEmailDomain" argument

This commit makes a wrapper function round the normalisation requiring
"stripEmailDomain" which has to be passed in almost all functions of
headscale by loading it from Viper instead.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-12 15:29:34 +02:00 committed by Kristoffer Dalby
parent 161243c787
commit 717abe89c1
16 changed files with 127 additions and 220 deletions

View file

@ -169,7 +169,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
database, err := db.NewHeadscaleDatabase( database, err := db.NewHeadscaleDatabase(
cfg.DBtype, cfg.DBtype,
dbString, dbString,
cfg.OIDC.StripEmaildomain,
app.dbDebug, app.dbDebug,
app.stateUpdateChan, app.stateUpdateChan,
cfg.IPPrefixes, cfg.IPPrefixes,

View file

@ -53,7 +53,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
}, },
} }
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -107,7 +107,7 @@ func TestInvalidTagValidUser(t *testing.T) {
}, },
} }
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -169,7 +169,7 @@ func TestPortGroup(t *testing.T) {
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
assert.NoError(t, err) assert.NoError(t, err)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -224,7 +224,7 @@ func TestPortUser(t *testing.T) {
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
assert.NoError(t, err) assert.NoError(t, err)
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -285,7 +285,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{
@ -361,7 +361,7 @@ func TestValidTagInvalidUser(t *testing.T) {
}, },
} }
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false) got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2})
assert.NoError(t, err) assert.NoError(t, err)
want := []tailcfg.FilterRule{ want := []tailcfg.FilterRule{

View file

@ -43,14 +43,13 @@ type HSDatabase struct {
ipPrefixes []netip.Prefix ipPrefixes []netip.Prefix
baseDomain string baseDomain string
stripEmailDomain bool
} }
// TODO(kradalby): assemble this struct from toptions or something typed // TODO(kradalby): assemble this struct from toptions or something typed
// rather than arguments. // rather than arguments.
func NewHeadscaleDatabase( func NewHeadscaleDatabase(
dbType, connectionAddr string, dbType, connectionAddr string,
stripEmailDomain, debug bool, debug bool,
notifyStateChan chan<- struct{}, notifyStateChan chan<- struct{},
ipPrefixes []netip.Prefix, ipPrefixes []netip.Prefix,
baseDomain string, baseDomain string,
@ -66,7 +65,6 @@ func NewHeadscaleDatabase(
ipPrefixes: ipPrefixes, ipPrefixes: ipPrefixes,
baseDomain: baseDomain, baseDomain: baseDomain,
stripEmailDomain: stripEmailDomain,
} }
log.Debug().Msgf("database %#v", dbConn) log.Debug().Msgf("database %#v", dbConn)
@ -202,9 +200,8 @@ func NewHeadscaleDatabase(
for item, machine := range machines { for item, machine := range machines {
if machine.GivenName == "" { if machine.GivenName == "" {
normalizedHostname, err := util.NormalizeToFQDNRules( normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
machine.Hostname, machine.Hostname,
stripEmailDomain,
) )
if err != nil { if err != nil {
log.Error(). log.Error().

View file

@ -632,9 +632,8 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
} }
func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
normalizedHostname, err := util.NormalizeToFQDNRules( normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
suppliedName, suppliedName,
hsdb.stripEmailDomain,
) )
if err != nil { if err != nil {
return "", err return "", err

View file

@ -293,10 +293,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
testPeers, err := db.ListPeers(testMachine) testPeers, err := db.ListPeers(testMachine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false) adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false) testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules) peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules)
@ -482,9 +482,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}{ }{
{ {
name: "simple machine name generation", name: "simple machine name generation",
db: &HSDatabase{ db: &HSDatabase{},
stripEmailDomain: true,
},
args: args{ args: args{
suppliedName: "testmachine", suppliedName: "testmachine",
randomSuffix: false, randomSuffix: false,
@ -494,9 +492,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}, },
{ {
name: "machine name with 53 chars", name: "machine name with 53 chars",
db: &HSDatabase{ db: &HSDatabase{},
stripEmailDomain: true,
},
args: args{ args: args{
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
randomSuffix: false, randomSuffix: false,
@ -506,9 +502,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}, },
{ {
name: "machine name with 63 chars", name: "machine name with 63 chars",
db: &HSDatabase{ db: &HSDatabase{},
stripEmailDomain: true,
},
args: args{ args: args{
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
randomSuffix: false, randomSuffix: false,
@ -518,9 +512,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}, },
{ {
name: "machine name with 64 chars", name: "machine name with 64 chars",
db: &HSDatabase{ db: &HSDatabase{},
stripEmailDomain: true,
},
args: args{ args: args{
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
randomSuffix: false, randomSuffix: false,
@ -530,9 +522,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}, },
{ {
name: "machine name with 73 chars", name: "machine name with 73 chars",
db: &HSDatabase{ db: &HSDatabase{},
stripEmailDomain: true,
},
args: args{ args: args{
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
randomSuffix: false, randomSuffix: false,
@ -542,9 +532,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}, },
{ {
name: "machine name with random suffix", name: "machine name with random suffix",
db: &HSDatabase{ db: &HSDatabase{},
stripEmailDomain: true,
},
args: args{ args: args{
suppliedName: "test", suppliedName: "test",
randomSuffix: true, randomSuffix: true,
@ -554,9 +542,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
}, },
{ {
name: "machine name with 63 chars with random suffix", name: "machine name with 63 chars with random suffix",
db: &HSDatabase{ db: &HSDatabase{},
stripEmailDomain: true,
},
args: args{ args: args{
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
randomSuffix: true, randomSuffix: true,

View file

@ -424,7 +424,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
approvedRoutes = append(approvedRoutes, advertisedRoute) approvedRoutes = append(approvedRoutes, advertisedRoute)
} else { } else {
// TODO(kradalby): figure out how to get this to depend on less stuff // TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain) approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias)
if err != nil { if err != nil {
log.Err(err). log.Err(err).
Str("alias", approvedAlias). Str("alias", approvedAlias).

View file

@ -60,7 +60,6 @@ func (s *Suite) ResetDB(c *check.C) {
"sqlite3", "sqlite3",
tmpDir+"/headscale_test.db", tmpDir+"/headscale_test.db",
false, false,
false,
sink, sink,
[]netip.Prefix{ []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"), netip.MustParsePrefix("10.27.0.0/23"),

View file

@ -340,7 +340,6 @@ func (api headscaleV1APIServer) ListMachines(
m := machine.Proto() m := machine.Proto()
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
machine, machine,
api.h.cfg.OIDC.StripEmaildomain,
) )
m.InvalidTags = invalidTags m.InvalidTags = invalidTags
m.ValidTags = validTags m.ValidTags = validTags

View file

@ -41,7 +41,6 @@ type Mapper struct {
dnsCfg *tailcfg.DNSConfig dnsCfg *tailcfg.DNSConfig
logtail bool logtail bool
randomClientPort bool randomClientPort bool
stripEmailDomain bool
} }
func NewMapper( func NewMapper(
@ -53,7 +52,6 @@ func NewMapper(
dnsCfg *tailcfg.DNSConfig, dnsCfg *tailcfg.DNSConfig,
logtail bool, logtail bool,
randomClientPort bool, randomClientPort bool,
stripEmailDomain bool,
) *Mapper { ) *Mapper {
return &Mapper{ return &Mapper{
db: db, db: db,
@ -66,7 +64,6 @@ func NewMapper(
dnsCfg: dnsCfg, dnsCfg: dnsCfg,
logtail: logtail, logtail: logtail,
randomClientPort: randomClientPort, randomClientPort: randomClientPort,
stripEmailDomain: stripEmailDomain,
} }
} }
@ -87,14 +84,13 @@ func fullMapResponse(
machine *types.Machine, machine *types.Machine,
peers types.Machines, peers types.Machines,
stripEmailDomain bool,
baseDomain string, baseDomain string,
dnsCfg *tailcfg.DNSConfig, dnsCfg *tailcfg.DNSConfig,
derpMap *tailcfg.DERPMap, derpMap *tailcfg.DERPMap,
logtail bool, logtail bool,
randomClientPort bool, randomClientPort bool,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain, stripEmailDomain) tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -103,7 +99,6 @@ func fullMapResponse(
pol, pol,
machine, machine,
peers, peers,
stripEmailDomain,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -129,7 +124,7 @@ func fullMapResponse(
peers, peers,
) )
tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain, stripEmailDomain) tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -296,7 +291,6 @@ func (m Mapper) CreateMapResponse(
pol, pol,
machine, machine,
peers, peers,
m.stripEmailDomain,
m.baseDomain, m.baseDomain,
m.dnsCfg, m.dnsCfg,
m.derpMap, m.derpMap,

View file

@ -320,7 +320,6 @@ func Test_fullMapResponse(t *testing.T) {
machine *types.Machine machine *types.Machine
peers types.Machines peers types.Machines
stripEmailDomain bool
baseDomain string baseDomain string
dnsConfig *tailcfg.DNSConfig dnsConfig *tailcfg.DNSConfig
derpMap *tailcfg.DERPMap derpMap *tailcfg.DERPMap
@ -335,7 +334,6 @@ func Test_fullMapResponse(t *testing.T) {
// pol: &policy.ACLPolicy{}, // pol: &policy.ACLPolicy{},
// dnsConfig: &tailcfg.DNSConfig{}, // dnsConfig: &tailcfg.DNSConfig{},
// baseDomain: "", // baseDomain: "",
// stripEmailDomain: false,
// want: nil, // want: nil,
// wantErr: true, // wantErr: true,
// }, // },
@ -344,7 +342,6 @@ func Test_fullMapResponse(t *testing.T) {
pol: &policy.ACLPolicy{}, pol: &policy.ACLPolicy{},
machine: mini, machine: mini,
peers: []types.Machine{}, peers: []types.Machine{},
stripEmailDomain: false,
baseDomain: "", baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{}, dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{}, derpMap: &tailcfg.DERPMap{},
@ -375,7 +372,6 @@ func Test_fullMapResponse(t *testing.T) {
peers: []types.Machine{ peers: []types.Machine{
peer1, peer1,
}, },
stripEmailDomain: false,
baseDomain: "", baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{}, dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{}, derpMap: &tailcfg.DERPMap{},
@ -417,7 +413,6 @@ func Test_fullMapResponse(t *testing.T) {
peer1, peer1,
peer2, peer2,
}, },
stripEmailDomain: false,
baseDomain: "", baseDomain: "",
dnsConfig: &tailcfg.DNSConfig{}, dnsConfig: &tailcfg.DNSConfig{},
derpMap: &tailcfg.DERPMap{}, derpMap: &tailcfg.DERPMap{},
@ -458,7 +453,6 @@ func Test_fullMapResponse(t *testing.T) {
tt.pol, tt.pol,
tt.machine, tt.machine,
tt.peers, tt.peers,
tt.stripEmailDomain,
tt.baseDomain, tt.baseDomain,
tt.dnsConfig, tt.dnsConfig,
tt.derpMap, tt.derpMap,

View file

@ -18,7 +18,6 @@ func tailNodes(
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
dnsConfig *tailcfg.DNSConfig, dnsConfig *tailcfg.DNSConfig,
baseDomain string, baseDomain string,
stripEmailDomain bool,
) ([]*tailcfg.Node, error) { ) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(machines)) nodes := make([]*tailcfg.Node, len(machines))
@ -28,7 +27,6 @@ func tailNodes(
pol, pol,
dnsConfig, dnsConfig,
baseDomain, baseDomain,
stripEmailDomain,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -47,7 +45,6 @@ func tailNode(
pol *policy.ACLPolicy, pol *policy.ACLPolicy,
dnsConfig *tailcfg.DNSConfig, dnsConfig *tailcfg.DNSConfig,
baseDomain string, baseDomain string,
stripEmailDomain bool,
) (*tailcfg.Node, error) { ) (*tailcfg.Node, error) {
nodeKey, err := machine.NodePublicKey() nodeKey, err := machine.NodePublicKey()
if err != nil { if err != nil {
@ -107,7 +104,7 @@ func tailNode(
online := machine.IsOnline() online := machine.IsOnline()
tags, _ := pol.GetTagsOfMachine(machine, stripEmailDomain) tags, _ := pol.GetTagsOfMachine(machine)
tags = lo.Uniq(append(tags, machine.ForcedTags...)) tags = lo.Uniq(append(tags, machine.ForcedTags...))
node := tailcfg.Node{ node := tailcfg.Node{

View file

@ -49,7 +49,6 @@ func TestTailNode(t *testing.T) {
pol *policy.ACLPolicy pol *policy.ACLPolicy
dnsConfig *tailcfg.DNSConfig dnsConfig *tailcfg.DNSConfig
baseDomain string baseDomain string
stripEmailDomain bool
want *tailcfg.Node want *tailcfg.Node
wantErr bool wantErr bool
}{ }{
@ -59,7 +58,6 @@ func TestTailNode(t *testing.T) {
pol: &policy.ACLPolicy{}, pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{}, dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "", baseDomain: "",
stripEmailDomain: false,
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
@ -111,7 +109,6 @@ func TestTailNode(t *testing.T) {
pol: &policy.ACLPolicy{}, pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{}, dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "", baseDomain: "",
stripEmailDomain: false,
want: &tailcfg.Node{ want: &tailcfg.Node{
ID: 0, ID: 0,
StableID: "0", StableID: "0",
@ -172,7 +169,6 @@ func TestTailNode(t *testing.T) {
tt.pol, tt.pol,
tt.dnsConfig, tt.dnsConfig,
tt.baseDomain, tt.baseDomain,
tt.stripEmailDomain,
) )
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {

View file

@ -121,14 +121,13 @@ func GenerateFilterRules(
policy *ACLPolicy, policy *ACLPolicy,
machine *types.Machine, machine *types.Machine,
peers types.Machines, peers types.Machines,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all // If there is no policy defined, we default to allow all
if policy == nil { if policy == nil {
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
} }
rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain) rules, err := policy.generateFilterRules(machine, peers)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
@ -136,7 +135,7 @@ func GenerateFilterRules(
log.Trace().Interface("ACL", rules).Msg("ACL rules generated") log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
var sshPolicy *tailcfg.SSHPolicy var sshPolicy *tailcfg.SSHPolicy
sshRules, err := policy.generateSSHRules(machine, peers, stripEmailDomain) sshRules, err := policy.generateSSHRules(machine, peers)
if err != nil { if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
} }
@ -154,7 +153,6 @@ func GenerateFilterRules(
func (pol *ACLPolicy) generateFilterRules( func (pol *ACLPolicy) generateFilterRules(
machine *types.Machine, machine *types.Machine,
peers types.Machines, peers types.Machines,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{} rules := []tailcfg.FilterRule{}
machines := append(peers, *machine) machines := append(peers, *machine)
@ -166,7 +164,7 @@ func (pol *ACLPolicy) generateFilterRules(
srcIPs := []string{} srcIPs := []string{}
for srcIndex, src := range acl.Sources { for srcIndex, src := range acl.Sources {
srcs, err := pol.getIPsFromSource(src, machines, stripEmailDomain) srcs, err := pol.getIPsFromSource(src, machines)
if err != nil { if err != nil {
log.Error(). log.Error().
Interface("src", src). Interface("src", src).
@ -193,7 +191,6 @@ func (pol *ACLPolicy) generateFilterRules(
dest, dest,
machines, machines,
needsWildcard, needsWildcard,
stripEmailDomain,
) )
if err != nil { if err != nil {
log.Error(). log.Error().
@ -220,7 +217,6 @@ func (pol *ACLPolicy) generateFilterRules(
func (pol *ACLPolicy) generateSSHRules( func (pol *ACLPolicy) generateSSHRules(
machine *types.Machine, machine *types.Machine,
peers types.Machines, peers types.Machines,
stripEmailDomain bool,
) ([]*tailcfg.SSHRule, error) { ) ([]*tailcfg.SSHRule, error) {
rules := []*tailcfg.SSHRule{} rules := []*tailcfg.SSHRule{}
@ -247,7 +243,7 @@ func (pol *ACLPolicy) generateSSHRules(
for index, sshACL := range pol.SSHs { for index, sshACL := range pol.SSHs {
var dest netipx.IPSetBuilder var dest netipx.IPSetBuilder
for _, src := range sshACL.Destinations { for _, src := range sshACL.Destinations {
expanded, err := pol.ExpandAlias(append(peers, *machine), src, stripEmailDomain) expanded, err := pol.ExpandAlias(append(peers, *machine), src)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -289,7 +285,7 @@ func (pol *ACLPolicy) generateSSHRules(
Any: true, Any: true,
}) })
} else if isGroup(rawSrc) { } else if isGroup(rawSrc) {
users, err := pol.getUsersInGroup(rawSrc, stripEmailDomain) users, err := pol.getUsersInGroup(rawSrc)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex) Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
@ -306,7 +302,6 @@ func (pol *ACLPolicy) generateSSHRules(
expandedSrcs, err := pol.ExpandAlias( expandedSrcs, err := pol.ExpandAlias(
peers, peers,
rawSrc, rawSrc,
stripEmailDomain,
) )
if err != nil { if err != nil {
log.Error(). log.Error().
@ -358,9 +353,8 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
func (pol *ACLPolicy) getIPsFromSource( func (pol *ACLPolicy) getIPsFromSource(
src string, src string,
machines types.Machines, machines types.Machines,
stripEmaildomain bool,
) ([]string, error) { ) ([]string, error) {
ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain) ipSet, err := pol.ExpandAlias(machines, src)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -380,7 +374,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
dest string, dest string,
machines types.Machines, machines types.Machines,
needsWildcard bool, needsWildcard bool,
stripEmaildomain bool,
) ([]tailcfg.NetPortRange, error) { ) ([]tailcfg.NetPortRange, error) {
var tokens []string var tokens []string
@ -434,7 +427,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
expanded, err := pol.ExpandAlias( expanded, err := pol.ExpandAlias(
machines, machines,
alias, alias,
stripEmaildomain,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -519,7 +511,6 @@ func parseProtocol(protocol string) ([]int, bool, error) {
func (pol *ACLPolicy) ExpandAlias( func (pol *ACLPolicy) ExpandAlias(
machines types.Machines, machines types.Machines,
alias string, alias string,
stripEmailDomain bool,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
if isWildcard(alias) { if isWildcard(alias) {
return util.ParseIPSet("*", nil) return util.ParseIPSet("*", nil)
@ -533,16 +524,16 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group // if alias is a group
if isGroup(alias) { if isGroup(alias) {
return pol.getIPsFromGroup(alias, machines, stripEmailDomain) return pol.getIPsFromGroup(alias, machines)
} }
// if alias is a tag // if alias is a tag
if isTag(alias) { if isTag(alias) {
return pol.getIPsFromTag(alias, machines, stripEmailDomain) return pol.getIPsFromTag(alias, machines)
} }
// if alias is a user // if alias is a user
if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil { if ips, err := pol.getIPsForUser(alias, machines); ips != nil {
return ips, err return ips, err
} }
@ -551,7 +542,7 @@ func (pol *ACLPolicy) ExpandAlias(
if h, ok := pol.Hosts[alias]; ok { if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.ExpandAlias(machines, h.String(), stripEmailDomain) return pol.ExpandAlias(machines, h.String())
} }
// if alias is an IP // if alias is an IP
@ -576,12 +567,11 @@ func excludeCorrectlyTaggedNodes(
aclPolicy *ACLPolicy, aclPolicy *ACLPolicy,
nodes types.Machines, nodes types.Machines,
user string, user string,
stripEmailDomain bool,
) types.Machines { ) types.Machines {
out := types.Machines{} out := types.Machines{}
tags := []string{} tags := []string{}
for tag := range aclPolicy.TagOwners { for tag := range aclPolicy.TagOwners {
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) owners, _ := getTagOwners(aclPolicy, user)
ns := append(owners, user) ns := append(owners, user)
if util.StringOrPrefixListContains(ns, user) { if util.StringOrPrefixListContains(ns, user) {
tags = append(tags, tag) tags = append(tags, tag)
@ -674,7 +664,6 @@ func filterMachinesByUser(machines types.Machines, user string) types.Machines {
func getTagOwners( func getTagOwners(
pol *ACLPolicy, pol *ACLPolicy,
tag string, tag string,
stripEmailDomain bool,
) ([]string, error) { ) ([]string, error) {
var owners []string var owners []string
ows, ok := pol.TagOwners[tag] ows, ok := pol.TagOwners[tag]
@ -687,7 +676,7 @@ func getTagOwners(
} }
for _, owner := range ows { for _, owner := range ows {
if isGroup(owner) { if isGroup(owner) {
gs, err := pol.getUsersInGroup(owner, stripEmailDomain) gs, err := pol.getUsersInGroup(owner)
if err != nil { if err != nil {
return []string{}, err return []string{}, err
} }
@ -704,7 +693,6 @@ func getTagOwners(
// after some validation. // after some validation.
func (pol *ACLPolicy) getUsersInGroup( func (pol *ACLPolicy) getUsersInGroup(
group string, group string,
stripEmailDomain bool,
) ([]string, error) { ) ([]string, error) {
users := []string{} users := []string{}
log.Trace().Caller().Interface("pol", pol).Msg("test") log.Trace().Caller().Interface("pol", pol).Msg("test")
@ -723,7 +711,7 @@ func (pol *ACLPolicy) getUsersInGroup(
ErrInvalidGroup, ErrInvalidGroup,
) )
} }
grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain) grp, err := util.NormalizeToFQDNRulesConfigFromViper(group)
if err != nil { if err != nil {
return []string{}, fmt.Errorf( return []string{}, fmt.Errorf(
"failed to normalize group %q, err: %w", "failed to normalize group %q, err: %w",
@ -740,11 +728,10 @@ func (pol *ACLPolicy) getUsersInGroup(
func (pol *ACLPolicy) getIPsFromGroup( func (pol *ACLPolicy) getIPsFromGroup(
group string, group string,
machines types.Machines, machines types.Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
users, err := pol.getUsersInGroup(group, stripEmailDomain) users, err := pol.getUsersInGroup(group)
if err != nil { if err != nil {
return &netipx.IPSet{}, err return &netipx.IPSet{}, err
} }
@ -761,7 +748,6 @@ func (pol *ACLPolicy) getIPsFromGroup(
func (pol *ACLPolicy) getIPsFromTag( func (pol *ACLPolicy) getIPsFromTag(
alias string, alias string,
machines types.Machines, machines types.Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
@ -773,7 +759,7 @@ func (pol *ACLPolicy) getIPsFromTag(
} }
// find tag owners // find tag owners
owners, err := getTagOwners(pol, alias, stripEmailDomain) owners, err := getTagOwners(pol, alias)
if err != nil { if err != nil {
if errors.Is(err, ErrInvalidTag) { if errors.Is(err, ErrInvalidTag) {
ipSet, _ := build.IPSet() ipSet, _ := build.IPSet()
@ -808,12 +794,11 @@ func (pol *ACLPolicy) getIPsFromTag(
func (pol *ACLPolicy) getIPsForUser( func (pol *ACLPolicy) getIPsForUser(
user string, user string,
machines types.Machines, machines types.Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) { ) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{} build := netipx.IPSetBuilder{}
filteredMachines := filterMachinesByUser(machines, user) filteredMachines := filterMachinesByUser(machines, user)
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain) filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user)
// shortcurcuit if we have no machines to get ips from. // shortcurcuit if we have no machines to get ips from.
if len(filteredMachines) == 0 { if len(filteredMachines) == 0 {
@ -885,7 +870,6 @@ func isTag(str string) bool {
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
func (pol *ACLPolicy) GetTagsOfMachine( func (pol *ACLPolicy) GetTagsOfMachine(
machine types.Machine, machine types.Machine,
stripEmailDomain bool,
) ([]string, []string) { ) ([]string, []string) {
validTags := make([]string, 0) validTags := make([]string, 0)
invalidTags := make([]string, 0) invalidTags := make([]string, 0)
@ -893,7 +877,7 @@ func (pol *ACLPolicy) GetTagsOfMachine(
validTagMap := make(map[string]bool) validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool)
for _, tag := range machine.HostInfo.RequestTags { for _, tag := range machine.HostInfo.RequestTags {
owners, err := getTagOwners(pol, tag, stripEmailDomain) owners, err := getTagOwners(pol, tag)
if errors.Is(err, ErrInvalidTag) { if errors.Is(err, ErrInvalidTag) {
invalidTagMap[tag] = true invalidTagMap[tag] = true

View file

@ -10,6 +10,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go4.org/netipx" "go4.org/netipx"
"gopkg.in/check.v1" "gopkg.in/check.v1"
@ -199,7 +200,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil) c.Assert(rules, check.IsNil)
} }
@ -230,7 +231,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
pol, err := LoadACLPolicyFromBytes(acl, "hujson") pol, err := LoadACLPolicyFromBytes(acl, "hujson")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
} }
@ -246,7 +247,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
} }
@ -265,7 +266,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
}, },
}, },
} }
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
} }
@ -281,7 +282,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
}, },
} }
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
} }
@ -310,7 +311,7 @@ func (s *Suite) TestPortRange(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -366,7 +367,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -401,7 +402,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -428,7 +429,7 @@ acls:
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -459,7 +460,7 @@ acls:
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(pol, check.NotNil) c.Assert(pol, check.NotNil)
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil) c.Assert(rules, check.NotNil)
@ -484,7 +485,7 @@ func Test_expandGroup(t *testing.T) {
} }
type args struct { type args struct {
group string group string
stripEmailDomain bool stripEmail bool
} }
tests := []struct { tests := []struct {
name string name string
@ -505,7 +506,6 @@ func Test_expandGroup(t *testing.T) {
}, },
args: args{ args: args{
group: "group:test", group: "group:test",
stripEmailDomain: true,
}, },
want: []string{"user1", "user2", "user3"}, want: []string{"user1", "user2", "user3"},
wantErr: false, wantErr: false,
@ -522,13 +522,12 @@ func Test_expandGroup(t *testing.T) {
}, },
args: args{ args: args{
group: "group:undefined", group: "group:undefined",
stripEmailDomain: true,
}, },
want: []string{}, want: []string{},
wantErr: true, wantErr: true,
}, },
{ {
name: "Expand emails in group", name: "Expand emails in group strip domains",
field: field{ field: field{
pol: ACLPolicy{ pol: ACLPolicy{
Groups: Groups{ Groups: Groups{
@ -541,7 +540,7 @@ func Test_expandGroup(t *testing.T) {
}, },
args: args{ args: args{
group: "group:admin", group: "group:admin",
stripEmailDomain: true, stripEmail: true,
}, },
want: []string{"joe.bar", "john.doe"}, want: []string{"joe.bar", "john.doe"},
wantErr: false, wantErr: false,
@ -560,7 +559,6 @@ func Test_expandGroup(t *testing.T) {
}, },
args: args{ args: args{
group: "group:admin", group: "group:admin",
stripEmailDomain: false,
}, },
want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"}, want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"},
wantErr: false, wantErr: false,
@ -568,17 +566,20 @@ func Test_expandGroup(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
viper.Set("oidc.strip_email_domain", test.args.stripEmail)
got, err := test.field.pol.getUsersInGroup( got, err := test.field.pol.getUsersInGroup(
test.args.group, test.args.group,
test.args.stripEmailDomain,
) )
if (err != nil) != test.wantErr { if (err != nil) != test.wantErr {
t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr) t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr)
return return
} }
if !reflect.DeepEqual(got, test.want) {
t.Errorf("expandGroup() = %v, want %v", got, test.want) if diff := cmp.Diff(test.want, got); diff != "" {
t.Errorf("expandGroup() unexpected result (-want +got):\n%s", diff)
} }
}) })
} }
@ -588,7 +589,6 @@ func Test_expandTagOwners(t *testing.T) {
type args struct { type args struct {
aclPolicy *ACLPolicy aclPolicy *ACLPolicy
tag string tag string
stripEmailDomain bool
} }
tests := []struct { tests := []struct {
name string name string
@ -603,7 +603,6 @@ func Test_expandTagOwners(t *testing.T) {
TagOwners: TagOwners{"tag:test": []string{"user1"}}, TagOwners: TagOwners{"tag:test": []string{"user1"}},
}, },
tag: "tag:test", tag: "tag:test",
stripEmailDomain: true,
}, },
want: []string{"user1"}, want: []string{"user1"},
wantErr: false, wantErr: false,
@ -616,7 +615,6 @@ func Test_expandTagOwners(t *testing.T) {
TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
}, },
tag: "tag:test", tag: "tag:test",
stripEmailDomain: true,
}, },
want: []string{"user1", "user2"}, want: []string{"user1", "user2"},
wantErr: false, wantErr: false,
@ -629,7 +627,6 @@ func Test_expandTagOwners(t *testing.T) {
TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}}, TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}},
}, },
tag: "tag:test", tag: "tag:test",
stripEmailDomain: true,
}, },
want: []string{"user1", "user2", "user3"}, want: []string{"user1", "user2", "user3"},
wantErr: false, wantErr: false,
@ -641,7 +638,6 @@ func Test_expandTagOwners(t *testing.T) {
TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}}, TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}},
}, },
tag: "tag:test", tag: "tag:test",
stripEmailDomain: true,
}, },
want: []string{}, want: []string{},
wantErr: true, wantErr: true,
@ -654,7 +650,6 @@ func Test_expandTagOwners(t *testing.T) {
TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}}, TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}},
}, },
tag: "tag:test", tag: "tag:test",
stripEmailDomain: true,
}, },
want: []string{}, want: []string{},
wantErr: true, wantErr: true,
@ -665,7 +660,6 @@ func Test_expandTagOwners(t *testing.T) {
got, err := getTagOwners( got, err := getTagOwners(
test.args.aclPolicy, test.args.aclPolicy,
test.args.tag, test.args.tag,
test.args.stripEmailDomain,
) )
if (err != nil) != test.wantErr { if (err != nil) != test.wantErr {
t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr) t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr)
@ -864,7 +858,6 @@ func Test_expandAlias(t *testing.T) {
machines types.Machines machines types.Machines
aclPolicy ACLPolicy aclPolicy ACLPolicy
alias string alias string
stripEmailDomain bool
} }
tests := []struct { tests := []struct {
name string name string
@ -888,7 +881,6 @@ func Test_expandAlias(t *testing.T) {
}, },
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{}, []string{ want: set([]string{}, []string{
"0.0.0.0/0", "0.0.0.0/0",
@ -931,7 +923,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{ want: set([]string{
"100.64.0.1", "100.64.0.2", "100.64.0.3", "100.64.0.1", "100.64.0.2", "100.64.0.3",
@ -973,7 +964,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{}, []string{}), want: set([]string{}, []string{}),
wantErr: true, wantErr: true,
@ -986,7 +976,6 @@ func Test_expandAlias(t *testing.T) {
args: args{ args: args{
alias: "10.0.0.3", alias: "10.0.0.3",
machines: types.Machines{}, machines: types.Machines{},
stripEmailDomain: true,
}, },
want: set([]string{ want: set([]string{
"10.0.0.3", "10.0.0.3",
@ -1001,7 +990,6 @@ func Test_expandAlias(t *testing.T) {
args: args{ args: args{
alias: "10.0.0.1", alias: "10.0.0.1",
machines: types.Machines{}, machines: types.Machines{},
stripEmailDomain: true,
}, },
want: set([]string{ want: set([]string{
"10.0.0.1", "10.0.0.1",
@ -1023,7 +1011,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{ want: set([]string{
"10.0.0.1", "10.0.0.1",
@ -1046,7 +1033,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{ want: set([]string{
"10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222",
@ -1069,7 +1055,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{ want: set([]string{
"fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1",
@ -1088,7 +1073,6 @@ func Test_expandAlias(t *testing.T) {
args: args{ args: args{
alias: "testy", alias: "testy",
machines: types.Machines{}, machines: types.Machines{},
stripEmailDomain: true,
}, },
want: set([]string{}, []string{"10.0.0.132/32"}), want: set([]string{}, []string{"10.0.0.132/32"}),
wantErr: false, wantErr: false,
@ -1105,7 +1089,6 @@ func Test_expandAlias(t *testing.T) {
args: args{ args: args{
alias: "homeNetwork", alias: "homeNetwork",
machines: types.Machines{}, machines: types.Machines{},
stripEmailDomain: true,
}, },
want: set([]string{}, []string{"192.168.1.0/24"}), want: set([]string{}, []string{"192.168.1.0/24"}),
wantErr: false, wantErr: false,
@ -1119,7 +1102,6 @@ func Test_expandAlias(t *testing.T) {
alias: "10.0.0.0/16", alias: "10.0.0.0/16",
machines: types.Machines{}, machines: types.Machines{},
aclPolicy: ACLPolicy{}, aclPolicy: ACLPolicy{},
stripEmailDomain: true,
}, },
want: set([]string{}, []string{"10.0.0.0/16"}), want: set([]string{}, []string{"10.0.0.0/16"}),
wantErr: false, wantErr: false,
@ -1169,7 +1151,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "joe"}, User: types.User{Name: "joe"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{ want: set([]string{
"100.64.0.1", "100.64.0.2", "100.64.0.1", "100.64.0.2",
@ -1214,7 +1195,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{}, []string{}), want: set([]string{}, []string{}),
wantErr: true, wantErr: true,
@ -1254,7 +1234,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
wantErr: false, wantErr: false,
@ -1302,7 +1281,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
wantErr: false, wantErr: false,
@ -1352,7 +1330,6 @@ func Test_expandAlias(t *testing.T) {
User: types.User{Name: "joe"}, User: types.User{Name: "joe"},
}, },
}, },
stripEmailDomain: true,
}, },
want: set([]string{"100.64.0.4"}, []string{}), want: set([]string{"100.64.0.4"}, []string{}),
wantErr: false, wantErr: false,
@ -1363,7 +1340,6 @@ func Test_expandAlias(t *testing.T) {
got, err := test.field.pol.ExpandAlias( got, err := test.field.pol.ExpandAlias(
test.args.machines, test.args.machines,
test.args.alias, test.args.alias,
test.args.stripEmailDomain,
) )
if (err != nil) != test.wantErr { if (err != nil) != test.wantErr {
t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr) t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr)
@ -1382,7 +1358,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
aclPolicy *ACLPolicy aclPolicy *ACLPolicy
nodes types.Machines nodes types.Machines
user string user string
stripEmailDomain bool
} }
tests := []struct { tests := []struct {
name string name string
@ -1427,7 +1402,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
}, },
}, },
user: "joe", user: "joe",
stripEmailDomain: true,
}, },
want: types.Machines{ want: types.Machines{
{ {
@ -1478,7 +1452,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
}, },
}, },
user: "joe", user: "joe",
stripEmailDomain: true,
}, },
want: types.Machines{ want: types.Machines{
{ {
@ -1520,7 +1493,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
}, },
}, },
user: "joe", user: "joe",
stripEmailDomain: true,
}, },
want: types.Machines{ want: types.Machines{
{ {
@ -1566,7 +1538,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
}, },
}, },
user: "joe", user: "joe",
stripEmailDomain: true,
}, },
want: types.Machines{ want: types.Machines{
{ {
@ -1606,7 +1577,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
test.args.aclPolicy, test.args.aclPolicy,
test.args.nodes, test.args.nodes,
test.args.user, test.args.user,
test.args.stripEmailDomain,
) )
if !reflect.DeepEqual(got, test.want) { if !reflect.DeepEqual(got, test.want) {
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want) t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want)
@ -1622,7 +1592,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
type args struct { type args struct {
machine types.Machine machine types.Machine
peers types.Machines peers types.Machines
stripEmailDomain bool
} }
tests := []struct { tests := []struct {
name string name string
@ -1654,7 +1623,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
args: args{ args: args{
machine: types.Machine{}, machine: types.Machine{},
peers: types.Machines{}, peers: types.Machines{},
stripEmailDomain: true,
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
@ -1709,7 +1677,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
User: types.User{Name: "mickael"}, User: types.User{Name: "mickael"},
}, },
}, },
stripEmailDomain: true,
}, },
want: []tailcfg.FilterRule{ want: []tailcfg.FilterRule{
{ {
@ -1743,7 +1710,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
got, err := tt.field.pol.generateFilterRules( got, err := tt.field.pol.generateFilterRules(
&tt.args.machine, &tt.args.machine,
tt.args.peers, tt.args.peers,
tt.args.stripEmailDomain,
) )
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
@ -1763,7 +1729,6 @@ func Test_getTags(t *testing.T) {
type args struct { type args struct {
aclPolicy *ACLPolicy aclPolicy *ACLPolicy
machine types.Machine machine types.Machine
stripEmailDomain bool
} }
tests := []struct { tests := []struct {
name string name string
@ -1787,7 +1752,6 @@ func Test_getTags(t *testing.T) {
RequestTags: []string{"tag:valid"}, RequestTags: []string{"tag:valid"},
}, },
}, },
stripEmailDomain: false,
}, },
wantValid: []string{"tag:valid"}, wantValid: []string{"tag:valid"},
wantInvalid: nil, wantInvalid: nil,
@ -1808,7 +1772,6 @@ func Test_getTags(t *testing.T) {
RequestTags: []string{"tag:valid", "tag:invalid"}, RequestTags: []string{"tag:valid", "tag:invalid"},
}, },
}, },
stripEmailDomain: false,
}, },
wantValid: []string{"tag:valid"}, wantValid: []string{"tag:valid"},
wantInvalid: []string{"tag:invalid"}, wantInvalid: []string{"tag:invalid"},
@ -1833,7 +1796,6 @@ func Test_getTags(t *testing.T) {
}, },
}, },
}, },
stripEmailDomain: false,
}, },
wantValid: []string{"tag:valid"}, wantValid: []string{"tag:valid"},
wantInvalid: []string{"tag:invalid"}, wantInvalid: []string{"tag:invalid"},
@ -1854,7 +1816,6 @@ func Test_getTags(t *testing.T) {
RequestTags: []string{"tag:invalid", "very-invalid"}, RequestTags: []string{"tag:invalid", "very-invalid"},
}, },
}, },
stripEmailDomain: false,
}, },
wantValid: nil, wantValid: nil,
wantInvalid: []string{"tag:invalid", "very-invalid"}, wantInvalid: []string{"tag:invalid", "very-invalid"},
@ -1871,7 +1832,6 @@ func Test_getTags(t *testing.T) {
RequestTags: []string{"tag:invalid", "very-invalid"}, RequestTags: []string{"tag:invalid", "very-invalid"},
}, },
}, },
stripEmailDomain: false,
}, },
wantValid: nil, wantValid: nil,
wantInvalid: []string{"tag:invalid", "very-invalid"}, wantInvalid: []string{"tag:invalid", "very-invalid"},
@ -1881,7 +1841,6 @@ func Test_getTags(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine( gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine(
test.args.machine, test.args.machine,
test.args.stripEmailDomain,
) )
for _, valid := range gotValid { for _, valid := range gotValid {
if !util.StringOrPrefixListContains(test.wantValid, valid) { if !util.StringOrPrefixListContains(test.wantValid, valid) {
@ -2589,7 +2548,7 @@ func TestSSHRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers, false) got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers)
assert.NoError(t, err) assert.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {

View file

@ -40,7 +40,6 @@ func (h *Headscale) handlePoll(
h.cfg.DNSConfig, h.cfg.DNSConfig,
h.cfg.LogTail.Enabled, h.cfg.LogTail.Enabled,
h.cfg.RandomizeClientPort, h.cfg.RandomizeClientPort,
h.cfg.OIDC.StripEmaildomain,
) )
machine.Hostname = mapRequest.Hostinfo.Hostname machine.Hostname = mapRequest.Hostinfo.Hostname
@ -265,7 +264,6 @@ func (h *Headscale) pollNetMapStream(
h.cfg.DNSConfig, h.cfg.DNSConfig,
h.cfg.LogTail.Enabled, h.cfg.LogTail.Enabled,
h.cfg.RandomizeClientPort, h.cfg.RandomizeClientPort,
h.cfg.OIDC.StripEmaildomain,
) )
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
@ -656,7 +654,6 @@ func (h *Headscale) scheduledPollWorker(
h.cfg.DNSConfig, h.cfg.DNSConfig,
h.cfg.LogTail.Enabled, h.cfg.LogTail.Enabled,
h.cfg.RandomizeClientPort, h.cfg.RandomizeClientPort,
h.cfg.OIDC.StripEmaildomain,
) )
keepAliveTicker := time.NewTicker(keepAliveInterval) keepAliveTicker := time.NewTicker(keepAliveInterval)

View file

@ -7,6 +7,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/spf13/viper"
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
) )
@ -24,6 +25,12 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
var ErrInvalidUserName = errors.New("invalid user name") var ErrInvalidUserName = errors.New("invalid user name")
func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) {
strip := viper.GetBool("oidc.strip_email_domain")
return NormalizeToFQDNRules(name, strip)
}
// NormalizeToFQDNRules will replace forbidden chars in user // NormalizeToFQDNRules will replace forbidden chars in user
// it can also return an error if the user doesn't respect RFC 952 and 1123. // it can also return an error if the user doesn't respect RFC 952 and 1123.
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {