headscale/hscontrol/db/preauth_keys.go
Kristoffer Dalby 45752db0f6
Some checks are pending
Build / build-nix (push) Waiting to run
Build / build-cross (GOARCH=386 GOOS=linux) (push) Waiting to run
Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Waiting to run
Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=5) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=6) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=7) (push) Waiting to run
Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Waiting to run
Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Waiting to run
Tests / test (push) Waiting to run
Return better web errors to the user (#2398)
* add dedicated http error to propagate to user

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* classify user errors in http handlers

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* move validation of pre auth key out of db

This move separates the logic a bit and allow us to
write specific errors for the caller, in this case the web
layer so we can present the user with the correct error
codes without bleeding web stuff into a generic validate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* update changelog

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2025-02-01 15:25:18 +01:00

173 lines
4.3 KiB
Go

package db
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm"
"tailscale.com/util/set"
)
var (
ErrPreAuthKeyNotFound = errors.New("AuthKey not found")
ErrPreAuthKeyExpired = errors.New("AuthKey expired")
ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used")
ErrUserMismatch = errors.New("user mismatch")
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
)
func (hsdb *HSDatabase) CreatePreAuthKey(
uid types.UserID,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
})
}
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func CreatePreAuthKey(
tx *gorm.DB,
uid types.UserID,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKey, error) {
user, err := GetUserByID(tx, uid)
if err != nil {
return nil, err
}
// Remove duplicates
aclTags = set.SetOf(aclTags).Slice()
// TODO(kradalby): factor out and create a reusable tag validation,
// check if there is one in Tailscale's lib.
for _, tag := range aclTags {
if !strings.HasPrefix(tag, "tag:") {
return nil, fmt.Errorf(
"%w: '%s' did not begin with 'tag:'",
ErrPreAuthKeyACLTagInvalid,
tag,
)
}
}
now := time.Now().UTC()
// TODO(kradalby): unify the key generations spread all over the code.
kstr, err := generateKey()
if err != nil {
return nil, err
}
key := types.PreAuthKey{
Key: kstr,
UserID: user.ID,
User: *user,
Reusable: reusable,
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
Tags: aclTags,
}
if err := tx.Save(&key).Error; err != nil {
return nil, fmt.Errorf("failed to create key in the database: %w", err)
}
return &key, nil
}
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeysByUser(rx, uid)
})
}
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
user, err := GetUserByID(tx, uid)
if err != nil {
return nil, err
}
keys := []types.PreAuthKey{}
if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return GetPreAuthKey(rx, key)
})
}
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
// for checking if the key is usable (expired or used).
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil {
return nil, ErrPreAuthKeyNotFound
}
return &pak, nil
}
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist.
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
return tx.Transaction(func(db *gorm.DB) error {
if result := db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
}
return nil
})
}
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
return hsdb.Write(func(tx *gorm.DB) error {
return ExpirePreAuthKey(tx, k)
})
}
// UsePreAuthKey marks a PreAuthKey as used.
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
k.Used = true
if err := tx.Save(k).Error; err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err)
}
return nil
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
return nil
}
func generateKey() (string, error) {
size := 24
bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}