diff --git a/oidc.go b/oidc.go index a47863ff..495832a5 100644 --- a/oidc.go +++ b/oidc.go @@ -9,7 +9,6 @@ import ( "fmt" "html/template" "net/http" - "regexp" "strings" "time" @@ -282,109 +281,90 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { now := time.Now().UTC() - if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok { - // register the machine if it's new - if !machine.Registered { - log.Debug().Msg("Registering new machine after successful callback") - - namespace, err := h.GetNamespace(namespaceName) - if errors.Is(err, gorm.ErrRecordNotFound) { - namespace, err = h.CreateNamespace(namespaceName) - - if err != nil { - log.Error(). - Err(err). - Caller(). - Msgf("could not create new namespace '%s'", namespaceName) - ctx.String( - http.StatusInternalServerError, - "could not create new namespace", - ) - - return - } - } else if err != nil { - log.Error(). - Caller(). - Err(err). - Str("namespace", namespaceName). - Msg("could not find or create namespace") - ctx.String( - http.StatusInternalServerError, - "could not find or create namespace", - ) - - return - } - - ips, err := h.getAvailableIPs() - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("could not get an IP from the pool") - ctx.String( - http.StatusInternalServerError, - "could not get an IP from the pool", - ) - - return - } - - machine.IPAddresses = ips - machine.NamespaceID = namespace.ID - machine.Registered = true - machine.RegisterMethod = RegisterMethodOIDC - machine.LastSuccessfulUpdate = &now - machine.Expiry = &requestedTime - h.db.Save(&machine) - } - - var content bytes.Buffer - if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ - User: claims.Email, - Verb: "Authenticated", - }); err != nil { - log.Error(). - Str("func", "OIDCCallback"). - Str("type", "authenticate"). - Err(err). - Msg("Could not render OIDC callback template") - ctx.Data( - http.StatusInternalServerError, - "text/html; charset=utf-8", - []byte("Could not render OIDC callback template"), - ) - } - - ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) - + namespaceName, err := NormalizeNamespaceName(claims.Email) + if err != nil { + log.Error().Err(err).Caller().Msgf("couldn't normalize email") + ctx.String( + http.StatusInternalServerError, + "couldn't normalize email", + ) return } + // register the machine if it's new + if !machine.Registered { + log.Debug().Msg("Registering new machine after successful callback") - log.Error(). - Caller(). - Str("email", claims.Email). - Str("username", claims.Username). - Str("machine", machine.Name). - Msg("Email could not be mapped to a namespace") - ctx.String( - http.StatusBadRequest, - "email from claim could not be mapped to a namespace", - ) -} + namespace, err := h.GetNamespace(namespaceName) + if errors.Is(err, gorm.ErrRecordNotFound) { + namespace, err = h.CreateNamespace(namespaceName) -// getNamespaceFromEmail passes the users email through a list of "matchers" -// and iterates through them until it matches and returns a namespace. -// If no match is found, an empty string will be returned. -// TODO(kradalby): golang Maps key order is not stable, so this list is _not_ deterministic. Find a way to make the list of keys stable, preferably in the order presented in a users configuration. -func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) { - for match, namespace := range h.cfg.OIDC.MatchMap { - regex := regexp.MustCompile(match) - if regex.MatchString(email) { - return namespace, true + if err != nil { + log.Error(). + Err(err). + Caller(). + Msgf("could not create new namespace '%s'", namespaceName) + ctx.String( + http.StatusInternalServerError, + "could not create new namespace", + ) + + return + } + } else if err != nil { + log.Error(). + Caller(). + Err(err). + Str("namespace", namespaceName). + Msg("could not find or create namespace") + ctx.String( + http.StatusInternalServerError, + "could not find or create namespace", + ) + + return } + + ips, err := h.getAvailableIPs() + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("could not get an IP from the pool") + ctx.String( + http.StatusInternalServerError, + "could not get an IP from the pool", + ) + + return + } + + machine.IPAddresses = ips + machine.NamespaceID = namespace.ID + machine.Registered = true + machine.RegisterMethod = RegisterMethodOIDC + machine.LastSuccessfulUpdate = &now + machine.Expiry = &requestedTime + h.db.Save(&machine) } - return "", false + var content bytes.Buffer + if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ + User: claims.Email, + Verb: "Authenticated", + }); err != nil { + log.Error(). + Str("func", "OIDCCallback"). + Str("type", "authenticate"). + Err(err). + Msg("Could not render OIDC callback template") + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Could not render OIDC callback template"), + ) + } + + ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) + + return }