From 45752db0f6ed7f2e72ba4946ff8462682198dae8 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 1 Feb 2025 15:25:18 +0100 Subject: [PATCH] Return better web errors to the user (#2398) * add dedicated http error to propagate to user Signed-off-by: Kristoffer Dalby * classify user errors in http handlers Signed-off-by: Kristoffer Dalby * move validation of pre auth key out of db This move separates the logic a bit and allow us to write specific errors for the caller, in this case the web layer so we can present the user with the correct error codes without bleeding web stuff into a generic validate. Signed-off-by: Kristoffer Dalby * update changelog Signed-off-by: Kristoffer Dalby --------- Signed-off-by: Kristoffer Dalby --- CHANGELOG.md | 2 + hscontrol/auth.go | 37 ++++++++- hscontrol/auth_test.go | 130 ++++++++++++++++++++++++++++++ hscontrol/db/preauth_keys.go | 75 +++++------------ hscontrol/db/preauth_keys_test.go | 121 +-------------------------- hscontrol/grpcv1.go | 6 +- hscontrol/handlers.go | 41 +++++++--- hscontrol/noise.go | 14 +++- hscontrol/oidc.go | 55 +++++++------ hscontrol/platform_config.go | 16 ++-- 10 files changed, 268 insertions(+), 229 deletions(-) create mode 100644 hscontrol/auth_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d5c2245..eee03861 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ [#2396](https://github.com/juanfont/headscale/pull/2396) - Pre auth keys that are used by a node can no longer be deleted [#2396](https://github.com/juanfont/headscale/pull/2396) +- Rehaul HTTP errors, return better status code and errors to users + [#2398](https://github.com/juanfont/headscale/pull/2398) ## 0.24.2 (2025-01-30) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 3fa5fa4b..7695f1ae 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -72,7 +72,7 @@ func (h *Headscale) handleExistingNode( machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { if node.MachineKey != machineKey { - return nil, errors.New("node already exists with different machine key") + return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) } expired := node.IsExpired() @@ -81,7 +81,7 @@ func (h *Headscale) handleExistingNode( // The client is trying to extend their key, this is not allowed. if requestExpiry.After(time.Now()) { - return nil, errors.New("extending key is not allowed") + return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil) } // If the request expiry is in the past, we consider it a logout. @@ -155,13 +155,42 @@ func (h *Headscale) waitForFollowup( } } +// canUsePreAuthKey checks if a pre auth key can be used. +func canUsePreAuthKey(pak *types.PreAuthKey) error { + if pak == nil { + return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil) + } + if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { + return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil) + } + + // we don't need to check if has been used before + if pak.Reusable { + return nil + } + + if pak.Used { + return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil) + } + + return nil +} + func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey) + pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey) if err != nil { - return nil, fmt.Errorf("invalid pre auth key: %w", err) + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) + } + return nil, err + } + + err = canUsePreAuthKey(pak) + if err != nil { + return nil, err } nodeToRegister := types.Node{ diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go new file mode 100644 index 00000000..7c0c0d42 --- /dev/null +++ b/hscontrol/auth_test.go @@ -0,0 +1,130 @@ +package hscontrol + +import ( + "net/http" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" +) + +func TestCanUsePreAuthKey(t *testing.T) { + now := time.Now() + past := now.Add(-time.Hour) + future := now.Add(time.Hour) + + tests := []struct { + name string + pak *types.PreAuthKey + wantErr bool + err HTTPError + }{ + { + name: "valid reusable key", + pak: &types.PreAuthKey{ + Reusable: true, + Used: false, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "valid non-reusable key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: false, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "expired key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: false, + Expiration: &past, + }, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil), + }, + { + name: "used non-reusable key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: true, + Expiration: &future, + }, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil), + }, + { + name: "used reusable key", + pak: &types.PreAuthKey{ + Reusable: true, + Used: true, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "no expiration date", + pak: &types.PreAuthKey{ + Reusable: false, + Used: false, + Expiration: nil, + }, + wantErr: false, + }, + { + name: "nil preauth key", + pak: nil, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil), + }, + { + name: "expired and used key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: true, + Expiration: &past, + }, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil), + }, + { + name: "no expiration and used key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: true, + Expiration: nil, + }, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := canUsePreAuthKey(tt.pak) + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } else { + httpErr, ok := err.(HTTPError) + if !ok { + t.Errorf("expected HTTPError but got %T", err) + } else { + if diff := cmp.Diff(tt.err, httpErr); diff != "" { + t.Errorf("unexpected error (-want +got):\n%s", diff) + } + } + } + } else { + if err != nil { + t.Errorf("expected no error but got %v", err) + } + } + }) + } +} diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index aeee5b52..ee977ae3 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -10,7 +10,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "gorm.io/gorm" - "tailscale.com/types/ptr" "tailscale.com/util/set" ) @@ -64,6 +63,7 @@ func CreatePreAuthKey( } now := time.Now().UTC() + // TODO(kradalby): unify the key generations spread all over the code. kstr, err := generateKey() if err != nil { return nil, err @@ -108,18 +108,21 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e return keys, nil } -// GetPreAuthKey returns a PreAuthKey for a given key. -func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) { - pak, err := ValidatePreAuthKey(tx, key) - if err != nil { - return nil, err +func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { + return GetPreAuthKey(rx, key) + }) +} + +// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible +// for checking if the key is usable (expired or used). +func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) { + pak := types.PreAuthKey{} + if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil { + return nil, ErrPreAuthKeyNotFound } - if pak.User.Name != user { - return nil, ErrUserMismatch - } - - return pak, nil + return &pak, nil } // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey @@ -140,15 +143,6 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { }) } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. -func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { - if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil { - return err - } - - return nil -} - // UsePreAuthKey marks a PreAuthKey as used. func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { k.Used = true @@ -159,44 +153,13 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { return nil } -func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { - return ValidatePreAuthKey(rx, k) - }) -} - -// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node -// If returns no error and a PreAuthKey, it can be used. -func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { - pak := types.PreAuthKey{} - if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is( - result.Error, - gorm.ErrRecordNotFound, - ) { - return nil, ErrPreAuthKeyNotFound +// MarkExpirePreAuthKey marks a PreAuthKey as expired. +func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { + if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil { + return err } - if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { - return nil, ErrPreAuthKeyExpired - } - - if pak.Reusable { // we don't need to check if has been used before - return &pak, nil - } - - nodes := types.Nodes{} - if err := tx. - Preload("AuthKey"). - Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}). - Find(&nodes).Error; err != nil { - return nil, err - } - - if len(nodes) != 0 || pak.Used { - return nil, ErrSingleUseAuthKeyHasBeenUsed - } - - return &pak, nil + return nil } func generateKey() (string, error) { diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index ec7f75a8..5ace968a 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -3,14 +3,14 @@ package db import ( "sort" "testing" - "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/check.v1" "tailscale.com/types/ptr" + + "gopkg.in/check.v1" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -43,123 +43,6 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { c.Assert((keys)[0].User.ID, check.Equals, user.ID) } -func (*Suite) TestExpiredPreAuthKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test2"}) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-5 * time.Second) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil) - c.Assert(err, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrPreAuthKeyExpired) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { - key, err := db.ValidatePreAuthKey("potatoKey") - c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestValidateKeyOk(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test3"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) - c.Assert(err, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestAlreadyUsedKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test4"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestReusableBeingUsedKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test5"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) - c.Assert(err, check.IsNil) - - node := types.Node{ - ID: 1, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test6"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestExpirePreauthKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test3"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) - c.Assert(err, check.IsNil) - c.Assert(pak.Expiration, check.IsNil) - - err = db.ExpirePreAuthKey(pak) - c.Assert(err, check.IsNil) - c.Assert(pak.Expiration, check.NotNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrPreAuthKeyExpired) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test6"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - pak.Used = true - db.DB.Save(&pak) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) -} - func (*Suite) TestPreAuthKeyACLTags(c *check.C) { user, err := db.CreateUser(types.User{Name: "test8"}) c.Assert(err, check.IsNil) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 51fb9869..7eadd0a7 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -181,11 +181,15 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { err := api.h.db.Write(func(tx *gorm.DB) error { - preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key) + preAuthKey, err := db.GetPreAuthKey(tx, request.Key) if err != nil { return err } + if preAuthKey.User.Name != request.GetUser() { + return fmt.Errorf("preauth key does not belong to user") + } + return db.ExpirePreAuthKey(tx, preAuthKey) }) if err != nil { diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index c310aedf..e55fce49 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -33,11 +33,34 @@ const ( ) // httpError logs an error and sends an HTTP error response with the given -func httpError(w http.ResponseWriter, err error, userError string, code int) { - log.Error().Err(err).Msg(userError) - http.Error(w, userError, code) +func httpError(w http.ResponseWriter, err error) { + var herr HTTPError + if errors.As(err, &herr) { + http.Error(w, herr.Msg, herr.Code) + log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg) + } else { + http.Error(w, "internal server error", http.StatusInternalServerError) + log.Error().Err(err).Int("code", http.StatusInternalServerError).Msg("http internal server error") + } } +// HTTPError represents an error that is surfaced to the user via web. +type HTTPError struct { + Code int // HTTP response code to send to client; 0 means 500 + Msg string // Response body to send to client + Err error // Detailed error to log on the server +} + +func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) } +func (e HTTPError) Unwrap() error { return e.Err } + +// Error returns an HTTPError containing the given information. +func NewHTTPError(code int, msg string, err error) HTTPError { + return HTTPError{Code: code, Msg: msg, Err: err} +} + +var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil) + var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( "machines registered with CLI does not support expire", ) @@ -47,12 +70,12 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) clientCapabilityStr := req.URL.Query().Get("v") if clientCapabilityStr == "" { - return 0, ErrNoCapabilityVersion + return 0, NewHTTPError(http.StatusBadRequest, "capability version must be set", nil) } clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr) if err != nil { - return 0, fmt.Errorf("failed to parse capability version: %w", err) + return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err)) } return tailcfg.CapabilityVersion(clientCapabilityVersion), nil @@ -85,13 +108,13 @@ func (h *Headscale) VerifyHandler( req *http.Request, ) { if req.Method != http.MethodPost { - httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed) + httpError(writer, errMethodNotAllowed) return } allow, err := h.derpRequestIsAllowed(req) if err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + httpError(writer, err) return } @@ -112,7 +135,7 @@ func (h *Headscale) KeyHandler( // New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion capVer, err := parseCabailityVersion(req) if err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + httpError(writer, err) return } @@ -199,7 +222,7 @@ func (a *AuthProviderWeb) RegisterHandler( // the template and log an error. registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - httpError(writer, err, "invalid registration ID", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) return } diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 318cf5e4..a8aa6e13 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -3,6 +3,7 @@ package hscontrol import ( "encoding/binary" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -12,6 +13,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "golang.org/x/net/http2" + "gorm.io/gorm" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/tailcfg" @@ -81,7 +83,7 @@ func (h *Headscale) NoiseUpgradeHandler( noiseServer.earlyNoise, ) if err != nil { - httpError(writer, err, "noise upgrade failed", http.StatusInternalServerError) + httpError(writer, fmt.Errorf("noise upgrade failed: %w", err)) return } @@ -198,7 +200,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( var mapRequest tailcfg.MapRequest if err := json.Unmarshal(body, &mapRequest); err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + httpError(writer, err) return } @@ -211,7 +213,11 @@ func (ns *noiseServer) NoisePollNetMapHandler( node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey) if err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + if errors.Is(err, gorm.ErrRecordNotFound) { + httpError(writer, NewHTTPError(http.StatusNotFound, "node not found", nil)) + return + } + httpError(writer, err) return } @@ -230,7 +236,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( req *http.Request, ) { if req.Method != http.MethodPost { - httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed) + httpError(writer, errMethodNotAllowed) return } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 42032f79..29c1141e 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -141,21 +141,21 @@ func (a *AuthProviderOIDC) RegisterHandler( // the template and log an error. registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - httpError(writer, err, "invalid registration ID", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) return } // Set the state and nonce cookies to protect against CSRF attacks state, err := setCSRFCookie(writer, req, "state") if err != nil { - httpError(writer, err, "Internal server error", http.StatusInternalServerError) + httpError(writer, err) return } // Set the state and nonce cookies to protect against CSRF attacks nonce, err := setCSRFCookie(writer, req, "nonce") if err != nil { - httpError(writer, err, "Internal server error", http.StatusInternalServerError) + httpError(writer, err) return } @@ -219,34 +219,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( ) { code, state, err := extractCodeAndStateParamFromRequest(req) if err != nil { - httpError(writer, err, err.Error(), http.StatusBadRequest) + httpError(writer, err) return } cookieState, err := req.Cookie("state") if err != nil { - httpError(writer, err, "state not found", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) return } if state != cookieState.Value { - httpError(writer, err, "state did not match", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil)) return } idToken, err := a.extractIDToken(req.Context(), code, state) if err != nil { - httpError(writer, err, err.Error(), http.StatusBadRequest) + httpError(writer, err) return } nonce, err := req.Cookie("nonce") if err != nil { - httpError(writer, err, "nonce not found", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) return } if idToken.Nonce != nonce.Value { - httpError(writer, err, "nonce did not match", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) return } @@ -254,29 +254,28 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( var claims types.OIDCClaims if err := idToken.Claims(&claims); err != nil { - err = fmt.Errorf("decoding ID token claims: %w", err) - httpError(writer, err, err.Error(), http.StatusInternalServerError) + httpError(writer, fmt.Errorf("decoding ID token claims: %w", err)) return } if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil { - httpError(writer, err, err.Error(), http.StatusUnauthorized) + httpError(writer, err) return } if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil { - httpError(writer, err, err.Error(), http.StatusUnauthorized) + httpError(writer, err) return } if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil { - httpError(writer, err, err.Error(), http.StatusUnauthorized) + httpError(writer, err) return } user, err := a.createOrUpdateUserFromClaim(&claims) if err != nil { - httpError(writer, err, err.Error(), http.StatusInternalServerError) + httpError(writer, err) return } @@ -289,9 +288,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Register the node if it does not exist. if registrationId != nil { verb := "Reauthenticated" - newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry) + newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) if err != nil { - httpError(writer, err, err.Error(), http.StatusInternalServerError) + httpError(writer, err) return } @@ -302,7 +301,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // TODO(kradalby): replace with go-elem content, err := renderOIDCCallbackTemplate(user, verb) if err != nil { - httpError(writer, err, err.Error(), http.StatusInternalServerError) + httpError(writer, err) return } @@ -317,7 +316,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Neither node nor machine key was found in the state cache meaning // that we could not reauth nor register the node. - httpError(writer, nil, "login session expired, try again", http.StatusInternalServerError) + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) return } @@ -328,7 +327,7 @@ func extractCodeAndStateParamFromRequest( state := req.URL.Query().Get("state") if code == "" || state == "" { - return "", "", errEmptyOIDCCallbackParams + return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams) } return code, state, nil @@ -346,7 +345,7 @@ func (a *AuthProviderOIDC) extractIDToken( if a.cfg.PKCE.Enabled { regInfo, ok := a.registrationCache.Get(state) if !ok { - return nil, errNoOIDCRegistrationInfo + return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } if regInfo.Verifier != nil { exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} @@ -355,18 +354,18 @@ func (a *AuthProviderOIDC) extractIDToken( oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...) if err != nil { - return nil, fmt.Errorf("could not exchange code for token: %w", err) + return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err)) } rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { - return nil, errNoOIDCIDToken + return nil, NewHTTPError(http.StatusBadRequest, "no id_token", errNoOIDCIDToken) } verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID}) idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { - return nil, fmt.Errorf("failed to verify ID token: %w", err) + return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err)) } return idToken, nil @@ -381,7 +380,7 @@ func validateOIDCAllowedDomains( if len(allowedDomains) > 0 { if at := strings.LastIndex(claims.Email, "@"); at < 0 || !slices.Contains(allowedDomains, claims.Email[at+1:]) { - return errOIDCAllowedDomains + return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains) } } @@ -403,7 +402,7 @@ func validateOIDCAllowedGroups( } } - return errOIDCAllowedGroups + return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups) } return nil @@ -417,7 +416,7 @@ func validateOIDCAllowedUsers( ) error { if len(allowedUsers) > 0 && !slices.Contains(allowedUsers, claims.Email) { - return errOIDCAllowedUsers + return NewHTTPError(http.StatusUnauthorized, "unauthorised user", errOIDCAllowedUsers) } return nil @@ -488,7 +487,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return user, nil } -func (a *AuthProviderOIDC) handleRegistrationID( +func (a *AuthProviderOIDC) handleRegistration( user *types.User, registrationID types.RegistrationID, expiry time.Time, diff --git a/hscontrol/platform_config.go b/hscontrol/platform_config.go index 1855ee24..23c4d25d 100644 --- a/hscontrol/platform_config.go +++ b/hscontrol/platform_config.go @@ -39,19 +39,19 @@ func (h *Headscale) ApplePlatformConfig( vars := mux.Vars(req) platform, ok := vars["platform"] if !ok { - httpError(writer, nil, "No platform specified", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil)) return } id, err := uuid.NewV4() if err != nil { - httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError) + httpError(writer, err) return } contentID, err := uuid.NewV4() if err != nil { - httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError) + httpError(writer, err) return } @@ -65,21 +65,21 @@ func (h *Headscale) ApplePlatformConfig( switch platform { case "macos-standalone": if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil { - httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError) + httpError(writer, err) return } case "macos-app-store": if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil { - httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError) + httpError(writer, err) return } case "ios": if err := iosTemplate.Execute(&payload, platformConfig); err != nil { - httpError(writer, err, "Could not render Apple iOS template", http.StatusInternalServerError) + httpError(writer, err) return } default: - httpError(writer, err, "Invalid platform. Only ios, macos-app-store and macos-standalone are supported", http.StatusInternalServerError) + httpError(writer, NewHTTPError(http.StatusBadRequest, "platform must be ios, macos-app-store or macos-standalone", nil)) return } @@ -91,7 +91,7 @@ func (h *Headscale) ApplePlatformConfig( var content bytes.Buffer if err := commonTemplate.Execute(&content, config); err != nil { - httpError(writer, err, "Could not render platform iOS template", http.StatusInternalServerError) + httpError(writer, err) return }