diff --git a/hscontrol/app.go b/hscontrol/app.go index 90628136..c654d4a6 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -169,7 +169,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { database, err := db.NewHeadscaleDatabase( cfg.DBtype, dbString, - cfg.OIDC.StripEmaildomain, app.dbDebug, app.stateUpdateChan, cfg.IPPrefixes, diff --git a/hscontrol/db/acls_test.go b/hscontrol/db/acls_test.go index fc822a79..ffbf66e7 100644 --- a/hscontrol/db/acls_test.go +++ b/hscontrol/db/acls_test.go @@ -53,7 +53,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) { }, } - got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -107,7 +107,7 @@ func TestInvalidTagValidUser(t *testing.T) { }, } - got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -169,7 +169,7 @@ func TestPortGroup(t *testing.T) { pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") assert.NoError(t, err) - got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -224,7 +224,7 @@ func TestPortUser(t *testing.T) { pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") assert.NoError(t, err) - got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -285,7 +285,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) { // c.Assert(rules[0].DstPorts, check.HasLen, 1) // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}) assert.NoError(t, err) want := []tailcfg.FilterRule{ @@ -361,7 +361,7 @@ func TestValidTagInvalidUser(t *testing.T) { }, } - got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false) + got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}) assert.NoError(t, err) want := []tailcfg.FilterRule{ diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index fb1089dc..5cff7868 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -41,16 +41,15 @@ type HSDatabase struct { ipAllocationMutex sync.Mutex - ipPrefixes []netip.Prefix - baseDomain string - stripEmailDomain bool + ipPrefixes []netip.Prefix + baseDomain string } // TODO(kradalby): assemble this struct from toptions or something typed // rather than arguments. func NewHeadscaleDatabase( dbType, connectionAddr string, - stripEmailDomain, debug bool, + debug bool, notifyStateChan chan<- struct{}, ipPrefixes []netip.Prefix, baseDomain string, @@ -64,9 +63,8 @@ func NewHeadscaleDatabase( db: dbConn, notifyStateChan: notifyStateChan, - ipPrefixes: ipPrefixes, - baseDomain: baseDomain, - stripEmailDomain: stripEmailDomain, + ipPrefixes: ipPrefixes, + baseDomain: baseDomain, } log.Debug().Msgf("database %#v", dbConn) @@ -202,9 +200,8 @@ func NewHeadscaleDatabase( for item, machine := range machines { if machine.GivenName == "" { - normalizedHostname, err := util.NormalizeToFQDNRules( + normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( machine.Hostname, - stripEmailDomain, ) if err != nil { log.Error(). diff --git a/hscontrol/db/machine.go b/hscontrol/db/machine.go index 1764ce6c..f2139abb 100644 --- a/hscontrol/db/machine.go +++ b/hscontrol/db/machine.go @@ -632,9 +632,8 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string } func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { - normalizedHostname, err := util.NormalizeToFQDNRules( + normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( suppliedName, - hsdb.stripEmailDomain, ) if err != nil { return "", err diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go index 92021e7d..f9522b53 100644 --- a/hscontrol/db/machine_test.go +++ b/hscontrol/db/machine_test.go @@ -293,10 +293,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { testPeers, err := db.ListPeers(testMachine) c.Assert(err, check.IsNil) - adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false) + adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers) c.Assert(err, check.IsNil) - testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false) + testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers) c.Assert(err, check.IsNil) peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules) @@ -482,9 +482,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { }{ { name: "simple machine name generation", - db: &HSDatabase{ - stripEmailDomain: true, - }, + db: &HSDatabase{}, args: args{ suppliedName: "testmachine", randomSuffix: false, @@ -494,9 +492,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 53 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, + db: &HSDatabase{}, args: args{ suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", randomSuffix: false, @@ -506,9 +502,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, + db: &HSDatabase{}, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", randomSuffix: false, @@ -518,9 +512,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 64 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, + db: &HSDatabase{}, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", randomSuffix: false, @@ -530,9 +522,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 73 chars", - db: &HSDatabase{ - stripEmailDomain: true, - }, + db: &HSDatabase{}, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", randomSuffix: false, @@ -542,9 +532,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with random suffix", - db: &HSDatabase{ - stripEmailDomain: true, - }, + db: &HSDatabase{}, args: args{ suppliedName: "test", randomSuffix: true, @@ -554,9 +542,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars with random suffix", - db: &HSDatabase{ - stripEmailDomain: true, - }, + db: &HSDatabase{}, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", randomSuffix: true, diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index bdb3f4c5..74e0afe4 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -424,7 +424,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( approvedRoutes = append(approvedRoutes, advertisedRoute) } else { // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain) + approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias) if err != nil { log.Err(err). Str("alias", approvedAlias). diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 739950db..495a9363 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -60,7 +60,6 @@ func (s *Suite) ResetDB(c *check.C) { "sqlite3", tmpDir+"/headscale_test.db", false, - false, sink, []netip.Prefix{ netip.MustParsePrefix("10.27.0.0/23"), diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 8adf871c..866c3cb2 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -340,7 +340,6 @@ func (api headscaleV1APIServer) ListMachines( m := machine.Proto() validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine( machine, - api.h.cfg.OIDC.StripEmaildomain, ) m.InvalidTags = invalidTags m.ValidTags = validTags diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 6f9498ea..427aaa1e 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -41,7 +41,6 @@ type Mapper struct { dnsCfg *tailcfg.DNSConfig logtail bool randomClientPort bool - stripEmailDomain bool } func NewMapper( @@ -53,7 +52,6 @@ func NewMapper( dnsCfg *tailcfg.DNSConfig, logtail bool, randomClientPort bool, - stripEmailDomain bool, ) *Mapper { return &Mapper{ db: db, @@ -66,7 +64,6 @@ func NewMapper( dnsCfg: dnsCfg, logtail: logtail, randomClientPort: randomClientPort, - stripEmailDomain: stripEmailDomain, } } @@ -87,14 +84,13 @@ func fullMapResponse( machine *types.Machine, peers types.Machines, - stripEmailDomain bool, baseDomain string, dnsCfg *tailcfg.DNSConfig, derpMap *tailcfg.DERPMap, logtail bool, randomClientPort bool, ) (*tailcfg.MapResponse, error) { - tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain, stripEmailDomain) + tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain) if err != nil { return nil, err } @@ -103,7 +99,6 @@ func fullMapResponse( pol, machine, peers, - stripEmailDomain, ) if err != nil { return nil, err @@ -129,7 +124,7 @@ func fullMapResponse( peers, ) - tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain, stripEmailDomain) + tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain) if err != nil { return nil, err } @@ -296,7 +291,6 @@ func (m Mapper) CreateMapResponse( pol, machine, peers, - m.stripEmailDomain, m.baseDomain, m.dnsCfg, m.derpMap, diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 919a22b2..6636e327 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -320,7 +320,6 @@ func Test_fullMapResponse(t *testing.T) { machine *types.Machine peers types.Machines - stripEmailDomain bool baseDomain string dnsConfig *tailcfg.DNSConfig derpMap *tailcfg.DERPMap @@ -335,7 +334,6 @@ func Test_fullMapResponse(t *testing.T) { // pol: &policy.ACLPolicy{}, // dnsConfig: &tailcfg.DNSConfig{}, // baseDomain: "", - // stripEmailDomain: false, // want: nil, // wantErr: true, // }, @@ -344,7 +342,6 @@ func Test_fullMapResponse(t *testing.T) { pol: &policy.ACLPolicy{}, machine: mini, peers: []types.Machine{}, - stripEmailDomain: false, baseDomain: "", dnsConfig: &tailcfg.DNSConfig{}, derpMap: &tailcfg.DERPMap{}, @@ -375,7 +372,6 @@ func Test_fullMapResponse(t *testing.T) { peers: []types.Machine{ peer1, }, - stripEmailDomain: false, baseDomain: "", dnsConfig: &tailcfg.DNSConfig{}, derpMap: &tailcfg.DERPMap{}, @@ -417,7 +413,6 @@ func Test_fullMapResponse(t *testing.T) { peer1, peer2, }, - stripEmailDomain: false, baseDomain: "", dnsConfig: &tailcfg.DNSConfig{}, derpMap: &tailcfg.DERPMap{}, @@ -458,7 +453,6 @@ func Test_fullMapResponse(t *testing.T) { tt.pol, tt.machine, tt.peers, - tt.stripEmailDomain, tt.baseDomain, tt.dnsConfig, tt.derpMap, diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index a6118648..afc9423d 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -18,7 +18,6 @@ func tailNodes( pol *policy.ACLPolicy, dnsConfig *tailcfg.DNSConfig, baseDomain string, - stripEmailDomain bool, ) ([]*tailcfg.Node, error) { nodes := make([]*tailcfg.Node, len(machines)) @@ -28,7 +27,6 @@ func tailNodes( pol, dnsConfig, baseDomain, - stripEmailDomain, ) if err != nil { return nil, err @@ -47,7 +45,6 @@ func tailNode( pol *policy.ACLPolicy, dnsConfig *tailcfg.DNSConfig, baseDomain string, - stripEmailDomain bool, ) (*tailcfg.Node, error) { nodeKey, err := machine.NodePublicKey() if err != nil { @@ -107,7 +104,7 @@ func tailNode( online := machine.IsOnline() - tags, _ := pol.GetTagsOfMachine(machine, stripEmailDomain) + tags, _ := pol.GetTagsOfMachine(machine) tags = lo.Uniq(append(tags, machine.ForcedTags...)) node := tailcfg.Node{ diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 3ffff7d3..9874a779 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -44,24 +44,22 @@ func TestTailNode(t *testing.T) { expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) tests := []struct { - name string - machine types.Machine - pol *policy.ACLPolicy - dnsConfig *tailcfg.DNSConfig - baseDomain string - stripEmailDomain bool - want *tailcfg.Node - wantErr bool + name string + machine types.Machine + pol *policy.ACLPolicy + dnsConfig *tailcfg.DNSConfig + baseDomain string + want *tailcfg.Node + wantErr bool }{ { - name: "empty-machine", - machine: types.Machine{}, - pol: &policy.ACLPolicy{}, - dnsConfig: &tailcfg.DNSConfig{}, - baseDomain: "", - stripEmailDomain: false, - want: nil, - wantErr: true, + name: "empty-machine", + machine: types.Machine{}, + pol: &policy.ACLPolicy{}, + dnsConfig: &tailcfg.DNSConfig{}, + baseDomain: "", + want: nil, + wantErr: true, }, { name: "minimal-machine", @@ -108,10 +106,9 @@ func TestTailNode(t *testing.T) { }, CreatedAt: created, }, - pol: &policy.ACLPolicy{}, - dnsConfig: &tailcfg.DNSConfig{}, - baseDomain: "", - stripEmailDomain: false, + pol: &policy.ACLPolicy{}, + dnsConfig: &tailcfg.DNSConfig{}, + baseDomain: "", want: &tailcfg.Node{ ID: 0, StableID: "0", @@ -172,7 +169,6 @@ func TestTailNode(t *testing.T) { tt.pol, tt.dnsConfig, tt.baseDomain, - tt.stripEmailDomain, ) if (err != nil) != tt.wantErr { diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 35580aac..d667c724 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -121,14 +121,13 @@ func GenerateFilterRules( policy *ACLPolicy, machine *types.Machine, peers types.Machines, - stripEmailDomain bool, ) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { // If there is no policy defined, we default to allow all if policy == nil { return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil } - rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain) + rules, err := policy.generateFilterRules(machine, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -136,7 +135,7 @@ func GenerateFilterRules( log.Trace().Interface("ACL", rules).Msg("ACL rules generated") var sshPolicy *tailcfg.SSHPolicy - sshRules, err := policy.generateSSHRules(machine, peers, stripEmailDomain) + sshRules, err := policy.generateSSHRules(machine, peers) if err != nil { return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err } @@ -154,7 +153,6 @@ func GenerateFilterRules( func (pol *ACLPolicy) generateFilterRules( machine *types.Machine, peers types.Machines, - stripEmailDomain bool, ) ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} machines := append(peers, *machine) @@ -166,7 +164,7 @@ func (pol *ACLPolicy) generateFilterRules( srcIPs := []string{} for srcIndex, src := range acl.Sources { - srcs, err := pol.getIPsFromSource(src, machines, stripEmailDomain) + srcs, err := pol.getIPsFromSource(src, machines) if err != nil { log.Error(). Interface("src", src). @@ -193,7 +191,6 @@ func (pol *ACLPolicy) generateFilterRules( dest, machines, needsWildcard, - stripEmailDomain, ) if err != nil { log.Error(). @@ -220,7 +217,6 @@ func (pol *ACLPolicy) generateFilterRules( func (pol *ACLPolicy) generateSSHRules( machine *types.Machine, peers types.Machines, - stripEmailDomain bool, ) ([]*tailcfg.SSHRule, error) { rules := []*tailcfg.SSHRule{} @@ -247,7 +243,7 @@ func (pol *ACLPolicy) generateSSHRules( for index, sshACL := range pol.SSHs { var dest netipx.IPSetBuilder for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, *machine), src, stripEmailDomain) + expanded, err := pol.ExpandAlias(append(peers, *machine), src) if err != nil { return nil, err } @@ -289,7 +285,7 @@ func (pol *ACLPolicy) generateSSHRules( Any: true, }) } else if isGroup(rawSrc) { - users, err := pol.getUsersInGroup(rawSrc, stripEmailDomain) + users, err := pol.getUsersInGroup(rawSrc) if err != nil { log.Error(). Msgf("Error parsing SSH %d, Source %d", index, innerIndex) @@ -306,7 +302,6 @@ func (pol *ACLPolicy) generateSSHRules( expandedSrcs, err := pol.ExpandAlias( peers, rawSrc, - stripEmailDomain, ) if err != nil { log.Error(). @@ -358,9 +353,8 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { func (pol *ACLPolicy) getIPsFromSource( src string, machines types.Machines, - stripEmaildomain bool, ) ([]string, error) { - ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain) + ipSet, err := pol.ExpandAlias(machines, src) if err != nil { return []string{}, err } @@ -380,7 +374,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( dest string, machines types.Machines, needsWildcard bool, - stripEmaildomain bool, ) ([]tailcfg.NetPortRange, error) { var tokens []string @@ -434,7 +427,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination( expanded, err := pol.ExpandAlias( machines, alias, - stripEmaildomain, ) if err != nil { return nil, err @@ -519,7 +511,6 @@ func parseProtocol(protocol string) ([]int, bool, error) { func (pol *ACLPolicy) ExpandAlias( machines types.Machines, alias string, - stripEmailDomain bool, ) (*netipx.IPSet, error) { if isWildcard(alias) { return util.ParseIPSet("*", nil) @@ -533,16 +524,16 @@ func (pol *ACLPolicy) ExpandAlias( // if alias is a group if isGroup(alias) { - return pol.getIPsFromGroup(alias, machines, stripEmailDomain) + return pol.getIPsFromGroup(alias, machines) } // if alias is a tag if isTag(alias) { - return pol.getIPsFromTag(alias, machines, stripEmailDomain) + return pol.getIPsFromTag(alias, machines) } // if alias is a user - if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil { + if ips, err := pol.getIPsForUser(alias, machines); ips != nil { return ips, err } @@ -551,7 +542,7 @@ func (pol *ACLPolicy) ExpandAlias( if h, ok := pol.Hosts[alias]; ok { log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - return pol.ExpandAlias(machines, h.String(), stripEmailDomain) + return pol.ExpandAlias(machines, h.String()) } // if alias is an IP @@ -576,12 +567,11 @@ func excludeCorrectlyTaggedNodes( aclPolicy *ACLPolicy, nodes types.Machines, user string, - stripEmailDomain bool, ) types.Machines { out := types.Machines{} tags := []string{} for tag := range aclPolicy.TagOwners { - owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain) + owners, _ := getTagOwners(aclPolicy, user) ns := append(owners, user) if util.StringOrPrefixListContains(ns, user) { tags = append(tags, tag) @@ -674,7 +664,6 @@ func filterMachinesByUser(machines types.Machines, user string) types.Machines { func getTagOwners( pol *ACLPolicy, tag string, - stripEmailDomain bool, ) ([]string, error) { var owners []string ows, ok := pol.TagOwners[tag] @@ -687,7 +676,7 @@ func getTagOwners( } for _, owner := range ows { if isGroup(owner) { - gs, err := pol.getUsersInGroup(owner, stripEmailDomain) + gs, err := pol.getUsersInGroup(owner) if err != nil { return []string{}, err } @@ -704,7 +693,6 @@ func getTagOwners( // after some validation. func (pol *ACLPolicy) getUsersInGroup( group string, - stripEmailDomain bool, ) ([]string, error) { users := []string{} log.Trace().Caller().Interface("pol", pol).Msg("test") @@ -723,7 +711,7 @@ func (pol *ACLPolicy) getUsersInGroup( ErrInvalidGroup, ) } - grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain) + grp, err := util.NormalizeToFQDNRulesConfigFromViper(group) if err != nil { return []string{}, fmt.Errorf( "failed to normalize group %q, err: %w", @@ -740,11 +728,10 @@ func (pol *ACLPolicy) getUsersInGroup( func (pol *ACLPolicy) getIPsFromGroup( group string, machines types.Machines, - stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} - users, err := pol.getUsersInGroup(group, stripEmailDomain) + users, err := pol.getUsersInGroup(group) if err != nil { return &netipx.IPSet{}, err } @@ -761,7 +748,6 @@ func (pol *ACLPolicy) getIPsFromGroup( func (pol *ACLPolicy) getIPsFromTag( alias string, machines types.Machines, - stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} @@ -773,7 +759,7 @@ func (pol *ACLPolicy) getIPsFromTag( } // find tag owners - owners, err := getTagOwners(pol, alias, stripEmailDomain) + owners, err := getTagOwners(pol, alias) if err != nil { if errors.Is(err, ErrInvalidTag) { ipSet, _ := build.IPSet() @@ -808,12 +794,11 @@ func (pol *ACLPolicy) getIPsFromTag( func (pol *ACLPolicy) getIPsForUser( user string, machines types.Machines, - stripEmailDomain bool, ) (*netipx.IPSet, error) { build := netipx.IPSetBuilder{} filteredMachines := filterMachinesByUser(machines, user) - filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain) + filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user) // shortcurcuit if we have no machines to get ips from. if len(filteredMachines) == 0 { @@ -885,7 +870,6 @@ func isTag(str string) bool { // Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. func (pol *ACLPolicy) GetTagsOfMachine( machine types.Machine, - stripEmailDomain bool, ) ([]string, []string) { validTags := make([]string, 0) invalidTags := make([]string, 0) @@ -893,7 +877,7 @@ func (pol *ACLPolicy) GetTagsOfMachine( validTagMap := make(map[string]bool) invalidTagMap := make(map[string]bool) for _, tag := range machine.HostInfo.RequestTags { - owners, err := getTagOwners(pol, tag, stripEmailDomain) + owners, err := getTagOwners(pol, tag) if errors.Is(err, ErrInvalidTag) { invalidTagMap[tag] = true diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index 5652e8c6..e220297f 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "go4.org/netipx" "gopkg.in/check.v1" @@ -199,7 +200,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) { c.Assert(pol.ACLs, check.HasLen, 6) c.Assert(err, check.IsNil) - rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) c.Assert(err, check.NotNil) c.Assert(rules, check.IsNil) } @@ -230,7 +231,7 @@ func (s *Suite) TestBasicRule(c *check.C) { pol, err := LoadACLPolicyFromBytes(acl, "hujson") c.Assert(err, check.IsNil) - rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) } @@ -246,7 +247,7 @@ func (s *Suite) TestInvalidAction(c *check.C) { }, }, } - _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) + _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}) c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) } @@ -265,7 +266,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) { }, }, } - _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) + _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}) c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) } @@ -281,7 +282,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) { }, } - _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false) + _, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}) c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) } @@ -310,7 +311,7 @@ func (s *Suite) TestPortRange(c *check.C) { c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -366,7 +367,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) { c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -401,7 +402,7 @@ func (s *Suite) TestPortWildcard(c *check.C) { c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -428,7 +429,7 @@ acls: c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -459,7 +460,7 @@ acls: c.Assert(err, check.IsNil) c.Assert(pol, check.NotNil) - rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false) + rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}) c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -483,8 +484,8 @@ func Test_expandGroup(t *testing.T) { pol ACLPolicy } type args struct { - group string - stripEmailDomain bool + group string + stripEmail bool } tests := []struct { name string @@ -504,8 +505,7 @@ func Test_expandGroup(t *testing.T) { }, }, args: args{ - group: "group:test", - stripEmailDomain: true, + group: "group:test", }, want: []string{"user1", "user2", "user3"}, wantErr: false, @@ -521,14 +521,13 @@ func Test_expandGroup(t *testing.T) { }, }, args: args{ - group: "group:undefined", - stripEmailDomain: true, + group: "group:undefined", }, want: []string{}, wantErr: true, }, { - name: "Expand emails in group", + name: "Expand emails in group strip domains", field: field{ pol: ACLPolicy{ Groups: Groups{ @@ -540,8 +539,8 @@ func Test_expandGroup(t *testing.T) { }, }, args: args{ - group: "group:admin", - stripEmailDomain: true, + group: "group:admin", + stripEmail: true, }, want: []string{"joe.bar", "john.doe"}, wantErr: false, @@ -559,8 +558,7 @@ func Test_expandGroup(t *testing.T) { }, }, args: args{ - group: "group:admin", - stripEmailDomain: false, + group: "group:admin", }, want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"}, wantErr: false, @@ -568,17 +566,20 @@ func Test_expandGroup(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + viper.Set("oidc.strip_email_domain", test.args.stripEmail) + got, err := test.field.pol.getUsersInGroup( test.args.group, - test.args.stripEmailDomain, ) + if (err != nil) != test.wantErr { t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr) return } - if !reflect.DeepEqual(got, test.want) { - t.Errorf("expandGroup() = %v, want %v", got, test.want) + + if diff := cmp.Diff(test.want, got); diff != "" { + t.Errorf("expandGroup() unexpected result (-want +got):\n%s", diff) } }) } @@ -586,9 +587,8 @@ func Test_expandGroup(t *testing.T) { func Test_expandTagOwners(t *testing.T) { type args struct { - aclPolicy *ACLPolicy - tag string - stripEmailDomain bool + aclPolicy *ACLPolicy + tag string } tests := []struct { name string @@ -602,8 +602,7 @@ func Test_expandTagOwners(t *testing.T) { aclPolicy: &ACLPolicy{ TagOwners: TagOwners{"tag:test": []string{"user1"}}, }, - tag: "tag:test", - stripEmailDomain: true, + tag: "tag:test", }, want: []string{"user1"}, wantErr: false, @@ -615,8 +614,7 @@ func Test_expandTagOwners(t *testing.T) { Groups: Groups{"group:foo": []string{"user1", "user2"}}, TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, }, - tag: "tag:test", - stripEmailDomain: true, + tag: "tag:test", }, want: []string{"user1", "user2"}, wantErr: false, @@ -628,8 +626,7 @@ func Test_expandTagOwners(t *testing.T) { Groups: Groups{"group:foo": []string{"user1", "user2"}}, TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}}, }, - tag: "tag:test", - stripEmailDomain: true, + tag: "tag:test", }, want: []string{"user1", "user2", "user3"}, wantErr: false, @@ -640,8 +637,7 @@ func Test_expandTagOwners(t *testing.T) { aclPolicy: &ACLPolicy{ TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}}, }, - tag: "tag:test", - stripEmailDomain: true, + tag: "tag:test", }, want: []string{}, wantErr: true, @@ -653,8 +649,7 @@ func Test_expandTagOwners(t *testing.T) { Groups: Groups{"group:bar": []string{"user1", "user2"}}, TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}}, }, - tag: "tag:test", - stripEmailDomain: true, + tag: "tag:test", }, want: []string{}, wantErr: true, @@ -665,7 +660,6 @@ func Test_expandTagOwners(t *testing.T) { got, err := getTagOwners( test.args.aclPolicy, test.args.tag, - test.args.stripEmailDomain, ) if (err != nil) != test.wantErr { t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr) @@ -861,10 +855,9 @@ func Test_expandAlias(t *testing.T) { pol ACLPolicy } type args struct { - machines types.Machines - aclPolicy ACLPolicy - alias string - stripEmailDomain bool + machines types.Machines + aclPolicy ACLPolicy + alias string } tests := []struct { name string @@ -888,7 +881,6 @@ func Test_expandAlias(t *testing.T) { }, }, }, - stripEmailDomain: true, }, want: set([]string{}, []string{ "0.0.0.0/0", @@ -931,7 +923,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: set([]string{ "100.64.0.1", "100.64.0.2", "100.64.0.3", @@ -973,7 +964,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: set([]string{}, []string{}), wantErr: true, @@ -984,9 +974,8 @@ func Test_expandAlias(t *testing.T) { pol: ACLPolicy{}, }, args: args{ - alias: "10.0.0.3", - machines: types.Machines{}, - stripEmailDomain: true, + alias: "10.0.0.3", + machines: types.Machines{}, }, want: set([]string{ "10.0.0.3", @@ -999,9 +988,8 @@ func Test_expandAlias(t *testing.T) { pol: ACLPolicy{}, }, args: args{ - alias: "10.0.0.1", - machines: types.Machines{}, - stripEmailDomain: true, + alias: "10.0.0.1", + machines: types.Machines{}, }, want: set([]string{ "10.0.0.1", @@ -1023,7 +1011,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: set([]string{ "10.0.0.1", @@ -1046,7 +1033,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: set([]string{ "10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", @@ -1069,7 +1055,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: set([]string{ "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1", @@ -1086,9 +1071,8 @@ func Test_expandAlias(t *testing.T) { }, }, args: args{ - alias: "testy", - machines: types.Machines{}, - stripEmailDomain: true, + alias: "testy", + machines: types.Machines{}, }, want: set([]string{}, []string{"10.0.0.132/32"}), wantErr: false, @@ -1103,9 +1087,8 @@ func Test_expandAlias(t *testing.T) { }, }, args: args{ - alias: "homeNetwork", - machines: types.Machines{}, - stripEmailDomain: true, + alias: "homeNetwork", + machines: types.Machines{}, }, want: set([]string{}, []string{"192.168.1.0/24"}), wantErr: false, @@ -1116,10 +1099,9 @@ func Test_expandAlias(t *testing.T) { pol: ACLPolicy{}, }, args: args{ - alias: "10.0.0.0/16", - machines: types.Machines{}, - aclPolicy: ACLPolicy{}, - stripEmailDomain: true, + alias: "10.0.0.0/16", + machines: types.Machines{}, + aclPolicy: ACLPolicy{}, }, want: set([]string{}, []string{"10.0.0.0/16"}), wantErr: false, @@ -1169,7 +1151,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "joe"}, }, }, - stripEmailDomain: true, }, want: set([]string{ "100.64.0.1", "100.64.0.2", @@ -1214,7 +1195,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: set([]string{}, []string{}), wantErr: true, @@ -1254,7 +1234,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), wantErr: false, @@ -1302,7 +1281,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), wantErr: false, @@ -1352,7 +1330,6 @@ func Test_expandAlias(t *testing.T) { User: types.User{Name: "joe"}, }, }, - stripEmailDomain: true, }, want: set([]string{"100.64.0.4"}, []string{}), wantErr: false, @@ -1363,7 +1340,6 @@ func Test_expandAlias(t *testing.T) { got, err := test.field.pol.ExpandAlias( test.args.machines, test.args.alias, - test.args.stripEmailDomain, ) if (err != nil) != test.wantErr { t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr) @@ -1379,10 +1355,9 @@ func Test_expandAlias(t *testing.T) { func Test_excludeCorrectlyTaggedNodes(t *testing.T) { type args struct { - aclPolicy *ACLPolicy - nodes types.Machines - user string - stripEmailDomain bool + aclPolicy *ACLPolicy + nodes types.Machines + user string } tests := []struct { name string @@ -1426,8 +1401,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { User: types.User{Name: "joe"}, }, }, - user: "joe", - stripEmailDomain: true, + user: "joe", }, want: types.Machines{ { @@ -1477,8 +1451,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { User: types.User{Name: "joe"}, }, }, - user: "joe", - stripEmailDomain: true, + user: "joe", }, want: types.Machines{ { @@ -1519,8 +1492,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { User: types.User{Name: "joe"}, }, }, - user: "joe", - stripEmailDomain: true, + user: "joe", }, want: types.Machines{ { @@ -1565,8 +1537,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { User: types.User{Name: "joe"}, }, }, - user: "joe", - stripEmailDomain: true, + user: "joe", }, want: types.Machines{ { @@ -1606,7 +1577,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { test.args.aclPolicy, test.args.nodes, test.args.user, - test.args.stripEmailDomain, ) if !reflect.DeepEqual(got, test.want) { t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want) @@ -1620,9 +1590,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { pol ACLPolicy } type args struct { - machine types.Machine - peers types.Machines - stripEmailDomain bool + machine types.Machine + peers types.Machines } tests := []struct { name string @@ -1652,9 +1621,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { }, }, args: args{ - machine: types.Machine{}, - peers: types.Machines{}, - stripEmailDomain: true, + machine: types.Machine{}, + peers: types.Machines{}, }, want: []tailcfg.FilterRule{ { @@ -1709,7 +1677,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { User: types.User{Name: "mickael"}, }, }, - stripEmailDomain: true, }, want: []tailcfg.FilterRule{ { @@ -1743,7 +1710,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { got, err := tt.field.pol.generateFilterRules( &tt.args.machine, tt.args.peers, - tt.args.stripEmailDomain, ) if (err != nil) != tt.wantErr { t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) @@ -1761,9 +1727,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) { func Test_getTags(t *testing.T) { type args struct { - aclPolicy *ACLPolicy - machine types.Machine - stripEmailDomain bool + aclPolicy *ACLPolicy + machine types.Machine } tests := []struct { name string @@ -1787,7 +1752,6 @@ func Test_getTags(t *testing.T) { RequestTags: []string{"tag:valid"}, }, }, - stripEmailDomain: false, }, wantValid: []string{"tag:valid"}, wantInvalid: nil, @@ -1808,7 +1772,6 @@ func Test_getTags(t *testing.T) { RequestTags: []string{"tag:valid", "tag:invalid"}, }, }, - stripEmailDomain: false, }, wantValid: []string{"tag:valid"}, wantInvalid: []string{"tag:invalid"}, @@ -1833,7 +1796,6 @@ func Test_getTags(t *testing.T) { }, }, }, - stripEmailDomain: false, }, wantValid: []string{"tag:valid"}, wantInvalid: []string{"tag:invalid"}, @@ -1854,7 +1816,6 @@ func Test_getTags(t *testing.T) { RequestTags: []string{"tag:invalid", "very-invalid"}, }, }, - stripEmailDomain: false, }, wantValid: nil, wantInvalid: []string{"tag:invalid", "very-invalid"}, @@ -1871,7 +1832,6 @@ func Test_getTags(t *testing.T) { RequestTags: []string{"tag:invalid", "very-invalid"}, }, }, - stripEmailDomain: false, }, wantValid: nil, wantInvalid: []string{"tag:invalid", "very-invalid"}, @@ -1881,7 +1841,6 @@ func Test_getTags(t *testing.T) { t.Run(test.name, func(t *testing.T) { gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine( test.args.machine, - test.args.stripEmailDomain, ) for _, valid := range gotValid { if !util.StringOrPrefixListContains(test.wantValid, valid) { @@ -2589,7 +2548,7 @@ func TestSSHRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers, false) + got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers) assert.NoError(t, err) if diff := cmp.Diff(tt.want, got); diff != "" { diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 56474ac1..d553c417 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -40,7 +40,6 @@ func (h *Headscale) handlePoll( h.cfg.DNSConfig, h.cfg.LogTail.Enabled, h.cfg.RandomizeClientPort, - h.cfg.OIDC.StripEmaildomain, ) machine.Hostname = mapRequest.Hostinfo.Hostname @@ -265,7 +264,6 @@ func (h *Headscale) pollNetMapStream( h.cfg.DNSConfig, h.cfg.LogTail.Enabled, h.cfg.RandomizeClientPort, - h.cfg.OIDC.StripEmaildomain, ) h.pollNetMapStreamWG.Add(1) @@ -656,7 +654,6 @@ func (h *Headscale) scheduledPollWorker( h.cfg.DNSConfig, h.cfg.LogTail.Enabled, h.cfg.RandomizeClientPort, - h.cfg.OIDC.StripEmaildomain, ) keepAliveTicker := time.NewTicker(keepAliveInterval) diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index 5c666436..c6bd2b69 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -7,6 +7,7 @@ import ( "regexp" "strings" + "github.com/spf13/viper" "go4.org/netipx" "tailscale.com/util/dnsname" ) @@ -24,6 +25,12 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") var ErrInvalidUserName = errors.New("invalid user name") +func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) { + strip := viper.GetBool("oidc.strip_email_domain") + + return NormalizeToFQDNRules(name, strip) +} + // NormalizeToFQDNRules will replace forbidden chars in user // it can also return an error if the user doesn't respect RFC 952 and 1123. func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {