Return better web errors to the user (#2398)
Some checks are pending
Build / build-nix (push) Waiting to run
Build / build-cross (GOARCH=386 GOOS=linux) (push) Waiting to run
Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Waiting to run
Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=5) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=6) (push) Waiting to run
Build / build-cross (GOARCH=arm GOOS=linux GOARM=7) (push) Waiting to run
Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Waiting to run
Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Waiting to run
Tests / test (push) Waiting to run

* add dedicated http error to propagate to user

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* classify user errors in http handlers

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* move validation of pre auth key out of db

This move separates the logic a bit and allow us to
write specific errors for the caller, in this case the web
layer so we can present the user with the correct error
codes without bleeding web stuff into a generic validate.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* update changelog

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-02-01 15:25:18 +01:00 committed by GitHub
parent 1c7f3bc440
commit 45752db0f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 268 additions and 229 deletions

View file

@ -28,6 +28,8 @@
[#2396](https://github.com/juanfont/headscale/pull/2396)
- Pre auth keys that are used by a node can no longer be deleted
[#2396](https://github.com/juanfont/headscale/pull/2396)
- Rehaul HTTP errors, return better status code and errors to users
[#2398](https://github.com/juanfont/headscale/pull/2398)
## 0.24.2 (2025-01-30)

View file

@ -72,7 +72,7 @@ func (h *Headscale) handleExistingNode(
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
if node.MachineKey != machineKey {
return nil, errors.New("node already exists with different machine key")
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
}
expired := node.IsExpired()
@ -81,7 +81,7 @@ func (h *Headscale) handleExistingNode(
// The client is trying to extend their key, this is not allowed.
if requestExpiry.After(time.Now()) {
return nil, errors.New("extending key is not allowed")
return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
}
// If the request expiry is in the past, we consider it a logout.
@ -155,13 +155,42 @@ func (h *Headscale) waitForFollowup(
}
}
// canUsePreAuthKey checks if a pre auth key can be used.
func canUsePreAuthKey(pak *types.PreAuthKey) error {
if pak == nil {
return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil)
}
// we don't need to check if has been used before
if pak.Reusable {
return nil
}
if pak.Used {
return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil)
}
return nil
}
func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey)
pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey)
if err != nil {
return nil, fmt.Errorf("invalid pre auth key: %w", err)
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
return nil, err
}
err = canUsePreAuthKey(pak)
if err != nil {
return nil, err
}
nodeToRegister := types.Node{

130
hscontrol/auth_test.go Normal file
View file

@ -0,0 +1,130 @@
package hscontrol
import (
"net/http"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
)
func TestCanUsePreAuthKey(t *testing.T) {
now := time.Now()
past := now.Add(-time.Hour)
future := now.Add(time.Hour)
tests := []struct {
name string
pak *types.PreAuthKey
wantErr bool
err HTTPError
}{
{
name: "valid reusable key",
pak: &types.PreAuthKey{
Reusable: true,
Used: false,
Expiration: &future,
},
wantErr: false,
},
{
name: "valid non-reusable key",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: &future,
},
wantErr: false,
},
{
name: "expired key",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: &past,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
},
{
name: "used non-reusable key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: &future,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
},
{
name: "used reusable key",
pak: &types.PreAuthKey{
Reusable: true,
Used: true,
Expiration: &future,
},
wantErr: false,
},
{
name: "no expiration date",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: nil,
},
wantErr: false,
},
{
name: "nil preauth key",
pak: nil,
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil),
},
{
name: "expired and used key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: &past,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
},
{
name: "no expiration and used key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: nil,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := canUsePreAuthKey(tt.pak)
if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
} else {
httpErr, ok := err.(HTTPError)
if !ok {
t.Errorf("expected HTTPError but got %T", err)
} else {
if diff := cmp.Diff(tt.err, httpErr); diff != "" {
t.Errorf("unexpected error (-want +got):\n%s", diff)
}
}
}
} else {
if err != nil {
t.Errorf("expected no error but got %v", err)
}
}
})
}
}

View file

@ -10,7 +10,6 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm"
"tailscale.com/types/ptr"
"tailscale.com/util/set"
)
@ -64,6 +63,7 @@ func CreatePreAuthKey(
}
now := time.Now().UTC()
// TODO(kradalby): unify the key generations spread all over the code.
kstr, err := generateKey()
if err != nil {
return nil, err
@ -108,18 +108,21 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
return keys, nil
}
// GetPreAuthKey returns a PreAuthKey for a given key.
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
pak, err := ValidatePreAuthKey(tx, key)
if err != nil {
return nil, err
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return GetPreAuthKey(rx, key)
})
}
if pak.User.Name != user {
return nil, ErrUserMismatch
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
// for checking if the key is usable (expired or used).
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil {
return nil, ErrPreAuthKeyNotFound
}
return pak, nil
return &pak, nil
}
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
@ -140,15 +143,6 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
})
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
return nil
}
// UsePreAuthKey marks a PreAuthKey as used.
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
k.Used = true
@ -159,44 +153,13 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
return nil
}
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return ValidatePreAuthKey(rx, k)
})
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used.
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrPreAuthKeyNotFound
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return nil, ErrPreAuthKeyExpired
}
if pak.Reusable { // we don't need to check if has been used before
return &pak, nil
}
nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}).
Find(&nodes).Error; err != nil {
return nil, err
}
if len(nodes) != 0 || pak.Used {
return nil, ErrSingleUseAuthKeyHasBeenUsed
}
return &pak, nil
return nil
}
func generateKey() (string, error) {

View file

@ -3,14 +3,14 @@ package db
import (
"sort"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/check.v1"
"tailscale.com/types/ptr"
"gopkg.in/check.v1"
)
func (*Suite) TestCreatePreAuthKey(c *check.C) {
@ -43,123 +43,6 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
c.Assert((keys)[0].User.ID, check.Equals, user.ID)
}
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test2"})
c.Assert(err, check.IsNil)
now := time.Now().Add(-5 * time.Second)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil)
c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
c.Assert(key, check.IsNil)
}
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
key, err := db.ValidatePreAuthKey("potatoKey")
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
c.Assert(key, check.IsNil)
}
func (*Suite) TestValidateKeyOk(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test3"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestAlreadyUsedKey(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test4"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
node := types.Node{
ID: 0,
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
c.Assert(key, check.IsNil)
}
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test5"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
c.Assert(err, check.IsNil)
node := types.Node{
ID: 1,
Hostname: "testest",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test6"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestExpirePreauthKey(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test3"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.IsNil)
err = db.ExpirePreAuthKey(pak)
c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.NotNil)
key, err := db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
c.Assert(key, check.IsNil)
}
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test6"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
pak.Used = true
db.DB.Save(&pak)
_, err = db.ValidatePreAuthKey(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
}
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test8"})
c.Assert(err, check.IsNil)

View file

@ -181,11 +181,15 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
err := api.h.db.Write(func(tx *gorm.DB) error {
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
preAuthKey, err := db.GetPreAuthKey(tx, request.Key)
if err != nil {
return err
}
if preAuthKey.User.Name != request.GetUser() {
return fmt.Errorf("preauth key does not belong to user")
}
return db.ExpirePreAuthKey(tx, preAuthKey)
})
if err != nil {

View file

@ -33,10 +33,33 @@ const (
)
// httpError logs an error and sends an HTTP error response with the given
func httpError(w http.ResponseWriter, err error, userError string, code int) {
log.Error().Err(err).Msg(userError)
http.Error(w, userError, code)
func httpError(w http.ResponseWriter, err error) {
var herr HTTPError
if errors.As(err, &herr) {
http.Error(w, herr.Msg, herr.Code)
log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg)
} else {
http.Error(w, "internal server error", http.StatusInternalServerError)
log.Error().Err(err).Int("code", http.StatusInternalServerError).Msg("http internal server error")
}
}
// HTTPError represents an error that is surfaced to the user via web.
type HTTPError struct {
Code int // HTTP response code to send to client; 0 means 500
Msg string // Response body to send to client
Err error // Detailed error to log on the server
}
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
func (e HTTPError) Unwrap() error { return e.Err }
// Error returns an HTTPError containing the given information.
func NewHTTPError(code int, msg string, err error) HTTPError {
return HTTPError{Code: code, Msg: msg, Err: err}
}
var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil)
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
"machines registered with CLI does not support expire",
@ -47,12 +70,12 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error)
clientCapabilityStr := req.URL.Query().Get("v")
if clientCapabilityStr == "" {
return 0, ErrNoCapabilityVersion
return 0, NewHTTPError(http.StatusBadRequest, "capability version must be set", nil)
}
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
if err != nil {
return 0, fmt.Errorf("failed to parse capability version: %w", err)
return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err))
}
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
@ -85,13 +108,13 @@ func (h *Headscale) VerifyHandler(
req *http.Request,
) {
if req.Method != http.MethodPost {
httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed)
httpError(writer, errMethodNotAllowed)
return
}
allow, err := h.derpRequestIsAllowed(req)
if err != nil {
httpError(writer, err, "Internal error", http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -112,7 +135,7 @@ func (h *Headscale) KeyHandler(
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
capVer, err := parseCabailityVersion(req)
if err != nil {
httpError(writer, err, "Internal error", http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -199,7 +222,7 @@ func (a *AuthProviderWeb) RegisterHandler(
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil {
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
return
}

View file

@ -3,6 +3,7 @@ package hscontrol
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@ -12,6 +13,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"gorm.io/gorm"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp/controlhttpserver"
"tailscale.com/tailcfg"
@ -81,7 +83,7 @@ func (h *Headscale) NoiseUpgradeHandler(
noiseServer.earlyNoise,
)
if err != nil {
httpError(writer, err, "noise upgrade failed", http.StatusInternalServerError)
httpError(writer, fmt.Errorf("noise upgrade failed: %w", err))
return
}
@ -198,7 +200,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
var mapRequest tailcfg.MapRequest
if err := json.Unmarshal(body, &mapRequest); err != nil {
httpError(writer, err, "Internal error", http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -211,7 +213,11 @@ func (ns *noiseServer) NoisePollNetMapHandler(
node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey)
if err != nil {
httpError(writer, err, "Internal error", http.StatusInternalServerError)
if errors.Is(err, gorm.ErrRecordNotFound) {
httpError(writer, NewHTTPError(http.StatusNotFound, "node not found", nil))
return
}
httpError(writer, err)
return
}
@ -230,7 +236,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
req *http.Request,
) {
if req.Method != http.MethodPost {
httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed)
httpError(writer, errMethodNotAllowed)
return
}

View file

@ -141,21 +141,21 @@ func (a *AuthProviderOIDC) RegisterHandler(
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil {
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
return
}
// Set the state and nonce cookies to protect against CSRF attacks
state, err := setCSRFCookie(writer, req, "state")
if err != nil {
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
httpError(writer, err)
return
}
// Set the state and nonce cookies to protect against CSRF attacks
nonce, err := setCSRFCookie(writer, req, "nonce")
if err != nil {
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -219,34 +219,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
) {
code, state, err := extractCodeAndStateParamFromRequest(req)
if err != nil {
httpError(writer, err, err.Error(), http.StatusBadRequest)
httpError(writer, err)
return
}
cookieState, err := req.Cookie("state")
if err != nil {
httpError(writer, err, "state not found", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
return
}
if state != cookieState.Value {
httpError(writer, err, "state did not match", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil))
return
}
idToken, err := a.extractIDToken(req.Context(), code, state)
if err != nil {
httpError(writer, err, err.Error(), http.StatusBadRequest)
httpError(writer, err)
return
}
nonce, err := req.Cookie("nonce")
if err != nil {
httpError(writer, err, "nonce not found", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
return
}
if idToken.Nonce != nonce.Value {
httpError(writer, err, "nonce did not match", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
return
}
@ -254,29 +254,28 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
var claims types.OIDCClaims
if err := idToken.Claims(&claims); err != nil {
err = fmt.Errorf("decoding ID token claims: %w", err)
httpError(writer, err, err.Error(), http.StatusInternalServerError)
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
return
}
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
httpError(writer, err, err.Error(), http.StatusUnauthorized)
httpError(writer, err)
return
}
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
httpError(writer, err, err.Error(), http.StatusUnauthorized)
httpError(writer, err)
return
}
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
httpError(writer, err, err.Error(), http.StatusUnauthorized)
httpError(writer, err)
return
}
user, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
httpError(writer, err, err.Error(), http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -289,9 +288,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil {
httpError(writer, err, err.Error(), http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -302,7 +301,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// TODO(kradalby): replace with go-elem
content, err := renderOIDCCallbackTemplate(user, verb)
if err != nil {
httpError(writer, err, err.Error(), http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -317,7 +316,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, nil, "login session expired, try again", http.StatusInternalServerError)
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
@ -328,7 +327,7 @@ func extractCodeAndStateParamFromRequest(
state := req.URL.Query().Get("state")
if code == "" || state == "" {
return "", "", errEmptyOIDCCallbackParams
return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams)
}
return code, state, nil
@ -346,7 +345,7 @@ func (a *AuthProviderOIDC) extractIDToken(
if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, errNoOIDCRegistrationInfo
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
}
if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
@ -355,18 +354,18 @@ func (a *AuthProviderOIDC) extractIDToken(
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
if err != nil {
return nil, fmt.Errorf("could not exchange code for token: %w", err)
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, errNoOIDCIDToken
return nil, NewHTTPError(http.StatusBadRequest, "no id_token", errNoOIDCIDToken)
}
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("failed to verify ID token: %w", err)
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err))
}
return idToken, nil
@ -381,7 +380,7 @@ func validateOIDCAllowedDomains(
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
return errOIDCAllowedDomains
return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains)
}
}
@ -403,7 +402,7 @@ func validateOIDCAllowedGroups(
}
}
return errOIDCAllowedGroups
return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups)
}
return nil
@ -417,7 +416,7 @@ func validateOIDCAllowedUsers(
) error {
if len(allowedUsers) > 0 &&
!slices.Contains(allowedUsers, claims.Email) {
return errOIDCAllowedUsers
return NewHTTPError(http.StatusUnauthorized, "unauthorised user", errOIDCAllowedUsers)
}
return nil
@ -488,7 +487,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return user, nil
}
func (a *AuthProviderOIDC) handleRegistrationID(
func (a *AuthProviderOIDC) handleRegistration(
user *types.User,
registrationID types.RegistrationID,
expiry time.Time,

View file

@ -39,19 +39,19 @@ func (h *Headscale) ApplePlatformConfig(
vars := mux.Vars(req)
platform, ok := vars["platform"]
if !ok {
httpError(writer, nil, "No platform specified", http.StatusBadRequest)
httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil))
return
}
id, err := uuid.NewV4()
if err != nil {
httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError)
httpError(writer, err)
return
}
contentID, err := uuid.NewV4()
if err != nil {
httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError)
httpError(writer, err)
return
}
@ -65,21 +65,21 @@ func (h *Headscale) ApplePlatformConfig(
switch platform {
case "macos-standalone":
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil {
httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError)
httpError(writer, err)
return
}
case "macos-app-store":
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil {
httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError)
httpError(writer, err)
return
}
case "ios":
if err := iosTemplate.Execute(&payload, platformConfig); err != nil {
httpError(writer, err, "Could not render Apple iOS template", http.StatusInternalServerError)
httpError(writer, err)
return
}
default:
httpError(writer, err, "Invalid platform. Only ios, macos-app-store and macos-standalone are supported", http.StatusInternalServerError)
httpError(writer, NewHTTPError(http.StatusBadRequest, "platform must be ios, macos-app-store or macos-standalone", nil))
return
}
@ -91,7 +91,7 @@ func (h *Headscale) ApplePlatformConfig(
var content bytes.Buffer
if err := commonTemplate.Execute(&content, config); err != nil {
httpError(writer, err, "Could not render platform iOS template", http.StatusInternalServerError)
httpError(writer, err)
return
}