diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index accf439e..e5a47953 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -3,6 +3,7 @@ package db import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "net/netip" @@ -19,6 +20,7 @@ import ( "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" + "tailscale.com/util/set" ) var errDatabaseNotSupported = errors.New("database type not supported") @@ -291,7 +293,12 @@ func NewHeadscaleDatabase( return err } - err = tx.AutoMigrate(&types.PreAuthKeyACLTag{}) + type preAuthKeyACLTag struct { + ID uint64 `gorm:"primary_key"` + PreAuthKeyID uint64 + Tag string + } + err = tx.AutoMigrate(&preAuthKeyACLTag{}) if err != nil { return err } @@ -413,6 +420,54 @@ func NewHeadscaleDatabase( }, Rollback: func(db *gorm.DB) error { return nil }, }, + // denormalise the ACL tags for preauth keys back onto + // the preauth key table. We dont normalise or reuse and + // it is just a bunch of work for extra work. + { + ID: "202409271400", + Migrate: func(tx *gorm.DB) error { + preauthkeyTags := map[uint64]set.Set[string]{} + + type preAuthKeyACLTag struct { + ID uint64 `gorm:"primary_key"` + PreAuthKeyID uint64 + Tag string + } + + var aclTags []preAuthKeyACLTag + if err := tx.Find(&aclTags).Error; err != nil { + return err + } + + // Store the current tags. + for _, tag := range aclTags { + if preauthkeyTags[tag.PreAuthKeyID] == nil { + preauthkeyTags[tag.PreAuthKeyID] = set.SetOf([]string{tag.Tag}) + } else { + preauthkeyTags[tag.PreAuthKeyID].Add(tag.Tag) + } + } + + // Add tags column and restore the tags. + _ = tx.Migrator().AddColumn(&types.PreAuthKey{}, "tags") + for keyID, tags := range preauthkeyTags { + s := tags.Slice() + j, err := json.Marshal(s) + if err != nil { + return err + } + if err := tx.Model(&types.PreAuthKey{}).Where("id = ?", keyID).Update("tags", string(j)).Error; err != nil { + return err + } + } + + // Drop the old table. + _ = tx.Migrator().DropTable(&preAuthKeyACLTag{}) + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, }, ) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index b32d93ce..157ede8b 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -6,6 +6,8 @@ import ( "net/netip" "os" "path/filepath" + "slices" + "sort" "testing" "github.com/google/go-cmp/cmp" @@ -108,6 +110,68 @@ func TestMigrations(t *testing.T) { } }, }, + // at 14:15:06 ❯ go run ./cmd/headscale preauthkeys list + // ID | Key | Reusable | Ephemeral | Used | Expiration | Created | Tags + // 1 | 09b28f.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:derp + // 2 | 3112b9.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:derp + // 3 | 7c23b9.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:derp,tag:merp + // 4 | f20155.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:test + // 5 | b212b9.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:test,tag:woop,tag:dedu + { + dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite", + wantFunc: func(t *testing.T, h *HSDatabase) { + keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { + kratest, err := ListPreAuthKeys(rx, "kratest") + if err != nil { + return nil, err + } + + testkra, err := ListPreAuthKeys(rx, "testkra") + if err != nil { + return nil, err + } + + return append(kratest, testkra...), nil + }) + assert.NoError(t, err) + + assert.Len(t, keys, 5) + want := []types.PreAuthKey{ + { + ID: 1, + Tags: []string{"tag:derp"}, + }, + { + ID: 2, + Tags: []string{"tag:derp"}, + }, + { + ID: 3, + Tags: []string{"tag:derp", "tag:merp"}, + }, + { + ID: 4, + Tags: []string{"tag:test"}, + }, + { + ID: 5, + Tags: []string{"tag:test", "tag:woop", "tag:dedu"}, + }, + } + + if diff := cmp.Diff(want, keys, cmp.Comparer(func(a, b []string) bool { + sort.Sort(sort.StringSlice(a)) + sort.Sort(sort.StringSlice(b)) + return slices.Equal(a, b) + }), cmpopts.IgnoreFields(types.PreAuthKey{}, "Key", "UserID", "User", "CreatedAt", "Expiration")); diff != "" { + t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff) + } + + if h.DB.Migrator().HasTable("pre_auth_key_acl_tags") { + t.Errorf("TestMigrations() table pre_auth_key_acl_tags should not exist") + } + }, + }, } for _, tt := range tests { diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 5ea59a9c..96420211 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "gorm.io/gorm" "tailscale.com/types/ptr" + "tailscale.com/util/set" ) var ( @@ -47,6 +48,11 @@ func CreatePreAuthKey( return nil, err } + // Remove duplicates + aclTags = set.SetOf(aclTags).Slice() + + // TODO(kradalby): factor out and create a reusable tag validation, + // check if there is one in Tailscale's lib. for _, tag := range aclTags { if !strings.HasPrefix(tag, "tag:") { return nil, fmt.Errorf( @@ -71,28 +77,13 @@ func CreatePreAuthKey( Ephemeral: ephemeral, CreatedAt: &now, Expiration: expiration, + Tags: types.StringList(aclTags), } if err := tx.Save(&key).Error; err != nil { return nil, fmt.Errorf("failed to create key in the database: %w", err) } - if len(aclTags) > 0 { - seenTags := map[string]bool{} - - for _, tag := range aclTags { - if !seenTags[tag] { - if err := tx.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { - return nil, fmt.Errorf( - "failed to create key tag in the database: %w", - err, - ) - } - seenTags[tag] = true - } - } - } - return &key, nil } @@ -110,7 +101,7 @@ func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) { } keys := []types.PreAuthKey{} - if err := tx.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { + if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { return nil, err } @@ -135,10 +126,6 @@ func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, err // does not exist. func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error { return tx.Transaction(func(db *gorm.DB) error { - if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { - return result.Error - } - if result := db.Unscoped().Delete(pak); result.Error != nil { return result.Error } @@ -182,7 +169,7 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) // If returns no error and a PreAuthKey, it can be used. func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { pak := types.PreAuthKey{} - if result := tx.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( + if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { diff --git a/hscontrol/db/testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite b/hscontrol/db/testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite new file mode 100644 index 00000000..512c4879 Binary files /dev/null and b/hscontrol/db/testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite differ diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 8b02569a..ba3b597b 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -16,21 +16,14 @@ type PreAuthKey struct { UserID uint User User `gorm:"constraint:OnDelete:CASCADE;"` Reusable bool - Ephemeral bool `gorm:"default:false"` - Used bool `gorm:"default:false"` - ACLTags []PreAuthKeyACLTag `gorm:"constraint:OnDelete:CASCADE;"` + Ephemeral bool `gorm:"default:false"` + Used bool `gorm:"default:false"` + Tags []string `gorm:"serializer:json"` CreatedAt *time.Time Expiration *time.Time } -// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. -type PreAuthKeyACLTag struct { - ID uint64 `gorm:"primary_key"` - PreAuthKeyID uint64 - Tag string -} - func (key *PreAuthKey) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ User: key.User.Name, @@ -39,7 +32,7 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { Ephemeral: key.Ephemeral, Reusable: key.Reusable, Used: key.Used, - AclTags: make([]string, len(key.ACLTags)), + AclTags: key.Tags, } if key.Expiration != nil { @@ -50,9 +43,5 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) } - for idx := range key.ACLTags { - protoKey.AclTags[idx] = key.ACLTags[idx].Tag - } - return &protoKey }