diff --git a/api_key.go b/api_key.go new file mode 100644 index 00000000..a968b260 --- /dev/null +++ b/api_key.go @@ -0,0 +1,164 @@ +package headscale + +import ( + "fmt" + "strings" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "golang.org/x/crypto/bcrypt" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + apiPrefixLength = 7 + apiKeyLength = 32 + apiKeyParts = 2 + + errAPIKeyFailedToParse = Error("Failed to parse ApiKey") +) + +// APIKey describes the datamodel for API keys used to remotely authenticate with +// headscale. +type APIKey struct { + ID uint64 `gorm:"primary_key"` + Prefix string `gorm:"uniqueIndex"` + Hash []byte + + CreatedAt *time.Time + Expiration *time.Time + LastSeen *time.Time +} + +// CreateAPIKey creates a new ApiKey in a namespace, and returns it. +func (h *Headscale) CreateAPIKey( + expiration *time.Time, +) (string, *APIKey, error) { + prefix, err := GenerateRandomStringURLSafe(apiPrefixLength) + if err != nil { + return "", nil, err + } + + toBeHashed, err := GenerateRandomStringURLSafe(apiKeyLength) + if err != nil { + return "", nil, err + } + + // Key to return to user, this will only be visible _once_ + keyStr := prefix + "." + toBeHashed + + hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost) + if err != nil { + return "", nil, err + } + + key := APIKey{ + Prefix: prefix, + Hash: hash, + Expiration: expiration, + } + h.db.Save(&key) + + return keyStr, &key, nil +} + +// ListAPIKeys returns the list of ApiKeys for a namespace. +func (h *Headscale) ListAPIKeys() ([]APIKey, error) { + keys := []APIKey{} + if err := h.db.Find(&keys).Error; err != nil { + return nil, err + } + + return keys, nil +} + +// GetAPIKey returns a ApiKey for a given key. +func (h *Headscale) GetAPIKey(prefix string) (*APIKey, error) { + key := APIKey{} + if result := h.db.First(&key, "prefix = ?", prefix); result.Error != nil { + return nil, result.Error + } + + return &key, nil +} + +// GetAPIKeyByID returns a ApiKey for a given id. +func (h *Headscale) GetAPIKeyByID(id uint64) (*APIKey, error) { + key := APIKey{} + if result := h.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil { + return nil, result.Error + } + + return &key, nil +} + +// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey +// does not exist. +func (h *Headscale) DestroyAPIKey(key APIKey) error { + if result := h.db.Unscoped().Delete(key); result.Error != nil { + return result.Error + } + + return nil +} + +// ExpireAPIKey marks a ApiKey as expired. +func (h *Headscale) ExpireAPIKey(key *APIKey) error { + if err := h.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { + return err + } + + return nil +} + +func (h *Headscale) ValidateAPIKey(keyStr string) (bool, error) { + prefix, hash, err := splitAPIKey(keyStr) + if err != nil { + return false, fmt.Errorf("failed to validate api key: %w", err) + } + + key, err := h.GetAPIKey(prefix) + if err != nil { + return false, fmt.Errorf("failed to validate api key: %w", err) + } + + if key.Expiration.Before(time.Now()) { + return false, nil + } + + if err := bcrypt.CompareHashAndPassword(key.Hash, []byte(hash)); err != nil { + return false, err + } + + return true, nil +} + +func splitAPIKey(key string) (string, string, error) { + parts := strings.Split(key, ".") + if len(parts) != apiKeyParts { + return "", "", errAPIKeyFailedToParse + } + + return parts[0], parts[1], nil +} + +func (key *APIKey) toProto() *v1.ApiKey { + protoKey := v1.ApiKey{ + Id: key.ID, + Prefix: key.Prefix, + } + + if key.Expiration != nil { + protoKey.Expiration = timestamppb.New(*key.Expiration) + } + + if key.CreatedAt != nil { + protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) + } + + if key.LastSeen != nil { + protoKey.LastSeen = timestamppb.New(*key.LastSeen) + } + + return &protoKey +} diff --git a/api_key_test.go b/api_key_test.go new file mode 100644 index 00000000..2ddbbbc0 --- /dev/null +++ b/api_key_test.go @@ -0,0 +1,89 @@ +package headscale + +import ( + "time" + + "gopkg.in/check.v1" +) + +func (*Suite) TestCreateAPIKey(c *check.C) { + apiKeyStr, apiKey, err := app.CreateAPIKey(nil) + c.Assert(err, check.IsNil) + c.Assert(apiKey, check.NotNil) + + // Did we get a valid key? + c.Assert(apiKey.Prefix, check.NotNil) + c.Assert(apiKey.Hash, check.NotNil) + c.Assert(apiKeyStr, check.Not(check.Equals), "") + + _, err = app.ListAPIKeys() + c.Assert(err, check.IsNil) + + keys, err := app.ListAPIKeys() + c.Assert(err, check.IsNil) + c.Assert(len(keys), check.Equals, 1) +} + +func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { + key, err := app.GetAPIKey("does-not-exist") + c.Assert(err, check.NotNil) + c.Assert(key, check.IsNil) +} + +func (*Suite) TestValidateAPIKeyOk(c *check.C) { + nowPlus2 := time.Now().Add(2 * time.Hour) + apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2) + c.Assert(err, check.IsNil) + c.Assert(apiKey, check.NotNil) + + valid, err := app.ValidateAPIKey(apiKeyStr) + c.Assert(err, check.IsNil) + c.Assert(valid, check.Equals, true) +} + +func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { + nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) + apiKeyStr, apiKey, err := app.CreateAPIKey(&nowMinus2) + c.Assert(err, check.IsNil) + c.Assert(apiKey, check.NotNil) + + valid, err := app.ValidateAPIKey(apiKeyStr) + c.Assert(err, check.IsNil) + c.Assert(valid, check.Equals, false) + + now := time.Now() + apiKeyStrNow, apiKey, err := app.CreateAPIKey(&now) + c.Assert(err, check.IsNil) + c.Assert(apiKey, check.NotNil) + + validNow, err := app.ValidateAPIKey(apiKeyStrNow) + c.Assert(err, check.IsNil) + c.Assert(validNow, check.Equals, false) + + validSilly, err := app.ValidateAPIKey("nota.validkey") + c.Assert(err, check.NotNil) + c.Assert(validSilly, check.Equals, false) + + validWithErr, err := app.ValidateAPIKey("produceerrorkey") + c.Assert(err, check.NotNil) + c.Assert(validWithErr, check.Equals, false) +} + +func (*Suite) TestExpireAPIKey(c *check.C) { + nowPlus2 := time.Now().Add(2 * time.Hour) + apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2) + c.Assert(err, check.IsNil) + c.Assert(apiKey, check.NotNil) + + valid, err := app.ValidateAPIKey(apiKeyStr) + c.Assert(err, check.IsNil) + c.Assert(valid, check.Equals, true) + + err = app.ExpireAPIKey(apiKey) + c.Assert(err, check.IsNil) + c.Assert(apiKey.Expiration, check.NotNil) + + notValid, err := app.ValidateAPIKey(apiKeyStr) + c.Assert(err, check.IsNil) + c.Assert(notValid, check.Equals, false) +}