diff --git a/grpcv1.go b/grpcv1.go index 1c891b94..b1e5c1ee 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -3,6 +3,7 @@ package headscale import ( "context" + "fmt" "strings" "time" @@ -195,13 +196,11 @@ func (api headscaleV1APIServer) SetTags( } for _, tag := range request.GetTags() { - if strings.Index(tag, "tag:") != 0 { + err := validateTag(tag) + if err != nil { return &v1.SetTagsResponse{ Machine: nil, - }, status.Error( - codes.InvalidArgument, - "Invalid tag detected. Each tag must start with the string 'tag:'", - ) + }, status.Error(codes.InvalidArgument, err.Error()) } } @@ -220,6 +219,19 @@ func (api headscaleV1APIServer) SetTags( return &v1.SetTagsResponse{Machine: machine.toProto()}, nil } +func validateTag(tag string) error { + if strings.Index(tag, "tag:") != 0 { + return fmt.Errorf("tag must start with the string 'tag:'") + } + if strings.ToLower(tag) != tag { + return fmt.Errorf("tag should be lowercase") + } + if len(strings.Fields(tag)) > 1 { + return fmt.Errorf("tag should not contains space") + } + return nil +} + func (api headscaleV1APIServer) DeleteMachine( ctx context.Context, request *v1.DeleteMachineRequest, diff --git a/grpcv1_test.go b/grpcv1_test.go new file mode 100644 index 00000000..e48ae1ef --- /dev/null +++ b/grpcv1_test.go @@ -0,0 +1,42 @@ +package headscale + +import "testing" + +func Test_validateTag(t *testing.T) { + type args struct { + tag string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "valid tag", + args: args{tag: "tag:test"}, + wantErr: false, + }, + { + name: "tag without tag prefix", + args: args{tag: "test"}, + wantErr: true, + }, + { + name: "uppercase tag", + args: args{tag: "tag:tEST"}, + wantErr: true, + }, + { + name: "tag that contains space", + args: args{tag: "tag:this is a spaced tag"}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr { + t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/integration_cli_test.go b/integration_cli_test.go index f9ff5ec0..2f58e71d 100644 --- a/integration_cli_test.go +++ b/integration_cli_test.go @@ -625,7 +625,7 @@ func (s *IntegrationCLITestSuite) TestNodeTagCommand() { var errorOutput errOutput err = json.Unmarshal([]byte(wrongTagResult), &errorOutput) assert.Nil(s.T(), err) - assert.Contains(s.T(), errorOutput.Error, "Invalid tag detected") + assert.Contains(s.T(), errorOutput.Error, "tag must start with the string 'tag:'") // Test list all nodes after added seconds listAllResult, err := ExecuteCommand( diff --git a/machine.go b/machine.go index 1bed2955..dda49020 100644 --- a/machine.go +++ b/machine.go @@ -374,7 +374,13 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { // SetTags takes a Machine struct pointer and update the forced tags. func (h *Headscale) SetTags(machine *Machine, tags []string) error { - machine.ForcedTags = tags + newTags := []string{} + for _, tag := range tags { + if !contains(newTags, tag) { + newTags = append(newTags, tag) + } + } + machine.ForcedTags = newTags if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) { return err } diff --git a/machine_test.go b/machine_test.go index a06d0db2..35c3eed9 100644 --- a/machine_test.go +++ b/machine_test.go @@ -280,6 +280,49 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { } } +func (s *Suite) TestSetTags(c *check.C) { + namespace, err := app.CreateNamespace("test") + c.Assert(err, check.IsNil) + + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = app.GetMachine("test", "testmachine") + c.Assert(err, check.NotNil) + + machine := &Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "testmachine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + app.db.Save(machine) + + // assign simple tags + sTags := []string{"tag:test", "tag:foo"} + err = app.SetTags(machine, sTags) + c.Assert(err, check.IsNil) + machine, err = app.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags)) + + // assign duplicat tags, expect no errors but no doubles in DB + eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} + err = app.SetTags(machine, eTags) + c.Assert(err, check.IsNil) + machine, err = app.GetMachine("test", "testmachine") + c.Assert(err, check.IsNil) + c.Assert( + machine.ForcedTags, + check.DeepEquals, + StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), + ) +} + func Test_getTags(t *testing.T) { type args struct { aclPolicy *ACLPolicy