diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d5c2245..eee03861 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ [#2396](https://github.com/juanfont/headscale/pull/2396) - Pre auth keys that are used by a node can no longer be deleted [#2396](https://github.com/juanfont/headscale/pull/2396) +- Rehaul HTTP errors, return better status code and errors to users + [#2398](https://github.com/juanfont/headscale/pull/2398) ## 0.24.2 (2025-01-30) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 3fa5fa4b..7695f1ae 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -72,7 +72,7 @@ func (h *Headscale) handleExistingNode( machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { if node.MachineKey != machineKey { - return nil, errors.New("node already exists with different machine key") + return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) } expired := node.IsExpired() @@ -81,7 +81,7 @@ func (h *Headscale) handleExistingNode( // The client is trying to extend their key, this is not allowed. if requestExpiry.After(time.Now()) { - return nil, errors.New("extending key is not allowed") + return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil) } // If the request expiry is in the past, we consider it a logout. @@ -155,13 +155,42 @@ func (h *Headscale) waitForFollowup( } } +// canUsePreAuthKey checks if a pre auth key can be used. +func canUsePreAuthKey(pak *types.PreAuthKey) error { + if pak == nil { + return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil) + } + if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { + return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil) + } + + // we don't need to check if has been used before + if pak.Reusable { + return nil + } + + if pak.Used { + return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil) + } + + return nil +} + func (h *Headscale) handleRegisterWithAuthKey( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) (*tailcfg.RegisterResponse, error) { - pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey) + pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey) if err != nil { - return nil, fmt.Errorf("invalid pre auth key: %w", err) + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) + } + return nil, err + } + + err = canUsePreAuthKey(pak) + if err != nil { + return nil, err } nodeToRegister := types.Node{ diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go new file mode 100644 index 00000000..7c0c0d42 --- /dev/null +++ b/hscontrol/auth_test.go @@ -0,0 +1,130 @@ +package hscontrol + +import ( + "net/http" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" +) + +func TestCanUsePreAuthKey(t *testing.T) { + now := time.Now() + past := now.Add(-time.Hour) + future := now.Add(time.Hour) + + tests := []struct { + name string + pak *types.PreAuthKey + wantErr bool + err HTTPError + }{ + { + name: "valid reusable key", + pak: &types.PreAuthKey{ + Reusable: true, + Used: false, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "valid non-reusable key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: false, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "expired key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: false, + Expiration: &past, + }, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil), + }, + { + name: "used non-reusable key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: true, + Expiration: &future, + }, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil), + }, + { + name: "used reusable key", + pak: &types.PreAuthKey{ + Reusable: true, + Used: true, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "no expiration date", + pak: &types.PreAuthKey{ + Reusable: false, + Used: false, + Expiration: nil, + }, + wantErr: false, + }, + { + name: "nil preauth key", + pak: nil, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil), + }, + { + name: "expired and used key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: true, + Expiration: &past, + }, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil), + }, + { + name: "no expiration and used key", + pak: &types.PreAuthKey{ + Reusable: false, + Used: true, + Expiration: nil, + }, + wantErr: true, + err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := canUsePreAuthKey(tt.pak) + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } else { + httpErr, ok := err.(HTTPError) + if !ok { + t.Errorf("expected HTTPError but got %T", err) + } else { + if diff := cmp.Diff(tt.err, httpErr); diff != "" { + t.Errorf("unexpected error (-want +got):\n%s", diff) + } + } + } + } else { + if err != nil { + t.Errorf("expected no error but got %v", err) + } + } + }) + } +} diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index aeee5b52..ee977ae3 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -10,7 +10,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "gorm.io/gorm" - "tailscale.com/types/ptr" "tailscale.com/util/set" ) @@ -64,6 +63,7 @@ func CreatePreAuthKey( } now := time.Now().UTC() + // TODO(kradalby): unify the key generations spread all over the code. kstr, err := generateKey() if err != nil { return nil, err @@ -108,18 +108,21 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e return keys, nil } -// GetPreAuthKey returns a PreAuthKey for a given key. -func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) { - pak, err := ValidatePreAuthKey(tx, key) - if err != nil { - return nil, err +func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { + return GetPreAuthKey(rx, key) + }) +} + +// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible +// for checking if the key is usable (expired or used). +func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) { + pak := types.PreAuthKey{} + if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil { + return nil, ErrPreAuthKeyNotFound } - if pak.User.Name != user { - return nil, ErrUserMismatch - } - - return pak, nil + return &pak, nil } // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey @@ -140,15 +143,6 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { }) } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. -func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { - if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil { - return err - } - - return nil -} - // UsePreAuthKey marks a PreAuthKey as used. func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { k.Used = true @@ -159,44 +153,13 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { return nil } -func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) { - return ValidatePreAuthKey(rx, k) - }) -} - -// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node -// If returns no error and a PreAuthKey, it can be used. -func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) { - pak := types.PreAuthKey{} - if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is( - result.Error, - gorm.ErrRecordNotFound, - ) { - return nil, ErrPreAuthKeyNotFound +// MarkExpirePreAuthKey marks a PreAuthKey as expired. +func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { + if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil { + return err } - if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { - return nil, ErrPreAuthKeyExpired - } - - if pak.Reusable { // we don't need to check if has been used before - return &pak, nil - } - - nodes := types.Nodes{} - if err := tx. - Preload("AuthKey"). - Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}). - Find(&nodes).Error; err != nil { - return nil, err - } - - if len(nodes) != 0 || pak.Used { - return nil, ErrSingleUseAuthKeyHasBeenUsed - } - - return &pak, nil + return nil } func generateKey() (string, error) { diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index ec7f75a8..5ace968a 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -3,14 +3,14 @@ package db import ( "sort" "testing" - "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "gopkg.in/check.v1" "tailscale.com/types/ptr" + + "gopkg.in/check.v1" ) func (*Suite) TestCreatePreAuthKey(c *check.C) { @@ -43,123 +43,6 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { c.Assert((keys)[0].User.ID, check.Equals, user.ID) } -func (*Suite) TestExpiredPreAuthKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test2"}) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-5 * time.Second) - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil) - c.Assert(err, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrPreAuthKeyExpired) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { - key, err := db.ValidatePreAuthKey("potatoKey") - c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestValidateKeyOk(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test3"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) - c.Assert(err, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestAlreadyUsedKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test4"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestReusableBeingUsedKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test5"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) - c.Assert(err, check.IsNil) - - node := types.Node{ - ID: 1, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), - } - trx := db.DB.Save(&node) - c.Assert(trx.Error, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test6"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestExpirePreauthKey(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test3"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) - c.Assert(err, check.IsNil) - c.Assert(pak.Expiration, check.IsNil) - - err = db.ExpirePreAuthKey(pak) - c.Assert(err, check.IsNil) - c.Assert(pak.Expiration, check.NotNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrPreAuthKeyExpired) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { - user, err := db.CreateUser(types.User{Name: "test6"}) - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) - c.Assert(err, check.IsNil) - pak.Used = true - db.DB.Save(&pak) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) -} - func (*Suite) TestPreAuthKeyACLTags(c *check.C) { user, err := db.CreateUser(types.User{Name: "test8"}) c.Assert(err, check.IsNil) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 51fb9869..7eadd0a7 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -181,11 +181,15 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { err := api.h.db.Write(func(tx *gorm.DB) error { - preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key) + preAuthKey, err := db.GetPreAuthKey(tx, request.Key) if err != nil { return err } + if preAuthKey.User.Name != request.GetUser() { + return fmt.Errorf("preauth key does not belong to user") + } + return db.ExpirePreAuthKey(tx, preAuthKey) }) if err != nil { diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index c310aedf..e55fce49 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -33,11 +33,34 @@ const ( ) // httpError logs an error and sends an HTTP error response with the given -func httpError(w http.ResponseWriter, err error, userError string, code int) { - log.Error().Err(err).Msg(userError) - http.Error(w, userError, code) +func httpError(w http.ResponseWriter, err error) { + var herr HTTPError + if errors.As(err, &herr) { + http.Error(w, herr.Msg, herr.Code) + log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg) + } else { + http.Error(w, "internal server error", http.StatusInternalServerError) + log.Error().Err(err).Int("code", http.StatusInternalServerError).Msg("http internal server error") + } } +// HTTPError represents an error that is surfaced to the user via web. +type HTTPError struct { + Code int // HTTP response code to send to client; 0 means 500 + Msg string // Response body to send to client + Err error // Detailed error to log on the server +} + +func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) } +func (e HTTPError) Unwrap() error { return e.Err } + +// Error returns an HTTPError containing the given information. +func NewHTTPError(code int, msg string, err error) HTTPError { + return HTTPError{Code: code, Msg: msg, Err: err} +} + +var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil) + var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( "machines registered with CLI does not support expire", ) @@ -47,12 +70,12 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) clientCapabilityStr := req.URL.Query().Get("v") if clientCapabilityStr == "" { - return 0, ErrNoCapabilityVersion + return 0, NewHTTPError(http.StatusBadRequest, "capability version must be set", nil) } clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr) if err != nil { - return 0, fmt.Errorf("failed to parse capability version: %w", err) + return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err)) } return tailcfg.CapabilityVersion(clientCapabilityVersion), nil @@ -85,13 +108,13 @@ func (h *Headscale) VerifyHandler( req *http.Request, ) { if req.Method != http.MethodPost { - httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed) + httpError(writer, errMethodNotAllowed) return } allow, err := h.derpRequestIsAllowed(req) if err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + httpError(writer, err) return } @@ -112,7 +135,7 @@ func (h *Headscale) KeyHandler( // New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion capVer, err := parseCabailityVersion(req) if err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + httpError(writer, err) return } @@ -199,7 +222,7 @@ func (a *AuthProviderWeb) RegisterHandler( // the template and log an error. registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - httpError(writer, err, "invalid registration ID", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) return } diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 318cf5e4..a8aa6e13 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -3,6 +3,7 @@ package hscontrol import ( "encoding/binary" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -12,6 +13,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "golang.org/x/net/http2" + "gorm.io/gorm" "tailscale.com/control/controlbase" "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/tailcfg" @@ -81,7 +83,7 @@ func (h *Headscale) NoiseUpgradeHandler( noiseServer.earlyNoise, ) if err != nil { - httpError(writer, err, "noise upgrade failed", http.StatusInternalServerError) + httpError(writer, fmt.Errorf("noise upgrade failed: %w", err)) return } @@ -198,7 +200,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( var mapRequest tailcfg.MapRequest if err := json.Unmarshal(body, &mapRequest); err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + httpError(writer, err) return } @@ -211,7 +213,11 @@ func (ns *noiseServer) NoisePollNetMapHandler( node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey) if err != nil { - httpError(writer, err, "Internal error", http.StatusInternalServerError) + if errors.Is(err, gorm.ErrRecordNotFound) { + httpError(writer, NewHTTPError(http.StatusNotFound, "node not found", nil)) + return + } + httpError(writer, err) return } @@ -230,7 +236,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( req *http.Request, ) { if req.Method != http.MethodPost { - httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed) + httpError(writer, errMethodNotAllowed) return } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 42032f79..29c1141e 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -141,21 +141,21 @@ func (a *AuthProviderOIDC) RegisterHandler( // the template and log an error. registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - httpError(writer, err, "invalid registration ID", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) return } // Set the state and nonce cookies to protect against CSRF attacks state, err := setCSRFCookie(writer, req, "state") if err != nil { - httpError(writer, err, "Internal server error", http.StatusInternalServerError) + httpError(writer, err) return } // Set the state and nonce cookies to protect against CSRF attacks nonce, err := setCSRFCookie(writer, req, "nonce") if err != nil { - httpError(writer, err, "Internal server error", http.StatusInternalServerError) + httpError(writer, err) return } @@ -219,34 +219,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( ) { code, state, err := extractCodeAndStateParamFromRequest(req) if err != nil { - httpError(writer, err, err.Error(), http.StatusBadRequest) + httpError(writer, err) return } cookieState, err := req.Cookie("state") if err != nil { - httpError(writer, err, "state not found", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) return } if state != cookieState.Value { - httpError(writer, err, "state did not match", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil)) return } idToken, err := a.extractIDToken(req.Context(), code, state) if err != nil { - httpError(writer, err, err.Error(), http.StatusBadRequest) + httpError(writer, err) return } nonce, err := req.Cookie("nonce") if err != nil { - httpError(writer, err, "nonce not found", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) return } if idToken.Nonce != nonce.Value { - httpError(writer, err, "nonce did not match", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) return } @@ -254,29 +254,28 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( var claims types.OIDCClaims if err := idToken.Claims(&claims); err != nil { - err = fmt.Errorf("decoding ID token claims: %w", err) - httpError(writer, err, err.Error(), http.StatusInternalServerError) + httpError(writer, fmt.Errorf("decoding ID token claims: %w", err)) return } if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil { - httpError(writer, err, err.Error(), http.StatusUnauthorized) + httpError(writer, err) return } if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil { - httpError(writer, err, err.Error(), http.StatusUnauthorized) + httpError(writer, err) return } if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil { - httpError(writer, err, err.Error(), http.StatusUnauthorized) + httpError(writer, err) return } user, err := a.createOrUpdateUserFromClaim(&claims) if err != nil { - httpError(writer, err, err.Error(), http.StatusInternalServerError) + httpError(writer, err) return } @@ -289,9 +288,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Register the node if it does not exist. if registrationId != nil { verb := "Reauthenticated" - newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry) + newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) if err != nil { - httpError(writer, err, err.Error(), http.StatusInternalServerError) + httpError(writer, err) return } @@ -302,7 +301,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // TODO(kradalby): replace with go-elem content, err := renderOIDCCallbackTemplate(user, verb) if err != nil { - httpError(writer, err, err.Error(), http.StatusInternalServerError) + httpError(writer, err) return } @@ -317,7 +316,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Neither node nor machine key was found in the state cache meaning // that we could not reauth nor register the node. - httpError(writer, nil, "login session expired, try again", http.StatusInternalServerError) + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) return } @@ -328,7 +327,7 @@ func extractCodeAndStateParamFromRequest( state := req.URL.Query().Get("state") if code == "" || state == "" { - return "", "", errEmptyOIDCCallbackParams + return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams) } return code, state, nil @@ -346,7 +345,7 @@ func (a *AuthProviderOIDC) extractIDToken( if a.cfg.PKCE.Enabled { regInfo, ok := a.registrationCache.Get(state) if !ok { - return nil, errNoOIDCRegistrationInfo + return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } if regInfo.Verifier != nil { exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} @@ -355,18 +354,18 @@ func (a *AuthProviderOIDC) extractIDToken( oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...) if err != nil { - return nil, fmt.Errorf("could not exchange code for token: %w", err) + return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err)) } rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { - return nil, errNoOIDCIDToken + return nil, NewHTTPError(http.StatusBadRequest, "no id_token", errNoOIDCIDToken) } verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID}) idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { - return nil, fmt.Errorf("failed to verify ID token: %w", err) + return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err)) } return idToken, nil @@ -381,7 +380,7 @@ func validateOIDCAllowedDomains( if len(allowedDomains) > 0 { if at := strings.LastIndex(claims.Email, "@"); at < 0 || !slices.Contains(allowedDomains, claims.Email[at+1:]) { - return errOIDCAllowedDomains + return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains) } } @@ -403,7 +402,7 @@ func validateOIDCAllowedGroups( } } - return errOIDCAllowedGroups + return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups) } return nil @@ -417,7 +416,7 @@ func validateOIDCAllowedUsers( ) error { if len(allowedUsers) > 0 && !slices.Contains(allowedUsers, claims.Email) { - return errOIDCAllowedUsers + return NewHTTPError(http.StatusUnauthorized, "unauthorised user", errOIDCAllowedUsers) } return nil @@ -488,7 +487,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( return user, nil } -func (a *AuthProviderOIDC) handleRegistrationID( +func (a *AuthProviderOIDC) handleRegistration( user *types.User, registrationID types.RegistrationID, expiry time.Time, diff --git a/hscontrol/platform_config.go b/hscontrol/platform_config.go index 1855ee24..23c4d25d 100644 --- a/hscontrol/platform_config.go +++ b/hscontrol/platform_config.go @@ -39,19 +39,19 @@ func (h *Headscale) ApplePlatformConfig( vars := mux.Vars(req) platform, ok := vars["platform"] if !ok { - httpError(writer, nil, "No platform specified", http.StatusBadRequest) + httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil)) return } id, err := uuid.NewV4() if err != nil { - httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError) + httpError(writer, err) return } contentID, err := uuid.NewV4() if err != nil { - httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError) + httpError(writer, err) return } @@ -65,21 +65,21 @@ func (h *Headscale) ApplePlatformConfig( switch platform { case "macos-standalone": if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil { - httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError) + httpError(writer, err) return } case "macos-app-store": if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil { - httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError) + httpError(writer, err) return } case "ios": if err := iosTemplate.Execute(&payload, platformConfig); err != nil { - httpError(writer, err, "Could not render Apple iOS template", http.StatusInternalServerError) + httpError(writer, err) return } default: - httpError(writer, err, "Invalid platform. Only ios, macos-app-store and macos-standalone are supported", http.StatusInternalServerError) + httpError(writer, NewHTTPError(http.StatusBadRequest, "platform must be ios, macos-app-store or macos-standalone", nil)) return } @@ -91,7 +91,7 @@ func (h *Headscale) ApplePlatformConfig( var content bytes.Buffer if err := commonTemplate.Execute(&content, config); err != nil { - httpError(writer, err, "Could not render platform iOS template", http.StatusInternalServerError) + httpError(writer, err) return }