Merge pull request #707 from restanrm/fix-bug-in-excludecorrectlytaggednodes

Fix bug in excludeCorrectlyTaggedNodes
This commit is contained in:
Juan Font 2022-08-11 13:08:43 +02:00 committed by GitHub
commit 6c8445988c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 199 additions and 44 deletions

View file

@ -3,6 +3,7 @@
## 0.17.0 (2022-xx-xx) ## 0.17.0 (2022-xx-xx)
- Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722) - Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722)
- Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563)
## 0.16.0 (2022-07-25) ## 0.16.0 (2022-07-25)

23
acls.go
View file

@ -162,7 +162,12 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
destPorts := []tailcfg.NetPortRange{} destPorts := []tailcfg.NetPortRange{}
for innerIndex, dest := range acl.Destinations { for innerIndex, dest := range acl.Destinations {
dests, err := h.generateACLPolicyDest(machines, *h.aclPolicy, dest, needsWildcard) dests, err := h.generateACLPolicyDest(
machines,
*h.aclPolicy,
dest,
needsWildcard,
)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Destination %d", index, innerIndex) Msgf("Error parsing ACL %d, Destination %d", index, innerIndex)
@ -255,7 +260,12 @@ func (h *Headscale) generateACLPolicyDest(
func parseProtocol(protocol string) ([]int, bool, error) { func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol { switch protocol {
case "": case "":
return []int{protocolICMP, protocolIPv6ICMP, protocolTCP, protocolUDP}, false, nil return []int{
protocolICMP,
protocolIPv6ICMP,
protocolTCP,
protocolUDP,
}, false, nil
case "igmp": case "igmp":
return []int{protocolIGMP}, true, nil return []int{protocolIGMP}, true, nil
case "ipv4", "ip-in-ip": case "ipv4", "ip-in-ip":
@ -284,7 +294,9 @@ func parseProtocol(protocol string) ([]int, bool, error) {
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
needsWildcard := protocolNumber != protocolTCP && protocolNumber != protocolUDP && protocolNumber != protocolSCTP needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard, nil return []int{protocolNumber}, needsWildcard, nil
} }
@ -367,7 +379,7 @@ func expandAlias(
// if alias is a namespace // if alias is a namespace
nodes := filterMachinesByNamespace(machines, alias) nodes := filterMachinesByNamespace(machines, alias)
nodes = excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias) nodes = excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias, stripEmailDomain)
for _, n := range nodes { for _, n := range nodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...) ips = append(ips, n.IPAddresses.ToStringSlice()...)
@ -405,10 +417,13 @@ func excludeCorrectlyTaggedNodes(
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
nodes []Machine, nodes []Machine,
namespace string, namespace string,
stripEmailDomain bool,
) []Machine { ) []Machine {
out := []Machine{} out := []Machine{}
tags := []string{} tags := []string{}
for tag, ns := range aclPolicy.TagOwners { for tag, ns := range aclPolicy.TagOwners {
owners, _ := expandTagOwners(aclPolicy, namespace, stripEmailDomain)
ns = append(owners, namespace)
if contains(ns, namespace) { if contains(ns, namespace) {
tags = append(tags, tag) tags = append(tags, tag)
} }

View file

@ -62,7 +62,11 @@ func (s *Suite) TestBasicRule(c *check.C) {
func (s *Suite) TestInvalidAction(c *check.C) { func (s *Suite) TestInvalidAction(c *check.C) {
app.aclPolicy = &ACLPolicy{ app.aclPolicy = &ACLPolicy{
ACLs: []ACL{ ACLs: []ACL{
{Action: "invalidAction", Sources: []string{"*"}, Destinations: []string{"*:*"}}, {
Action: "invalidAction",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
}, },
} }
err := app.UpdateACLRules() err := app.UpdateACLRules()
@ -77,7 +81,11 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
"group:error": []string{"foo", "group:test"}, "group:error": []string{"foo", "group:test"},
}, },
ACLs: []ACL{ ACLs: []ACL{
{Action: "accept", Sources: []string{"group:error"}, Destinations: []string{"*:*"}}, {
Action: "accept",
Sources: []string{"group:error"},
Destinations: []string{"*:*"},
},
}, },
} }
err := app.UpdateACLRules() err := app.UpdateACLRules()
@ -88,7 +96,11 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
// this ACL is wrong because no tagOwners own the requested tag for the server // this ACL is wrong because no tagOwners own the requested tag for the server
app.aclPolicy = &ACLPolicy{ app.aclPolicy = &ACLPolicy{
ACLs: []ACL{ ACLs: []ACL{
{Action: "accept", Sources: []string{"tag:foo"}, Destinations: []string{"*:*"}}, {
Action: "accept",
Sources: []string{"tag:foo"},
Destinations: []string{"*:*"},
},
}, },
} }
err := app.UpdateACLRules() err := app.UpdateACLRules()
@ -131,7 +143,11 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
Groups: Groups{"group:test": []string{"user1", "user2"}}, Groups: Groups{"group:test": []string{"user1", "user2"}},
TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []ACL{ ACLs: []ACL{
{Action: "accept", Sources: []string{"tag:test"}, Destinations: []string{"*:*"}}, {
Action: "accept",
Sources: []string{"tag:test"},
Destinations: []string{"*:*"},
},
}, },
} }
err = app.UpdateACLRules() err = app.UpdateACLRules()
@ -177,7 +193,11 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
Groups: Groups{"group:test": []string{"user1", "user2"}}, Groups: Groups{"group:test": []string{"user1", "user2"}},
TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []ACL{ ACLs: []ACL{
{Action: "accept", Sources: []string{"*"}, Destinations: []string{"tag:test:*"}}, {
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"tag:test:*"},
},
}, },
} }
err = app.UpdateACLRules() err = app.UpdateACLRules()
@ -222,7 +242,11 @@ func (s *Suite) TestInvalidTagValidNamespace(c *check.C) {
app.aclPolicy = &ACLPolicy{ app.aclPolicy = &ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"user1"}}, TagOwners: TagOwners{"tag:test": []string{"user1"}},
ACLs: []ACL{ ACLs: []ACL{
{Action: "accept", Sources: []string{"user1"}, Destinations: []string{"*:*"}}, {
Action: "accept",
Sources: []string{"user1"},
Destinations: []string{"*:*"},
},
}, },
} }
err = app.UpdateACLRules() err = app.UpdateACLRules()
@ -1204,6 +1228,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
aclPolicy ACLPolicy aclPolicy ACLPolicy
nodes []Machine nodes []Machine
namespace string namespace string
stripEmailDomain bool
} }
tests := []struct { tests := []struct {
name string name string
@ -1248,6 +1273,58 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
}, },
}, },
namespace: "joe", namespace: "joe",
stripEmailDomain: true,
},
want: []Machine{
{
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")},
Namespace: Namespace{Name: "joe"},
},
},
},
{
name: "exclude nodes with valid tags, and owner is in a group",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{
"group:accountant": []string{"joe", "bar"},
},
TagOwners: TagOwners{
"tag:accountant-webserver": []string{"group:accountant"},
},
},
nodes: []Machine{
{
IPAddresses: MachineAddresses{
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
},
},
{
IPAddresses: MachineAddresses{
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
},
},
{
IPAddresses: MachineAddresses{
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "joe"},
},
},
namespace: "joe",
stripEmailDomain: true,
}, },
want: []Machine{ want: []Machine{
{ {
@ -1289,6 +1366,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
}, },
}, },
namespace: "joe", namespace: "joe",
stripEmailDomain: true,
}, },
want: []Machine{ want: []Machine{
{ {
@ -1334,6 +1412,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
}, },
}, },
namespace: "joe", namespace: "joe",
stripEmailDomain: true,
}, },
want: []Machine{ want: []Machine{
{ {
@ -1373,6 +1452,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
test.args.aclPolicy, test.args.aclPolicy,
test.args.nodes, test.args.nodes,
test.args.namespace, test.args.namespace,
test.args.stripEmailDomain,
) )
if !reflect.DeepEqual(got, test.want) { if !reflect.DeepEqual(got, test.want) {
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want) t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want)

11
api.go
View file

@ -271,7 +271,8 @@ func (h *Headscale) RegistrationHandler(
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) { if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() && registerRequest.Expiry.UTC().Before(now) { if !registerRequest.Expiry.IsZero() &&
registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOut(writer, req, machineKey, *machine) h.handleMachineLogOut(writer, req, machineKey, *machine)
return return
@ -289,7 +290,13 @@ func (h *Headscale) RegistrationHandler(
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() { !machine.isExpired() {
h.handleMachineRefreshKey(writer, req, machineKey, registerRequest, *machine) h.handleMachineRefreshKey(
writer,
req,
machineKey,
registerRequest,
*machine,
)
return return
} }

22
app.go
View file

@ -418,16 +418,20 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet) router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost) router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).
Methods(http.MethodPost)
router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost) router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost)
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet) router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).Methods(http.MethodGet) router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).
Methods(http.MethodGet)
router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet) router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).
Methods(http.MethodGet)
router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet) router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet)
router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet) router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).
Methods(http.MethodGet)
if h.cfg.DERP.ServerEnabled { if h.cfg.DERP.ServerEnabled {
router.HandleFunc("/derp", h.DERPHandler) router.HandleFunc("/derp", h.DERPHandler)
@ -692,7 +696,10 @@ func (h *Headscale) Serve() error {
h.pollNetMapStreamWG.Wait() h.pollNetMapStreamWG.Wait()
// Gracefully shut down servers // Gracefully shut down servers
ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout) ctx, cancel := context.WithTimeout(
context.Background(),
HTTPShutdownTimeout,
)
if err := promHTTPServer.Shutdown(ctx); err != nil { if err := promHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown prometheus http") log.Error().Err(err).Msg("Failed to shutdown prometheus http")
} }
@ -819,7 +826,10 @@ func (h *Headscale) setLastStateChangeToNow(namespaces ...string) {
if len(namespaces) == 0 { if len(namespaces) == 0 {
namespaces, err = h.ListNamespacesStr() namespaces, err = h.ListNamespacesStr()
if err != nil { if err != nil {
log.Error().Caller().Err(err).Msg("failed to fetch all namespaces, failing to update last changed state.") log.Error().
Caller().
Err(err).
Msg("failed to fetch all namespaces, failing to update last changed state.")
} }
} }

View file

@ -134,7 +134,9 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
expiration := time.Now().UTC().Add(time.Duration(duration)) expiration := time.Now().UTC().Add(time.Duration(duration))
log.Trace().Dur("expiration", time.Duration(duration)).Msg("expiration has been set") log.Trace().
Dur("expiration", time.Duration(duration)).
Msg("expiration has been set")
request.Expiration = timestamppb.New(expiration) request.Expiration = timestamppb.New(expiration)

View file

@ -164,7 +164,9 @@ var createPreAuthKeyCmd = &cobra.Command{
expiration := time.Now().UTC().Add(time.Duration(duration)) expiration := time.Now().UTC().Add(time.Duration(duration))
log.Trace().Dur("expiration", time.Duration(duration)).Msg("expiration has been set") log.Trace().
Dur("expiration", time.Duration(duration)).
Msg("expiration has been set")
request.Expiration = timestamppb.New(expiration) request.Expiration = timestamppb.New(expiration)

View file

@ -24,7 +24,10 @@ const (
func getHeadscaleApp() (*headscale.Headscale, error) { func getHeadscaleApp() (*headscale.Headscale, error) {
cfg, err := headscale.GetHeadscaleConfig() cfg, err := headscale.GetHeadscaleConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load configuration while creating headscale instance: %w", err) return nil, fmt.Errorf(
"failed to load configuration while creating headscale instance: %w",
err,
)
} }
app, err := headscale.NewHeadscale(cfg) app, err := headscale.NewHeadscale(cfg)

View file

@ -69,7 +69,13 @@ func (s *IntegrationCLITestSuite) SetupTest() {
err = s.pool.RemoveContainerByName(headscaleHostname) err = s.pool.RemoveContainerByName(headscaleHostname)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("Could not remove existing container before building test: %s", err), "") s.FailNow(
fmt.Sprintf(
"Could not remove existing container before building test: %s",
err,
),
"",
)
} }
fmt.Println("Creating headscale container for CLI tests") fmt.Println("Creating headscale container for CLI tests")

View file

@ -227,7 +227,6 @@ func getIPs(
func getDNSNames( func getDNSNames(
headscale *dockertest.Resource, headscale *dockertest.Resource,
) ([]string, error) { ) ([]string, error) {
listAllResult, err := ExecuteCommand( listAllResult, err := ExecuteCommand(
headscale, headscale,
[]string{ []string{
@ -261,7 +260,6 @@ func getDNSNames(
func getMagicFQDN( func getMagicFQDN(
headscale *dockertest.Resource, headscale *dockertest.Resource,
) ([]string, error) { ) ([]string, error) {
listAllResult, err := ExecuteCommand( listAllResult, err := ExecuteCommand(
headscale, headscale,
[]string{ []string{
@ -286,7 +284,11 @@ func getMagicFQDN(
hostnames := make([]string, len(listAll)) hostnames := make([]string, len(listAll))
for index := range listAll { for index := range listAll {
hostnames[index] = fmt.Sprintf("%s.%s.headscale.net", listAll[index].GetGivenName(), listAll[index].GetNamespace().GetName()) hostnames[index] = fmt.Sprintf(
"%s.%s.headscale.net",
listAll[index].GetGivenName(),
listAll[index].GetNamespace().GetName(),
)
} }
return hostnames, nil return hostnames, nil

View file

@ -129,7 +129,13 @@ func (s *IntegrationDERPTestSuite) SetupSuite() {
err = s.pool.RemoveContainerByName(headscaleHostname) err = s.pool.RemoveContainerByName(headscaleHostname)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("Could not remove existing container before building test: %s", err), "") s.FailNow(
fmt.Sprintf(
"Could not remove existing container before building test: %s",
err,
),
"",
)
} }
log.Println("Creating headscale container for DERP integration tests") log.Println("Creating headscale container for DERP integration tests")

View file

@ -246,7 +246,13 @@ func (s *IntegrationTestSuite) SetupSuite() {
err = s.pool.RemoveContainerByName(headscaleHostname) err = s.pool.RemoveContainerByName(headscaleHostname)
if err != nil { if err != nil {
s.FailNow(fmt.Sprintf("Could not remove existing container before building test: %s", err), "") s.FailNow(
fmt.Sprintf(
"Could not remove existing container before building test: %s",
err,
),
"",
)
} }
log.Println("Creating headscale container for core integration tests") log.Println("Creating headscale container for core integration tests")

View file

@ -188,8 +188,16 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
Hosts: map[string]netaddr.IPPrefix{}, Hosts: map[string]netaddr.IPPrefix{},
TagOwners: map[string][]string{}, TagOwners: map[string][]string{},
ACLs: []ACL{ ACLs: []ACL{
{Action: "accept", Sources: []string{"admin"}, Destinations: []string{"*:*"}}, {
{Action: "accept", Sources: []string{"test"}, Destinations: []string{"test:*"}}, Action: "accept",
Sources: []string{"admin"},
Destinations: []string{"*:*"},
},
{
Action: "accept",
Sources: []string{"test"},
Destinations: []string{"test:*"},
},
}, },
Tests: []ACLTest{}, Tests: []ACLTest{},
} }

View file

@ -474,7 +474,11 @@ func (h *Headscale) validateMachineForOIDCCallback(
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to refresh machine") Msg("Failed to refresh machine")
http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError) http.Error(
writer,
"Failed to refresh machine",
http.StatusInternalServerError,
)
return nil, true, err return nil, true, err
} }

View file

@ -325,7 +325,9 @@ func (h *Headscale) ApplePlatformConfig(
default: default:
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Invalid platform, only ios and macos is supported")) _, err := writer.Write(
[]byte("Invalid platform, only ios and macos is supported"),
)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -362,7 +364,8 @@ func (h *Headscale) ApplePlatformConfig(
return return
} }
writer.Header().Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8") writer.Header().
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes()) _, err = writer.Write(content.Bytes())
if err != nil { if err != nil {