mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
Merge pull request #689 from restanrm/fix-duplicate-tags-returned-by-api
Remove duplicate tags if sent by the client
This commit is contained in:
commit
32a8f06486
5 changed files with 110 additions and 7 deletions
22
grpcv1.go
22
grpcv1.go
|
@ -3,6 +3,7 @@ package headscale
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -195,13 +196,11 @@ func (api headscaleV1APIServer) SetTags(
|
|||
}
|
||||
|
||||
for _, tag := range request.GetTags() {
|
||||
if strings.Index(tag, "tag:") != 0 {
|
||||
err := validateTag(tag)
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Machine: nil,
|
||||
}, status.Error(
|
||||
codes.InvalidArgument,
|
||||
"Invalid tag detected. Each tag must start with the string 'tag:'",
|
||||
)
|
||||
}, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -220,6 +219,19 @@ func (api headscaleV1APIServer) SetTags(
|
|||
return &v1.SetTagsResponse{Machine: machine.toProto()}, nil
|
||||
}
|
||||
|
||||
func validateTag(tag string) error {
|
||||
if strings.Index(tag, "tag:") != 0 {
|
||||
return fmt.Errorf("tag must start with the string 'tag:'")
|
||||
}
|
||||
if strings.ToLower(tag) != tag {
|
||||
return fmt.Errorf("tag should be lowercase")
|
||||
}
|
||||
if len(strings.Fields(tag)) > 1 {
|
||||
return fmt.Errorf("tag should not contains space")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) DeleteMachine(
|
||||
ctx context.Context,
|
||||
request *v1.DeleteMachineRequest,
|
||||
|
|
42
grpcv1_test.go
Normal file
42
grpcv1_test.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package headscale
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_validateTag(t *testing.T) {
|
||||
type args struct {
|
||||
tag string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid tag",
|
||||
args: args{tag: "tag:test"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "tag without tag prefix",
|
||||
args: args{tag: "test"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase tag",
|
||||
args: args{tag: "tag:tEST"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "tag that contains space",
|
||||
args: args{tag: "tag:this is a spaced tag"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -625,7 +625,7 @@ func (s *IntegrationCLITestSuite) TestNodeTagCommand() {
|
|||
var errorOutput errOutput
|
||||
err = json.Unmarshal([]byte(wrongTagResult), &errorOutput)
|
||||
assert.Nil(s.T(), err)
|
||||
assert.Contains(s.T(), errorOutput.Error, "Invalid tag detected")
|
||||
assert.Contains(s.T(), errorOutput.Error, "tag must start with the string 'tag:'")
|
||||
|
||||
// Test list all nodes after added seconds
|
||||
listAllResult, err := ExecuteCommand(
|
||||
|
|
|
@ -374,7 +374,13 @@ func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
|
|||
|
||||
// SetTags takes a Machine struct pointer and update the forced tags.
|
||||
func (h *Headscale) SetTags(machine *Machine, tags []string) error {
|
||||
machine.ForcedTags = tags
|
||||
newTags := []string{}
|
||||
for _, tag := range tags {
|
||||
if !contains(newTags, tag) {
|
||||
newTags = append(newTags, tag)
|
||||
}
|
||||
}
|
||||
machine.ForcedTags = newTags
|
||||
if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -280,6 +280,49 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetTags(c *check.C) {
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
NamespaceID: namespace.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
app.db.Save(machine)
|
||||
|
||||
// assign simple tags
|
||||
sTags := []string{"tag:test", "tag:foo"}
|
||||
err = app.SetTags(machine, sTags)
|
||||
c.Assert(err, check.IsNil)
|
||||
machine, err = app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.ForcedTags, check.DeepEquals, StringList(sTags))
|
||||
|
||||
// assign duplicat tags, expect no errors but no doubles in DB
|
||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||
err = app.SetTags(machine, eTags)
|
||||
c.Assert(err, check.IsNil)
|
||||
machine, err = app.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(
|
||||
machine.ForcedTags,
|
||||
check.DeepEquals,
|
||||
StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||
)
|
||||
}
|
||||
|
||||
func Test_getTags(t *testing.T) {
|
||||
type args struct {
|
||||
aclPolicy *ACLPolicy
|
||||
|
|
Loading…
Reference in a new issue