Remove Gin from OIDC callback

This commit is contained in:
Juan Font Alonso 2022-06-17 17:42:17 +02:00
parent 367da0fcc2
commit d5e331a2fb
2 changed files with 67 additions and 51 deletions

2
app.go
View file

@ -402,7 +402,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router.POST("/machine/:id/map", h.PollNetMapHandler) router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler) router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC) router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback) router.GET("/oidc/callback", gin.WrapF(h.OIDCCallback))
router.GET("/apple", gin.WrapF(h.AppleConfigMessage)) router.GET("/apple", gin.WrapF(h.AppleConfigMessage))
router.GET("/apple/:platform", gin.WrapF(h.ApplePlatformConfig)) router.GET("/apple/:platform", gin.WrapF(h.ApplePlatformConfig))
router.GET("/windows", gin.WrapF(h.WindowsConfigMessage)) router.GET("/windows", gin.WrapF(h.WindowsConfigMessage))

116
oidc.go
View file

@ -125,12 +125,17 @@ var oidcCallbackTemplate = template.Must(
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback. // Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(ctx *gin.Context) { func (h *Headscale) OIDCCallback(
code := ctx.Query("code") w http.ResponseWriter,
state := ctx.Query("state") r *http.Request,
) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" || state == "" { if code == "" || state == "" {
ctx.String(http.StatusBadRequest, "Wrong params") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Wrong params"))
return return
} }
@ -141,7 +146,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Caller(). Caller().
Msg("Could not exchange code for token") Msg("Could not exchange code for token")
ctx.String(http.StatusBadRequest, "Could not exchange code for token") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Could not exchange code for token"))
return return
} }
@ -154,7 +161,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK { if !rawIDTokenOK {
ctx.String(http.StatusBadRequest, "Could not extract ID Token") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Could not extract ID Token"))
return return
} }
@ -167,7 +176,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Caller(). Caller().
Msg("failed to verify id token") Msg("failed to verify id token")
ctx.String(http.StatusBadRequest, "Failed to verify id token") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Failed to verify id token"))
return return
} }
@ -186,10 +197,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Caller(). Caller().
Msg("Failed to decode id token claims") Msg("Failed to decode id token claims")
ctx.String( w.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusBadRequest, w.WriteHeader(http.StatusBadRequest)
"Failed to decode id token claims", w.Write([]byte("Failed to decode id token claims"))
)
return return
} }
@ -199,10 +209,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
if at := strings.LastIndex(claims.Email, "@"); at < 0 || if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!IsStringInSlice(h.cfg.OIDC.AllowedDomains, claims.Email[at+1:]) { !IsStringInSlice(h.cfg.OIDC.AllowedDomains, claims.Email[at+1:]) {
log.Error().Msg("authenticated principal does not match any allowed domain") log.Error().Msg("authenticated principal does not match any allowed domain")
ctx.String( w.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusBadRequest, w.WriteHeader(http.StatusBadRequest)
"unauthorized principal (domain mismatch)", w.Write([]byte("unauthorized principal (domain mismatch)"))
)
return return
} }
@ -212,7 +221,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
if len(h.cfg.OIDC.AllowedUsers) > 0 && if len(h.cfg.OIDC.AllowedUsers) > 0 &&
!IsStringInSlice(h.cfg.OIDC.AllowedUsers, claims.Email) { !IsStringInSlice(h.cfg.OIDC.AllowedUsers, claims.Email) {
log.Error().Msg("authenticated principal does not match any allowed user") log.Error().Msg("authenticated principal does not match any allowed user")
ctx.String(http.StatusBadRequest, "unauthorized principal (user mismatch)") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("unauthorized principal (user mismatch)"))
return return
} }
@ -223,7 +234,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
if !machineKeyFound { if !machineKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("state has expired"))
return return
} }
@ -237,17 +250,18 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse machine public key")
ctx.String(http.StatusBadRequest, "could not parse public key") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("could not parse public key"))
return return
} }
if !machineKeyOK { if !machineKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get machine key from cache")
ctx.String( w.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, w.WriteHeader(http.StatusInternalServerError)
"could not get machine key from cache", w.Write([]byte("could not get machine key from cache"))
)
return return
} }
@ -276,14 +290,17 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Str("type", "reauthenticate"). Str("type", "reauthenticate").
Err(err). Err(err).
Msg("Could not render OIDC callback template") Msg("Could not render OIDC callback template")
ctx.Data(
http.StatusInternalServerError, w.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", w.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render OIDC callback template"), w.Write([]byte("Could not render OIDC callback template"))
)
return
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write(content.Bytes())
return return
} }
@ -294,10 +311,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
) )
if err != nil { if err != nil {
log.Error().Err(err).Caller().Msgf("couldn't normalize email") log.Error().Err(err).Caller().Msgf("couldn't normalize email")
ctx.String( w.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, w.WriteHeader(http.StatusInternalServerError)
"couldn't normalize email", w.Write([]byte("couldn't normalize email"))
)
return return
} }
@ -314,10 +330,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Caller(). Caller().
Msgf("could not create new namespace '%s'", namespaceName) Msgf("could not create new namespace '%s'", namespaceName)
ctx.String( w.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, w.WriteHeader(http.StatusInternalServerError)
"could not create new namespace", w.Write([]byte("could not create namespace"))
)
return return
} }
@ -327,10 +342,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Err(err). Err(err).
Str("namespace", namespaceName). Str("namespace", namespaceName).
Msg("could not find or create namespace") Msg("could not find or create namespace")
ctx.String( w.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, w.WriteHeader(http.StatusInternalServerError)
"could not find or create namespace", w.Write([]byte("could not find or create namespace"))
)
return return
} }
@ -347,10 +361,9 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Caller(). Caller().
Err(err). Err(err).
Msg("could not register machine") Msg("could not register machine")
ctx.String( w.Header().Set("Content-Type", "text/plain; charset=utf-8")
http.StatusInternalServerError, w.WriteHeader(http.StatusInternalServerError)
"could not register machine", w.Write([]byte("could not register machine"))
)
return return
} }
@ -365,12 +378,15 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Str("type", "authenticate"). Str("type", "authenticate").
Err(err). Err(err).
Msg("Could not render OIDC callback template") Msg("Could not render OIDC callback template")
ctx.Data(
http.StatusInternalServerError, w.Header().Set("Content-Type", "text/plain; charset=utf-8")
"text/html; charset=utf-8", w.WriteHeader(http.StatusInternalServerError)
[]byte("Could not render OIDC callback template"), w.Write([]byte("Could not render OIDC callback template"))
)
return
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write(content.Bytes())
} }