diff --git a/app.go b/app.go index 3d24d81c..cfed67aa 100644 --- a/app.go +++ b/app.go @@ -402,7 +402,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine { router.POST("/machine/:id/map", h.PollNetMapHandler) router.POST("/machine/:id", h.RegistrationHandler) router.GET("/oidc/register/:mkey", h.RegisterOIDC) - router.GET("/oidc/callback", h.OIDCCallback) + router.GET("/oidc/callback", gin.WrapF(h.OIDCCallback)) router.GET("/apple", gin.WrapF(h.AppleConfigMessage)) router.GET("/apple/:platform", gin.WrapF(h.ApplePlatformConfig)) router.GET("/windows", gin.WrapF(h.WindowsConfigMessage)) diff --git a/oidc.go b/oidc.go index 38a1eb36..477fe78c 100644 --- a/oidc.go +++ b/oidc.go @@ -125,12 +125,17 @@ var oidcCallbackTemplate = template.Must( // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback. -func (h *Headscale) OIDCCallback(ctx *gin.Context) { - code := ctx.Query("code") - state := ctx.Query("state") +func (h *Headscale) OIDCCallback( + w http.ResponseWriter, + r *http.Request, +) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") if code == "" || state == "" { - ctx.String(http.StatusBadRequest, "Wrong params") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Wrong params")) return } @@ -141,7 +146,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Caller(). Msg("Could not exchange code for token") - ctx.String(http.StatusBadRequest, "Could not exchange code for token") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Could not exchange code for token")) return } @@ -154,7 +161,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) if !rawIDTokenOK { - ctx.String(http.StatusBadRequest, "Could not extract ID Token") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Could not extract ID Token")) return } @@ -167,7 +176,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Caller(). Msg("failed to verify id token") - ctx.String(http.StatusBadRequest, "Failed to verify id token") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Failed to verify id token")) return } @@ -186,10 +197,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Caller(). Msg("Failed to decode id token claims") - ctx.String( - http.StatusBadRequest, - "Failed to decode id token claims", - ) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("Failed to decode id token claims")) return } @@ -199,10 +209,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { if at := strings.LastIndex(claims.Email, "@"); at < 0 || !IsStringInSlice(h.cfg.OIDC.AllowedDomains, claims.Email[at+1:]) { log.Error().Msg("authenticated principal does not match any allowed domain") - ctx.String( - http.StatusBadRequest, - "unauthorized principal (domain mismatch)", - ) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("unauthorized principal (domain mismatch)")) return } @@ -212,7 +221,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { if len(h.cfg.OIDC.AllowedUsers) > 0 && !IsStringInSlice(h.cfg.OIDC.AllowedUsers, claims.Email) { log.Error().Msg("authenticated principal does not match any allowed user") - ctx.String(http.StatusBadRequest, "unauthorized principal (user mismatch)") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("unauthorized principal (user mismatch)")) return } @@ -223,7 +234,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { if !machineKeyFound { log.Error(). Msg("requested machine state key expired before authorisation completed") - ctx.String(http.StatusBadRequest, "state has expired") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("state has expired")) return } @@ -237,17 +250,18 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { if err != nil { log.Error(). Msg("could not parse machine public key") - ctx.String(http.StatusBadRequest, "could not parse public key") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("could not parse public key")) return } if !machineKeyOK { log.Error().Msg("could not get machine key from cache") - ctx.String( - http.StatusInternalServerError, - "could not get machine key from cache", - ) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not get machine key from cache")) return } @@ -276,14 +290,17 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Str("type", "reauthenticate"). Err(err). Msg("Could not render OIDC callback template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render OIDC callback template"), - ) + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Could not render OIDC callback template")) + + return } - ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(content.Bytes()) return } @@ -294,10 +311,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { ) if err != nil { log.Error().Err(err).Caller().Msgf("couldn't normalize email") - ctx.String( - http.StatusInternalServerError, - "couldn't normalize email", - ) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("couldn't normalize email")) return } @@ -314,10 +330,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Caller(). Msgf("could not create new namespace '%s'", namespaceName) - ctx.String( - http.StatusInternalServerError, - "could not create new namespace", - ) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not create namespace")) return } @@ -327,10 +342,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Err(err). Str("namespace", namespaceName). Msg("could not find or create namespace") - ctx.String( - http.StatusInternalServerError, - "could not find or create namespace", - ) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not find or create namespace")) return } @@ -347,10 +361,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Caller(). Err(err). Msg("could not register machine") - ctx.String( - http.StatusInternalServerError, - "could not register machine", - ) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not register machine")) return } @@ -365,12 +378,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { Str("type", "authenticate"). Err(err). Msg("Could not render OIDC callback template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render OIDC callback template"), - ) + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Could not render OIDC callback template")) + + return } - ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(content.Bytes()) }