mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
remove "stripEmailDomain" argument
This commit makes a wrapper function round the normalisation requiring "stripEmailDomain" which has to be passed in almost all functions of headscale by loading it from Viper instead. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
161243c787
commit
717abe89c1
16 changed files with 127 additions and 220 deletions
|
@ -169,7 +169,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
database, err := db.NewHeadscaleDatabase(
|
database, err := db.NewHeadscaleDatabase(
|
||||||
cfg.DBtype,
|
cfg.DBtype,
|
||||||
dbString,
|
dbString,
|
||||||
cfg.OIDC.StripEmaildomain,
|
|
||||||
app.dbDebug,
|
app.dbDebug,
|
||||||
app.stateUpdateChan,
|
app.stateUpdateChan,
|
||||||
cfg.IPPrefixes,
|
cfg.IPPrefixes,
|
||||||
|
|
|
@ -53,7 +53,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -107,7 +107,7 @@ func TestInvalidTagValidUser(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -169,7 +169,7 @@ func TestPortGroup(t *testing.T) {
|
||||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -224,7 +224,7 @@ func TestPortUser(t *testing.T) {
|
||||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -285,7 +285,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
||||||
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||||
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||||
|
|
||||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
|
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -361,7 +361,7 @@ func TestValidTagInvalidUser(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false)
|
got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
|
|
@ -41,16 +41,15 @@ type HSDatabase struct {
|
||||||
|
|
||||||
ipAllocationMutex sync.Mutex
|
ipAllocationMutex sync.Mutex
|
||||||
|
|
||||||
ipPrefixes []netip.Prefix
|
ipPrefixes []netip.Prefix
|
||||||
baseDomain string
|
baseDomain string
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): assemble this struct from toptions or something typed
|
// TODO(kradalby): assemble this struct from toptions or something typed
|
||||||
// rather than arguments.
|
// rather than arguments.
|
||||||
func NewHeadscaleDatabase(
|
func NewHeadscaleDatabase(
|
||||||
dbType, connectionAddr string,
|
dbType, connectionAddr string,
|
||||||
stripEmailDomain, debug bool,
|
debug bool,
|
||||||
notifyStateChan chan<- struct{},
|
notifyStateChan chan<- struct{},
|
||||||
ipPrefixes []netip.Prefix,
|
ipPrefixes []netip.Prefix,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
|
@ -64,9 +63,8 @@ func NewHeadscaleDatabase(
|
||||||
db: dbConn,
|
db: dbConn,
|
||||||
notifyStateChan: notifyStateChan,
|
notifyStateChan: notifyStateChan,
|
||||||
|
|
||||||
ipPrefixes: ipPrefixes,
|
ipPrefixes: ipPrefixes,
|
||||||
baseDomain: baseDomain,
|
baseDomain: baseDomain,
|
||||||
stripEmailDomain: stripEmailDomain,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("database %#v", dbConn)
|
log.Debug().Msgf("database %#v", dbConn)
|
||||||
|
@ -202,9 +200,8 @@ func NewHeadscaleDatabase(
|
||||||
|
|
||||||
for item, machine := range machines {
|
for item, machine := range machines {
|
||||||
if machine.GivenName == "" {
|
if machine.GivenName == "" {
|
||||||
normalizedHostname, err := util.NormalizeToFQDNRules(
|
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
|
||||||
machine.Hostname,
|
machine.Hostname,
|
||||||
stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
|
|
@ -632,9 +632,8 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||||
normalizedHostname, err := util.NormalizeToFQDNRules(
|
normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper(
|
||||||
suppliedName,
|
suppliedName,
|
||||||
hsdb.stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|
|
@ -293,10 +293,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
testPeers, err := db.ListPeers(testMachine)
|
testPeers, err := db.ListPeers(testMachine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers, false)
|
adminRules, _, err := policy.GenerateFilterRules(aclPolicy, adminMachine, adminPeers)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers, false)
|
testRules, _, err := policy.GenerateFilterRules(aclPolicy, testMachine, testPeers)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules)
|
peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, adminPeers, adminRules)
|
||||||
|
@ -482,9 +482,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "simple machine name generation",
|
name: "simple machine name generation",
|
||||||
db: &HSDatabase{
|
db: &HSDatabase{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "testmachine",
|
suppliedName: "testmachine",
|
||||||
randomSuffix: false,
|
randomSuffix: false,
|
||||||
|
@ -494,9 +492,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 53 chars",
|
name: "machine name with 53 chars",
|
||||||
db: &HSDatabase{
|
db: &HSDatabase{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
|
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
|
||||||
randomSuffix: false,
|
randomSuffix: false,
|
||||||
|
@ -506,9 +502,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 63 chars",
|
name: "machine name with 63 chars",
|
||||||
db: &HSDatabase{
|
db: &HSDatabase{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||||
randomSuffix: false,
|
randomSuffix: false,
|
||||||
|
@ -518,9 +512,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 64 chars",
|
name: "machine name with 64 chars",
|
||||||
db: &HSDatabase{
|
db: &HSDatabase{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
|
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
|
||||||
randomSuffix: false,
|
randomSuffix: false,
|
||||||
|
@ -530,9 +522,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 73 chars",
|
name: "machine name with 73 chars",
|
||||||
db: &HSDatabase{
|
db: &HSDatabase{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
|
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
|
||||||
randomSuffix: false,
|
randomSuffix: false,
|
||||||
|
@ -542,9 +532,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with random suffix",
|
name: "machine name with random suffix",
|
||||||
db: &HSDatabase{
|
db: &HSDatabase{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "test",
|
suppliedName: "test",
|
||||||
randomSuffix: true,
|
randomSuffix: true,
|
||||||
|
@ -554,9 +542,7 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "machine name with 63 chars with random suffix",
|
name: "machine name with 63 chars with random suffix",
|
||||||
db: &HSDatabase{
|
db: &HSDatabase{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
|
||||||
args: args{
|
args: args{
|
||||||
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||||
randomSuffix: true,
|
randomSuffix: true,
|
||||||
|
|
|
@ -424,7 +424,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
} else {
|
} else {
|
||||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||||
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain)
|
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Err(err).
|
log.Err(err).
|
||||||
Str("alias", approvedAlias).
|
Str("alias", approvedAlias).
|
||||||
|
|
|
@ -60,7 +60,6 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
"sqlite3",
|
"sqlite3",
|
||||||
tmpDir+"/headscale_test.db",
|
tmpDir+"/headscale_test.db",
|
||||||
false,
|
false,
|
||||||
false,
|
|
||||||
sink,
|
sink,
|
||||||
[]netip.Prefix{
|
[]netip.Prefix{
|
||||||
netip.MustParsePrefix("10.27.0.0/23"),
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
|
|
|
@ -340,7 +340,6 @@ func (api headscaleV1APIServer) ListMachines(
|
||||||
m := machine.Proto()
|
m := machine.Proto()
|
||||||
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
|
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
|
||||||
machine,
|
machine,
|
||||||
api.h.cfg.OIDC.StripEmaildomain,
|
|
||||||
)
|
)
|
||||||
m.InvalidTags = invalidTags
|
m.InvalidTags = invalidTags
|
||||||
m.ValidTags = validTags
|
m.ValidTags = validTags
|
||||||
|
|
|
@ -41,7 +41,6 @@ type Mapper struct {
|
||||||
dnsCfg *tailcfg.DNSConfig
|
dnsCfg *tailcfg.DNSConfig
|
||||||
logtail bool
|
logtail bool
|
||||||
randomClientPort bool
|
randomClientPort bool
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMapper(
|
func NewMapper(
|
||||||
|
@ -53,7 +52,6 @@ func NewMapper(
|
||||||
dnsCfg *tailcfg.DNSConfig,
|
dnsCfg *tailcfg.DNSConfig,
|
||||||
logtail bool,
|
logtail bool,
|
||||||
randomClientPort bool,
|
randomClientPort bool,
|
||||||
stripEmailDomain bool,
|
|
||||||
) *Mapper {
|
) *Mapper {
|
||||||
return &Mapper{
|
return &Mapper{
|
||||||
db: db,
|
db: db,
|
||||||
|
@ -66,7 +64,6 @@ func NewMapper(
|
||||||
dnsCfg: dnsCfg,
|
dnsCfg: dnsCfg,
|
||||||
logtail: logtail,
|
logtail: logtail,
|
||||||
randomClientPort: randomClientPort,
|
randomClientPort: randomClientPort,
|
||||||
stripEmailDomain: stripEmailDomain,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,14 +84,13 @@ func fullMapResponse(
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
peers types.Machines,
|
peers types.Machines,
|
||||||
|
|
||||||
stripEmailDomain bool,
|
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
dnsCfg *tailcfg.DNSConfig,
|
dnsCfg *tailcfg.DNSConfig,
|
||||||
derpMap *tailcfg.DERPMap,
|
derpMap *tailcfg.DERPMap,
|
||||||
logtail bool,
|
logtail bool,
|
||||||
randomClientPort bool,
|
randomClientPort bool,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain, stripEmailDomain)
|
tailnode, err := tailNode(*machine, pol, dnsCfg, baseDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -103,7 +99,6 @@ func fullMapResponse(
|
||||||
pol,
|
pol,
|
||||||
machine,
|
machine,
|
||||||
peers,
|
peers,
|
||||||
stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -129,7 +124,7 @@ func fullMapResponse(
|
||||||
peers,
|
peers,
|
||||||
)
|
)
|
||||||
|
|
||||||
tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain, stripEmailDomain)
|
tailPeers, err := tailNodes(peers, pol, dnsCfg, baseDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -296,7 +291,6 @@ func (m Mapper) CreateMapResponse(
|
||||||
pol,
|
pol,
|
||||||
machine,
|
machine,
|
||||||
peers,
|
peers,
|
||||||
m.stripEmailDomain,
|
|
||||||
m.baseDomain,
|
m.baseDomain,
|
||||||
m.dnsCfg,
|
m.dnsCfg,
|
||||||
m.derpMap,
|
m.derpMap,
|
||||||
|
|
|
@ -320,7 +320,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
machine *types.Machine
|
machine *types.Machine
|
||||||
peers types.Machines
|
peers types.Machines
|
||||||
|
|
||||||
stripEmailDomain bool
|
|
||||||
baseDomain string
|
baseDomain string
|
||||||
dnsConfig *tailcfg.DNSConfig
|
dnsConfig *tailcfg.DNSConfig
|
||||||
derpMap *tailcfg.DERPMap
|
derpMap *tailcfg.DERPMap
|
||||||
|
@ -335,7 +334,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
// pol: &policy.ACLPolicy{},
|
// pol: &policy.ACLPolicy{},
|
||||||
// dnsConfig: &tailcfg.DNSConfig{},
|
// dnsConfig: &tailcfg.DNSConfig{},
|
||||||
// baseDomain: "",
|
// baseDomain: "",
|
||||||
// stripEmailDomain: false,
|
|
||||||
// want: nil,
|
// want: nil,
|
||||||
// wantErr: true,
|
// wantErr: true,
|
||||||
// },
|
// },
|
||||||
|
@ -344,7 +342,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
pol: &policy.ACLPolicy{},
|
pol: &policy.ACLPolicy{},
|
||||||
machine: mini,
|
machine: mini,
|
||||||
peers: []types.Machine{},
|
peers: []types.Machine{},
|
||||||
stripEmailDomain: false,
|
|
||||||
baseDomain: "",
|
baseDomain: "",
|
||||||
dnsConfig: &tailcfg.DNSConfig{},
|
dnsConfig: &tailcfg.DNSConfig{},
|
||||||
derpMap: &tailcfg.DERPMap{},
|
derpMap: &tailcfg.DERPMap{},
|
||||||
|
@ -375,7 +372,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
peers: []types.Machine{
|
peers: []types.Machine{
|
||||||
peer1,
|
peer1,
|
||||||
},
|
},
|
||||||
stripEmailDomain: false,
|
|
||||||
baseDomain: "",
|
baseDomain: "",
|
||||||
dnsConfig: &tailcfg.DNSConfig{},
|
dnsConfig: &tailcfg.DNSConfig{},
|
||||||
derpMap: &tailcfg.DERPMap{},
|
derpMap: &tailcfg.DERPMap{},
|
||||||
|
@ -417,7 +413,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
peer1,
|
peer1,
|
||||||
peer2,
|
peer2,
|
||||||
},
|
},
|
||||||
stripEmailDomain: false,
|
|
||||||
baseDomain: "",
|
baseDomain: "",
|
||||||
dnsConfig: &tailcfg.DNSConfig{},
|
dnsConfig: &tailcfg.DNSConfig{},
|
||||||
derpMap: &tailcfg.DERPMap{},
|
derpMap: &tailcfg.DERPMap{},
|
||||||
|
@ -458,7 +453,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
tt.pol,
|
tt.pol,
|
||||||
tt.machine,
|
tt.machine,
|
||||||
tt.peers,
|
tt.peers,
|
||||||
tt.stripEmailDomain,
|
|
||||||
tt.baseDomain,
|
tt.baseDomain,
|
||||||
tt.dnsConfig,
|
tt.dnsConfig,
|
||||||
tt.derpMap,
|
tt.derpMap,
|
||||||
|
|
|
@ -18,7 +18,6 @@ func tailNodes(
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
stripEmailDomain bool,
|
|
||||||
) ([]*tailcfg.Node, error) {
|
) ([]*tailcfg.Node, error) {
|
||||||
nodes := make([]*tailcfg.Node, len(machines))
|
nodes := make([]*tailcfg.Node, len(machines))
|
||||||
|
|
||||||
|
@ -28,7 +27,6 @@ func tailNodes(
|
||||||
pol,
|
pol,
|
||||||
dnsConfig,
|
dnsConfig,
|
||||||
baseDomain,
|
baseDomain,
|
||||||
stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -47,7 +45,6 @@ func tailNode(
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
dnsConfig *tailcfg.DNSConfig,
|
dnsConfig *tailcfg.DNSConfig,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
stripEmailDomain bool,
|
|
||||||
) (*tailcfg.Node, error) {
|
) (*tailcfg.Node, error) {
|
||||||
nodeKey, err := machine.NodePublicKey()
|
nodeKey, err := machine.NodePublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -107,7 +104,7 @@ func tailNode(
|
||||||
|
|
||||||
online := machine.IsOnline()
|
online := machine.IsOnline()
|
||||||
|
|
||||||
tags, _ := pol.GetTagsOfMachine(machine, stripEmailDomain)
|
tags, _ := pol.GetTagsOfMachine(machine)
|
||||||
tags = lo.Uniq(append(tags, machine.ForcedTags...))
|
tags = lo.Uniq(append(tags, machine.ForcedTags...))
|
||||||
|
|
||||||
node := tailcfg.Node{
|
node := tailcfg.Node{
|
||||||
|
|
|
@ -44,24 +44,22 @@ func TestTailNode(t *testing.T) {
|
||||||
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
machine types.Machine
|
machine types.Machine
|
||||||
pol *policy.ACLPolicy
|
pol *policy.ACLPolicy
|
||||||
dnsConfig *tailcfg.DNSConfig
|
dnsConfig *tailcfg.DNSConfig
|
||||||
baseDomain string
|
baseDomain string
|
||||||
stripEmailDomain bool
|
want *tailcfg.Node
|
||||||
want *tailcfg.Node
|
wantErr bool
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "empty-machine",
|
name: "empty-machine",
|
||||||
machine: types.Machine{},
|
machine: types.Machine{},
|
||||||
pol: &policy.ACLPolicy{},
|
pol: &policy.ACLPolicy{},
|
||||||
dnsConfig: &tailcfg.DNSConfig{},
|
dnsConfig: &tailcfg.DNSConfig{},
|
||||||
baseDomain: "",
|
baseDomain: "",
|
||||||
stripEmailDomain: false,
|
want: nil,
|
||||||
want: nil,
|
wantErr: true,
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "minimal-machine",
|
name: "minimal-machine",
|
||||||
|
@ -108,10 +106,9 @@ func TestTailNode(t *testing.T) {
|
||||||
},
|
},
|
||||||
CreatedAt: created,
|
CreatedAt: created,
|
||||||
},
|
},
|
||||||
pol: &policy.ACLPolicy{},
|
pol: &policy.ACLPolicy{},
|
||||||
dnsConfig: &tailcfg.DNSConfig{},
|
dnsConfig: &tailcfg.DNSConfig{},
|
||||||
baseDomain: "",
|
baseDomain: "",
|
||||||
stripEmailDomain: false,
|
|
||||||
want: &tailcfg.Node{
|
want: &tailcfg.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
StableID: "0",
|
StableID: "0",
|
||||||
|
@ -172,7 +169,6 @@ func TestTailNode(t *testing.T) {
|
||||||
tt.pol,
|
tt.pol,
|
||||||
tt.dnsConfig,
|
tt.dnsConfig,
|
||||||
tt.baseDomain,
|
tt.baseDomain,
|
||||||
tt.stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
|
|
|
@ -121,14 +121,13 @@ func GenerateFilterRules(
|
||||||
policy *ACLPolicy,
|
policy *ACLPolicy,
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
peers types.Machines,
|
peers types.Machines,
|
||||||
stripEmailDomain bool,
|
|
||||||
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
||||||
// If there is no policy defined, we default to allow all
|
// If there is no policy defined, we default to allow all
|
||||||
if policy == nil {
|
if policy == nil {
|
||||||
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain)
|
rules, err := policy.generateFilterRules(machine, peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
@ -136,7 +135,7 @@ func GenerateFilterRules(
|
||||||
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
|
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
|
||||||
|
|
||||||
var sshPolicy *tailcfg.SSHPolicy
|
var sshPolicy *tailcfg.SSHPolicy
|
||||||
sshRules, err := policy.generateSSHRules(machine, peers, stripEmailDomain)
|
sshRules, err := policy.generateSSHRules(machine, peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
@ -154,7 +153,6 @@ func GenerateFilterRules(
|
||||||
func (pol *ACLPolicy) generateFilterRules(
|
func (pol *ACLPolicy) generateFilterRules(
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
peers types.Machines,
|
peers types.Machines,
|
||||||
stripEmailDomain bool,
|
|
||||||
) ([]tailcfg.FilterRule, error) {
|
) ([]tailcfg.FilterRule, error) {
|
||||||
rules := []tailcfg.FilterRule{}
|
rules := []tailcfg.FilterRule{}
|
||||||
machines := append(peers, *machine)
|
machines := append(peers, *machine)
|
||||||
|
@ -166,7 +164,7 @@ func (pol *ACLPolicy) generateFilterRules(
|
||||||
|
|
||||||
srcIPs := []string{}
|
srcIPs := []string{}
|
||||||
for srcIndex, src := range acl.Sources {
|
for srcIndex, src := range acl.Sources {
|
||||||
srcs, err := pol.getIPsFromSource(src, machines, stripEmailDomain)
|
srcs, err := pol.getIPsFromSource(src, machines)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Interface("src", src).
|
Interface("src", src).
|
||||||
|
@ -193,7 +191,6 @@ func (pol *ACLPolicy) generateFilterRules(
|
||||||
dest,
|
dest,
|
||||||
machines,
|
machines,
|
||||||
needsWildcard,
|
needsWildcard,
|
||||||
stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -220,7 +217,6 @@ func (pol *ACLPolicy) generateFilterRules(
|
||||||
func (pol *ACLPolicy) generateSSHRules(
|
func (pol *ACLPolicy) generateSSHRules(
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
peers types.Machines,
|
peers types.Machines,
|
||||||
stripEmailDomain bool,
|
|
||||||
) ([]*tailcfg.SSHRule, error) {
|
) ([]*tailcfg.SSHRule, error) {
|
||||||
rules := []*tailcfg.SSHRule{}
|
rules := []*tailcfg.SSHRule{}
|
||||||
|
|
||||||
|
@ -247,7 +243,7 @@ func (pol *ACLPolicy) generateSSHRules(
|
||||||
for index, sshACL := range pol.SSHs {
|
for index, sshACL := range pol.SSHs {
|
||||||
var dest netipx.IPSetBuilder
|
var dest netipx.IPSetBuilder
|
||||||
for _, src := range sshACL.Destinations {
|
for _, src := range sshACL.Destinations {
|
||||||
expanded, err := pol.ExpandAlias(append(peers, *machine), src, stripEmailDomain)
|
expanded, err := pol.ExpandAlias(append(peers, *machine), src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -289,7 +285,7 @@ func (pol *ACLPolicy) generateSSHRules(
|
||||||
Any: true,
|
Any: true,
|
||||||
})
|
})
|
||||||
} else if isGroup(rawSrc) {
|
} else if isGroup(rawSrc) {
|
||||||
users, err := pol.getUsersInGroup(rawSrc, stripEmailDomain)
|
users, err := pol.getUsersInGroup(rawSrc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
|
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
|
||||||
|
@ -306,7 +302,6 @@ func (pol *ACLPolicy) generateSSHRules(
|
||||||
expandedSrcs, err := pol.ExpandAlias(
|
expandedSrcs, err := pol.ExpandAlias(
|
||||||
peers,
|
peers,
|
||||||
rawSrc,
|
rawSrc,
|
||||||
stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -358,9 +353,8 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
||||||
func (pol *ACLPolicy) getIPsFromSource(
|
func (pol *ACLPolicy) getIPsFromSource(
|
||||||
src string,
|
src string,
|
||||||
machines types.Machines,
|
machines types.Machines,
|
||||||
stripEmaildomain bool,
|
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain)
|
ipSet, err := pol.ExpandAlias(machines, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, err
|
return []string{}, err
|
||||||
}
|
}
|
||||||
|
@ -380,7 +374,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||||
dest string,
|
dest string,
|
||||||
machines types.Machines,
|
machines types.Machines,
|
||||||
needsWildcard bool,
|
needsWildcard bool,
|
||||||
stripEmaildomain bool,
|
|
||||||
) ([]tailcfg.NetPortRange, error) {
|
) ([]tailcfg.NetPortRange, error) {
|
||||||
var tokens []string
|
var tokens []string
|
||||||
|
|
||||||
|
@ -434,7 +427,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||||
expanded, err := pol.ExpandAlias(
|
expanded, err := pol.ExpandAlias(
|
||||||
machines,
|
machines,
|
||||||
alias,
|
alias,
|
||||||
stripEmaildomain,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -519,7 +511,6 @@ func parseProtocol(protocol string) ([]int, bool, error) {
|
||||||
func (pol *ACLPolicy) ExpandAlias(
|
func (pol *ACLPolicy) ExpandAlias(
|
||||||
machines types.Machines,
|
machines types.Machines,
|
||||||
alias string,
|
alias string,
|
||||||
stripEmailDomain bool,
|
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
if isWildcard(alias) {
|
if isWildcard(alias) {
|
||||||
return util.ParseIPSet("*", nil)
|
return util.ParseIPSet("*", nil)
|
||||||
|
@ -533,16 +524,16 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
|
|
||||||
// if alias is a group
|
// if alias is a group
|
||||||
if isGroup(alias) {
|
if isGroup(alias) {
|
||||||
return pol.getIPsFromGroup(alias, machines, stripEmailDomain)
|
return pol.getIPsFromGroup(alias, machines)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is a tag
|
// if alias is a tag
|
||||||
if isTag(alias) {
|
if isTag(alias) {
|
||||||
return pol.getIPsFromTag(alias, machines, stripEmailDomain)
|
return pol.getIPsFromTag(alias, machines)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is a user
|
// if alias is a user
|
||||||
if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil {
|
if ips, err := pol.getIPsForUser(alias, machines); ips != nil {
|
||||||
return ips, err
|
return ips, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -551,7 +542,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
if h, ok := pol.Hosts[alias]; ok {
|
if h, ok := pol.Hosts[alias]; ok {
|
||||||
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
||||||
|
|
||||||
return pol.ExpandAlias(machines, h.String(), stripEmailDomain)
|
return pol.ExpandAlias(machines, h.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is an IP
|
// if alias is an IP
|
||||||
|
@ -576,12 +567,11 @@ func excludeCorrectlyTaggedNodes(
|
||||||
aclPolicy *ACLPolicy,
|
aclPolicy *ACLPolicy,
|
||||||
nodes types.Machines,
|
nodes types.Machines,
|
||||||
user string,
|
user string,
|
||||||
stripEmailDomain bool,
|
|
||||||
) types.Machines {
|
) types.Machines {
|
||||||
out := types.Machines{}
|
out := types.Machines{}
|
||||||
tags := []string{}
|
tags := []string{}
|
||||||
for tag := range aclPolicy.TagOwners {
|
for tag := range aclPolicy.TagOwners {
|
||||||
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
|
owners, _ := getTagOwners(aclPolicy, user)
|
||||||
ns := append(owners, user)
|
ns := append(owners, user)
|
||||||
if util.StringOrPrefixListContains(ns, user) {
|
if util.StringOrPrefixListContains(ns, user) {
|
||||||
tags = append(tags, tag)
|
tags = append(tags, tag)
|
||||||
|
@ -674,7 +664,6 @@ func filterMachinesByUser(machines types.Machines, user string) types.Machines {
|
||||||
func getTagOwners(
|
func getTagOwners(
|
||||||
pol *ACLPolicy,
|
pol *ACLPolicy,
|
||||||
tag string,
|
tag string,
|
||||||
stripEmailDomain bool,
|
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
var owners []string
|
var owners []string
|
||||||
ows, ok := pol.TagOwners[tag]
|
ows, ok := pol.TagOwners[tag]
|
||||||
|
@ -687,7 +676,7 @@ func getTagOwners(
|
||||||
}
|
}
|
||||||
for _, owner := range ows {
|
for _, owner := range ows {
|
||||||
if isGroup(owner) {
|
if isGroup(owner) {
|
||||||
gs, err := pol.getUsersInGroup(owner, stripEmailDomain)
|
gs, err := pol.getUsersInGroup(owner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, err
|
return []string{}, err
|
||||||
}
|
}
|
||||||
|
@ -704,7 +693,6 @@ func getTagOwners(
|
||||||
// after some validation.
|
// after some validation.
|
||||||
func (pol *ACLPolicy) getUsersInGroup(
|
func (pol *ACLPolicy) getUsersInGroup(
|
||||||
group string,
|
group string,
|
||||||
stripEmailDomain bool,
|
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
users := []string{}
|
users := []string{}
|
||||||
log.Trace().Caller().Interface("pol", pol).Msg("test")
|
log.Trace().Caller().Interface("pol", pol).Msg("test")
|
||||||
|
@ -723,7 +711,7 @@ func (pol *ACLPolicy) getUsersInGroup(
|
||||||
ErrInvalidGroup,
|
ErrInvalidGroup,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain)
|
grp, err := util.NormalizeToFQDNRulesConfigFromViper(group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, fmt.Errorf(
|
return []string{}, fmt.Errorf(
|
||||||
"failed to normalize group %q, err: %w",
|
"failed to normalize group %q, err: %w",
|
||||||
|
@ -740,11 +728,10 @@ func (pol *ACLPolicy) getUsersInGroup(
|
||||||
func (pol *ACLPolicy) getIPsFromGroup(
|
func (pol *ACLPolicy) getIPsFromGroup(
|
||||||
group string,
|
group string,
|
||||||
machines types.Machines,
|
machines types.Machines,
|
||||||
stripEmailDomain bool,
|
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
build := netipx.IPSetBuilder{}
|
build := netipx.IPSetBuilder{}
|
||||||
|
|
||||||
users, err := pol.getUsersInGroup(group, stripEmailDomain)
|
users, err := pol.getUsersInGroup(group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &netipx.IPSet{}, err
|
return &netipx.IPSet{}, err
|
||||||
}
|
}
|
||||||
|
@ -761,7 +748,6 @@ func (pol *ACLPolicy) getIPsFromGroup(
|
||||||
func (pol *ACLPolicy) getIPsFromTag(
|
func (pol *ACLPolicy) getIPsFromTag(
|
||||||
alias string,
|
alias string,
|
||||||
machines types.Machines,
|
machines types.Machines,
|
||||||
stripEmailDomain bool,
|
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
build := netipx.IPSetBuilder{}
|
build := netipx.IPSetBuilder{}
|
||||||
|
|
||||||
|
@ -773,7 +759,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||||
}
|
}
|
||||||
|
|
||||||
// find tag owners
|
// find tag owners
|
||||||
owners, err := getTagOwners(pol, alias, stripEmailDomain)
|
owners, err := getTagOwners(pol, alias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrInvalidTag) {
|
if errors.Is(err, ErrInvalidTag) {
|
||||||
ipSet, _ := build.IPSet()
|
ipSet, _ := build.IPSet()
|
||||||
|
@ -808,12 +794,11 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||||
func (pol *ACLPolicy) getIPsForUser(
|
func (pol *ACLPolicy) getIPsForUser(
|
||||||
user string,
|
user string,
|
||||||
machines types.Machines,
|
machines types.Machines,
|
||||||
stripEmailDomain bool,
|
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
build := netipx.IPSetBuilder{}
|
build := netipx.IPSetBuilder{}
|
||||||
|
|
||||||
filteredMachines := filterMachinesByUser(machines, user)
|
filteredMachines := filterMachinesByUser(machines, user)
|
||||||
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain)
|
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user)
|
||||||
|
|
||||||
// shortcurcuit if we have no machines to get ips from.
|
// shortcurcuit if we have no machines to get ips from.
|
||||||
if len(filteredMachines) == 0 {
|
if len(filteredMachines) == 0 {
|
||||||
|
@ -885,7 +870,6 @@ func isTag(str string) bool {
|
||||||
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
|
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
|
||||||
func (pol *ACLPolicy) GetTagsOfMachine(
|
func (pol *ACLPolicy) GetTagsOfMachine(
|
||||||
machine types.Machine,
|
machine types.Machine,
|
||||||
stripEmailDomain bool,
|
|
||||||
) ([]string, []string) {
|
) ([]string, []string) {
|
||||||
validTags := make([]string, 0)
|
validTags := make([]string, 0)
|
||||||
invalidTags := make([]string, 0)
|
invalidTags := make([]string, 0)
|
||||||
|
@ -893,7 +877,7 @@ func (pol *ACLPolicy) GetTagsOfMachine(
|
||||||
validTagMap := make(map[string]bool)
|
validTagMap := make(map[string]bool)
|
||||||
invalidTagMap := make(map[string]bool)
|
invalidTagMap := make(map[string]bool)
|
||||||
for _, tag := range machine.HostInfo.RequestTags {
|
for _, tag := range machine.HostInfo.RequestTags {
|
||||||
owners, err := getTagOwners(pol, tag, stripEmailDomain)
|
owners, err := getTagOwners(pol, tag)
|
||||||
if errors.Is(err, ErrInvalidTag) {
|
if errors.Is(err, ErrInvalidTag) {
|
||||||
invalidTagMap[tag] = true
|
invalidTagMap[tag] = true
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/spf13/viper"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
@ -199,7 +200,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
|
||||||
c.Assert(pol.ACLs, check.HasLen, 6)
|
c.Assert(pol.ACLs, check.HasLen, 6)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
|
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(rules, check.IsNil)
|
c.Assert(rules, check.IsNil)
|
||||||
}
|
}
|
||||||
|
@ -230,7 +231,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
|
||||||
pol, err := LoadACLPolicyFromBytes(acl, "hujson")
|
pol, err := LoadACLPolicyFromBytes(acl, "hujson")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
|
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
}
|
}
|
||||||
|
@ -246,7 +247,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
|
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,7 +266,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
|
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -281,7 +282,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{}, false)
|
_, _, err := GenerateFilterRules(pol, &types.Machine{}, types.Machines{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,7 +311,7 @@ func (s *Suite) TestPortRange(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pol, check.NotNil)
|
c.Assert(pol, check.NotNil)
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
|
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
@ -366,7 +367,7 @@ func (s *Suite) TestProtocolParsing(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pol, check.NotNil)
|
c.Assert(pol, check.NotNil)
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
|
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
@ -401,7 +402,7 @@ func (s *Suite) TestPortWildcard(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pol, check.NotNil)
|
c.Assert(pol, check.NotNil)
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
|
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
@ -428,7 +429,7 @@ acls:
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pol, check.NotNil)
|
c.Assert(pol, check.NotNil)
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
|
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
@ -459,7 +460,7 @@ acls:
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pol, check.NotNil)
|
c.Assert(pol, check.NotNil)
|
||||||
|
|
||||||
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{}, false)
|
rules, err := pol.generateFilterRules(&types.Machine{}, types.Machines{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(rules, check.NotNil)
|
c.Assert(rules, check.NotNil)
|
||||||
|
|
||||||
|
@ -483,8 +484,8 @@ func Test_expandGroup(t *testing.T) {
|
||||||
pol ACLPolicy
|
pol ACLPolicy
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
group string
|
group string
|
||||||
stripEmailDomain bool
|
stripEmail bool
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -504,8 +505,7 @@ func Test_expandGroup(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
group: "group:test",
|
group: "group:test",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []string{"user1", "user2", "user3"},
|
want: []string{"user1", "user2", "user3"},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -521,14 +521,13 @@ func Test_expandGroup(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
group: "group:undefined",
|
group: "group:undefined",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []string{},
|
want: []string{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Expand emails in group",
|
name: "Expand emails in group strip domains",
|
||||||
field: field{
|
field: field{
|
||||||
pol: ACLPolicy{
|
pol: ACLPolicy{
|
||||||
Groups: Groups{
|
Groups: Groups{
|
||||||
|
@ -540,8 +539,8 @@ func Test_expandGroup(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
group: "group:admin",
|
group: "group:admin",
|
||||||
stripEmailDomain: true,
|
stripEmail: true,
|
||||||
},
|
},
|
||||||
want: []string{"joe.bar", "john.doe"},
|
want: []string{"joe.bar", "john.doe"},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -559,8 +558,7 @@ func Test_expandGroup(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
group: "group:admin",
|
group: "group:admin",
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
},
|
||||||
want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"},
|
want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -568,17 +566,20 @@ func Test_expandGroup(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
viper.Set("oidc.strip_email_domain", test.args.stripEmail)
|
||||||
|
|
||||||
got, err := test.field.pol.getUsersInGroup(
|
got, err := test.field.pol.getUsersInGroup(
|
||||||
test.args.group,
|
test.args.group,
|
||||||
test.args.stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (err != nil) != test.wantErr {
|
if (err != nil) != test.wantErr {
|
||||||
t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr)
|
t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(got, test.want) {
|
|
||||||
t.Errorf("expandGroup() = %v, want %v", got, test.want)
|
if diff := cmp.Diff(test.want, got); diff != "" {
|
||||||
|
t.Errorf("expandGroup() unexpected result (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -586,9 +587,8 @@ func Test_expandGroup(t *testing.T) {
|
||||||
|
|
||||||
func Test_expandTagOwners(t *testing.T) {
|
func Test_expandTagOwners(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
aclPolicy *ACLPolicy
|
aclPolicy *ACLPolicy
|
||||||
tag string
|
tag string
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -602,8 +602,7 @@ func Test_expandTagOwners(t *testing.T) {
|
||||||
aclPolicy: &ACLPolicy{
|
aclPolicy: &ACLPolicy{
|
||||||
TagOwners: TagOwners{"tag:test": []string{"user1"}},
|
TagOwners: TagOwners{"tag:test": []string{"user1"}},
|
||||||
},
|
},
|
||||||
tag: "tag:test",
|
tag: "tag:test",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []string{"user1"},
|
want: []string{"user1"},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -615,8 +614,7 @@ func Test_expandTagOwners(t *testing.T) {
|
||||||
Groups: Groups{"group:foo": []string{"user1", "user2"}},
|
Groups: Groups{"group:foo": []string{"user1", "user2"}},
|
||||||
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
|
TagOwners: TagOwners{"tag:test": []string{"group:foo"}},
|
||||||
},
|
},
|
||||||
tag: "tag:test",
|
tag: "tag:test",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []string{"user1", "user2"},
|
want: []string{"user1", "user2"},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -628,8 +626,7 @@ func Test_expandTagOwners(t *testing.T) {
|
||||||
Groups: Groups{"group:foo": []string{"user1", "user2"}},
|
Groups: Groups{"group:foo": []string{"user1", "user2"}},
|
||||||
TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}},
|
TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}},
|
||||||
},
|
},
|
||||||
tag: "tag:test",
|
tag: "tag:test",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []string{"user1", "user2", "user3"},
|
want: []string{"user1", "user2", "user3"},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -640,8 +637,7 @@ func Test_expandTagOwners(t *testing.T) {
|
||||||
aclPolicy: &ACLPolicy{
|
aclPolicy: &ACLPolicy{
|
||||||
TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}},
|
TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}},
|
||||||
},
|
},
|
||||||
tag: "tag:test",
|
tag: "tag:test",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []string{},
|
want: []string{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
|
@ -653,8 +649,7 @@ func Test_expandTagOwners(t *testing.T) {
|
||||||
Groups: Groups{"group:bar": []string{"user1", "user2"}},
|
Groups: Groups{"group:bar": []string{"user1", "user2"}},
|
||||||
TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}},
|
TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}},
|
||||||
},
|
},
|
||||||
tag: "tag:test",
|
tag: "tag:test",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []string{},
|
want: []string{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
|
@ -665,7 +660,6 @@ func Test_expandTagOwners(t *testing.T) {
|
||||||
got, err := getTagOwners(
|
got, err := getTagOwners(
|
||||||
test.args.aclPolicy,
|
test.args.aclPolicy,
|
||||||
test.args.tag,
|
test.args.tag,
|
||||||
test.args.stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if (err != nil) != test.wantErr {
|
if (err != nil) != test.wantErr {
|
||||||
t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr)
|
t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr)
|
||||||
|
@ -861,10 +855,9 @@ func Test_expandAlias(t *testing.T) {
|
||||||
pol ACLPolicy
|
pol ACLPolicy
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
machines types.Machines
|
machines types.Machines
|
||||||
aclPolicy ACLPolicy
|
aclPolicy ACLPolicy
|
||||||
alias string
|
alias string
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -888,7 +881,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{}, []string{
|
want: set([]string{}, []string{
|
||||||
"0.0.0.0/0",
|
"0.0.0.0/0",
|
||||||
|
@ -931,7 +923,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{
|
want: set([]string{
|
||||||
"100.64.0.1", "100.64.0.2", "100.64.0.3",
|
"100.64.0.1", "100.64.0.2", "100.64.0.3",
|
||||||
|
@ -973,7 +964,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{}, []string{}),
|
want: set([]string{}, []string{}),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
|
@ -984,9 +974,8 @@ func Test_expandAlias(t *testing.T) {
|
||||||
pol: ACLPolicy{},
|
pol: ACLPolicy{},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
alias: "10.0.0.3",
|
alias: "10.0.0.3",
|
||||||
machines: types.Machines{},
|
machines: types.Machines{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{
|
want: set([]string{
|
||||||
"10.0.0.3",
|
"10.0.0.3",
|
||||||
|
@ -999,9 +988,8 @@ func Test_expandAlias(t *testing.T) {
|
||||||
pol: ACLPolicy{},
|
pol: ACLPolicy{},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
alias: "10.0.0.1",
|
alias: "10.0.0.1",
|
||||||
machines: types.Machines{},
|
machines: types.Machines{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{
|
want: set([]string{
|
||||||
"10.0.0.1",
|
"10.0.0.1",
|
||||||
|
@ -1023,7 +1011,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{
|
want: set([]string{
|
||||||
"10.0.0.1",
|
"10.0.0.1",
|
||||||
|
@ -1046,7 +1033,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{
|
want: set([]string{
|
||||||
"10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222",
|
"10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222",
|
||||||
|
@ -1069,7 +1055,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{
|
want: set([]string{
|
||||||
"fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1",
|
"fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1",
|
||||||
|
@ -1086,9 +1071,8 @@ func Test_expandAlias(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
alias: "testy",
|
alias: "testy",
|
||||||
machines: types.Machines{},
|
machines: types.Machines{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{}, []string{"10.0.0.132/32"}),
|
want: set([]string{}, []string{"10.0.0.132/32"}),
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -1103,9 +1087,8 @@ func Test_expandAlias(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
alias: "homeNetwork",
|
alias: "homeNetwork",
|
||||||
machines: types.Machines{},
|
machines: types.Machines{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{}, []string{"192.168.1.0/24"}),
|
want: set([]string{}, []string{"192.168.1.0/24"}),
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -1116,10 +1099,9 @@ func Test_expandAlias(t *testing.T) {
|
||||||
pol: ACLPolicy{},
|
pol: ACLPolicy{},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
alias: "10.0.0.0/16",
|
alias: "10.0.0.0/16",
|
||||||
machines: types.Machines{},
|
machines: types.Machines{},
|
||||||
aclPolicy: ACLPolicy{},
|
aclPolicy: ACLPolicy{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{}, []string{"10.0.0.0/16"}),
|
want: set([]string{}, []string{"10.0.0.0/16"}),
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -1169,7 +1151,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "joe"},
|
User: types.User{Name: "joe"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{
|
want: set([]string{
|
||||||
"100.64.0.1", "100.64.0.2",
|
"100.64.0.1", "100.64.0.2",
|
||||||
|
@ -1214,7 +1195,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{}, []string{}),
|
want: set([]string{}, []string{}),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
|
@ -1254,7 +1234,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
|
want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -1302,7 +1281,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
|
want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}),
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -1352,7 +1330,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
User: types.User{Name: "joe"},
|
User: types.User{Name: "joe"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: set([]string{"100.64.0.4"}, []string{}),
|
want: set([]string{"100.64.0.4"}, []string{}),
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
|
@ -1363,7 +1340,6 @@ func Test_expandAlias(t *testing.T) {
|
||||||
got, err := test.field.pol.ExpandAlias(
|
got, err := test.field.pol.ExpandAlias(
|
||||||
test.args.machines,
|
test.args.machines,
|
||||||
test.args.alias,
|
test.args.alias,
|
||||||
test.args.stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if (err != nil) != test.wantErr {
|
if (err != nil) != test.wantErr {
|
||||||
t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr)
|
t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr)
|
||||||
|
@ -1379,10 +1355,9 @@ func Test_expandAlias(t *testing.T) {
|
||||||
|
|
||||||
func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
aclPolicy *ACLPolicy
|
aclPolicy *ACLPolicy
|
||||||
nodes types.Machines
|
nodes types.Machines
|
||||||
user string
|
user string
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -1426,8 +1401,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||||
User: types.User{Name: "joe"},
|
User: types.User{Name: "joe"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
user: "joe",
|
user: "joe",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: types.Machines{
|
want: types.Machines{
|
||||||
{
|
{
|
||||||
|
@ -1477,8 +1451,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||||
User: types.User{Name: "joe"},
|
User: types.User{Name: "joe"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
user: "joe",
|
user: "joe",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: types.Machines{
|
want: types.Machines{
|
||||||
{
|
{
|
||||||
|
@ -1519,8 +1492,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||||
User: types.User{Name: "joe"},
|
User: types.User{Name: "joe"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
user: "joe",
|
user: "joe",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: types.Machines{
|
want: types.Machines{
|
||||||
{
|
{
|
||||||
|
@ -1565,8 +1537,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||||
User: types.User{Name: "joe"},
|
User: types.User{Name: "joe"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
user: "joe",
|
user: "joe",
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: types.Machines{
|
want: types.Machines{
|
||||||
{
|
{
|
||||||
|
@ -1606,7 +1577,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
|
||||||
test.args.aclPolicy,
|
test.args.aclPolicy,
|
||||||
test.args.nodes,
|
test.args.nodes,
|
||||||
test.args.user,
|
test.args.user,
|
||||||
test.args.stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if !reflect.DeepEqual(got, test.want) {
|
if !reflect.DeepEqual(got, test.want) {
|
||||||
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want)
|
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want)
|
||||||
|
@ -1620,9 +1590,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
pol ACLPolicy
|
pol ACLPolicy
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
machine types.Machine
|
machine types.Machine
|
||||||
peers types.Machines
|
peers types.Machines
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -1652,9 +1621,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
machine: types.Machine{},
|
machine: types.Machine{},
|
||||||
peers: types.Machines{},
|
peers: types.Machines{},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
{
|
{
|
||||||
|
@ -1709,7 +1677,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
User: types.User{Name: "mickael"},
|
User: types.User{Name: "mickael"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: true,
|
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
{
|
{
|
||||||
|
@ -1743,7 +1710,6 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
got, err := tt.field.pol.generateFilterRules(
|
got, err := tt.field.pol.generateFilterRules(
|
||||||
&tt.args.machine,
|
&tt.args.machine,
|
||||||
tt.args.peers,
|
tt.args.peers,
|
||||||
tt.args.stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -1761,9 +1727,8 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
|
|
||||||
func Test_getTags(t *testing.T) {
|
func Test_getTags(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
aclPolicy *ACLPolicy
|
aclPolicy *ACLPolicy
|
||||||
machine types.Machine
|
machine types.Machine
|
||||||
stripEmailDomain bool
|
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -1787,7 +1752,6 @@ func Test_getTags(t *testing.T) {
|
||||||
RequestTags: []string{"tag:valid"},
|
RequestTags: []string{"tag:valid"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
},
|
||||||
wantValid: []string{"tag:valid"},
|
wantValid: []string{"tag:valid"},
|
||||||
wantInvalid: nil,
|
wantInvalid: nil,
|
||||||
|
@ -1808,7 +1772,6 @@ func Test_getTags(t *testing.T) {
|
||||||
RequestTags: []string{"tag:valid", "tag:invalid"},
|
RequestTags: []string{"tag:valid", "tag:invalid"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
},
|
||||||
wantValid: []string{"tag:valid"},
|
wantValid: []string{"tag:valid"},
|
||||||
wantInvalid: []string{"tag:invalid"},
|
wantInvalid: []string{"tag:invalid"},
|
||||||
|
@ -1833,7 +1796,6 @@ func Test_getTags(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
},
|
||||||
wantValid: []string{"tag:valid"},
|
wantValid: []string{"tag:valid"},
|
||||||
wantInvalid: []string{"tag:invalid"},
|
wantInvalid: []string{"tag:invalid"},
|
||||||
|
@ -1854,7 +1816,6 @@ func Test_getTags(t *testing.T) {
|
||||||
RequestTags: []string{"tag:invalid", "very-invalid"},
|
RequestTags: []string{"tag:invalid", "very-invalid"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
},
|
||||||
wantValid: nil,
|
wantValid: nil,
|
||||||
wantInvalid: []string{"tag:invalid", "very-invalid"},
|
wantInvalid: []string{"tag:invalid", "very-invalid"},
|
||||||
|
@ -1871,7 +1832,6 @@ func Test_getTags(t *testing.T) {
|
||||||
RequestTags: []string{"tag:invalid", "very-invalid"},
|
RequestTags: []string{"tag:invalid", "very-invalid"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
stripEmailDomain: false,
|
|
||||||
},
|
},
|
||||||
wantValid: nil,
|
wantValid: nil,
|
||||||
wantInvalid: []string{"tag:invalid", "very-invalid"},
|
wantInvalid: []string{"tag:invalid", "very-invalid"},
|
||||||
|
@ -1881,7 +1841,6 @@ func Test_getTags(t *testing.T) {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine(
|
gotValid, gotInvalid := test.args.aclPolicy.GetTagsOfMachine(
|
||||||
test.args.machine,
|
test.args.machine,
|
||||||
test.args.stripEmailDomain,
|
|
||||||
)
|
)
|
||||||
for _, valid := range gotValid {
|
for _, valid := range gotValid {
|
||||||
if !util.StringOrPrefixListContains(test.wantValid, valid) {
|
if !util.StringOrPrefixListContains(test.wantValid, valid) {
|
||||||
|
@ -2589,7 +2548,7 @@ func TestSSHRules(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers, false)
|
got, err := tt.pol.generateSSHRules(&tt.machine, tt.peers)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
|
|
@ -40,7 +40,6 @@ func (h *Headscale) handlePoll(
|
||||||
h.cfg.DNSConfig,
|
h.cfg.DNSConfig,
|
||||||
h.cfg.LogTail.Enabled,
|
h.cfg.LogTail.Enabled,
|
||||||
h.cfg.RandomizeClientPort,
|
h.cfg.RandomizeClientPort,
|
||||||
h.cfg.OIDC.StripEmaildomain,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
machine.Hostname = mapRequest.Hostinfo.Hostname
|
machine.Hostname = mapRequest.Hostinfo.Hostname
|
||||||
|
@ -265,7 +264,6 @@ func (h *Headscale) pollNetMapStream(
|
||||||
h.cfg.DNSConfig,
|
h.cfg.DNSConfig,
|
||||||
h.cfg.LogTail.Enabled,
|
h.cfg.LogTail.Enabled,
|
||||||
h.cfg.RandomizeClientPort,
|
h.cfg.RandomizeClientPort,
|
||||||
h.cfg.OIDC.StripEmaildomain,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
h.pollNetMapStreamWG.Add(1)
|
h.pollNetMapStreamWG.Add(1)
|
||||||
|
@ -656,7 +654,6 @@ func (h *Headscale) scheduledPollWorker(
|
||||||
h.cfg.DNSConfig,
|
h.cfg.DNSConfig,
|
||||||
h.cfg.LogTail.Enabled,
|
h.cfg.LogTail.Enabled,
|
||||||
h.cfg.RandomizeClientPort,
|
h.cfg.RandomizeClientPort,
|
||||||
h.cfg.OIDC.StripEmaildomain,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/viper"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"tailscale.com/util/dnsname"
|
"tailscale.com/util/dnsname"
|
||||||
)
|
)
|
||||||
|
@ -24,6 +25,12 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||||
|
|
||||||
var ErrInvalidUserName = errors.New("invalid user name")
|
var ErrInvalidUserName = errors.New("invalid user name")
|
||||||
|
|
||||||
|
func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) {
|
||||||
|
strip := viper.GetBool("oidc.strip_email_domain")
|
||||||
|
|
||||||
|
return NormalizeToFQDNRules(name, strip)
|
||||||
|
}
|
||||||
|
|
||||||
// NormalizeToFQDNRules will replace forbidden chars in user
|
// NormalizeToFQDNRules will replace forbidden chars in user
|
||||||
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
||||||
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
||||||
|
|
Loading…
Reference in a new issue