diff --git a/CHANGELOG.md b/CHANGELOG.md index 01d630a7..30043c8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - Fix send on closed channel crash in polling [#542](https://github.com/juanfont/headscale/pull/542) - Fixed spurious calls to setLastStateChangeToNow from ephemeral nodes [#566](https://github.com/juanfont/headscale/pull/566) - Add command for moving nodes between namespaces [#362](https://github.com/juanfont/headscale/issues/362) +- Added more configuration parameters for OpenID Connect (scopes, free-form paramters, domain and user allowlist) ## 0.15.0 (2022-03-20) diff --git a/app.go b/app.go index f2f481f8..a96ab83a 100644 --- a/app.go +++ b/app.go @@ -119,6 +119,10 @@ type OIDCConfig struct { Issuer string ClientID string ClientSecret string + Scope []string + ExtraParams map[string]string + AllowedDomains []string + AllowedUsers []string StripEmaildomain bool } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 992d1254..3e7aa2f8 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" @@ -67,6 +68,7 @@ func LoadConfig(path string) error { viper.SetDefault("cli.timeout", "5s") viper.SetDefault("cli.insecure", false) + viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.strip_email_domain", true) if err := viper.ReadInConfig(); err != nil { @@ -367,6 +369,10 @@ func getHeadscaleConfig() headscale.Config { Issuer: viper.GetString("oidc.issuer"), ClientID: viper.GetString("oidc.client_id"), ClientSecret: viper.GetString("oidc.client_secret"), + Scope: viper.GetStringSlice("oidc.scope"), + ExtraParams: viper.GetStringMapString("oidc.extra_params"), + AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), + AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), }, diff --git a/config-example.yaml b/config-example.yaml index dee25cb3..ebaa7101 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -214,6 +214,21 @@ unix_socket_permission: "0770" # client_id: "your-oidc-client-id" # client_secret: "your-oidc-client-secret" # +# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query +# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email". +# +# scope: ["openid", "profile", "email", "custom"] +# extra_params: +# domain_hint: example.com +# +# List allowed principal domains and/or users. If an authenticated user's domain is not in this list, the +# authentication request will be rejected. +# +# allowed_domains: +# - example.com +# allowed_users: +# - alice@example.com +# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. # This will transform `first-name.last-name@example.com` to the namespace `first-name.last-name` # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following diff --git a/oidc.go b/oidc.go index 598a208d..73a59778 100644 --- a/oidc.go +++ b/oidc.go @@ -53,7 +53,7 @@ func (h *Headscale) initOIDC() error { "%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/"), ), - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + Scopes: h.cfg.OIDC.Scope, } } @@ -91,7 +91,14 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { // place the machine key into the state cache, so it can be retrieved later h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) - authURL := h.oauth2Config.AuthCodeURL(stateStr) + // Add any extra parameter provided in the configuration to the Authorize Endpoint request + extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) + + for k, v := range h.cfg.OIDC.ExtraParams { + extras = append(extras, oauth2.SetAuthURLParam(k, v)) + } + + authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...) log.Debug().Msgf("Redirecting to %s for authentication", authURL) ctx.Redirect(http.StatusFound, authURL) @@ -187,6 +194,29 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } + // If AllowedDomains is provided, check that the authenticated principal ends with @. + if len(h.cfg.OIDC.AllowedDomains) > 0 { + if at := strings.LastIndex(claims.Email, "@"); at < 0 || + !IsStringInSlice(h.cfg.OIDC.AllowedDomains, claims.Email[at+1:]) { + log.Error().Msg("authenticated principal does not match any allowed domain") + ctx.String( + http.StatusBadRequest, + "unauthorized principal (domain mismatch)", + ) + + return + } + } + + // If AllowedUsers is provided, check that the authenticated princial is part of that list. + if len(h.cfg.OIDC.AllowedUsers) > 0 && + !IsStringInSlice(h.cfg.OIDC.AllowedUsers, claims.Email) { + log.Error().Msg("authenticated principal does not match any allowed user") + ctx.String(http.StatusBadRequest, "unauthorized principal (user mismatch)") + + return + } + // retrieve machinekey from state cache machineKeyIf, machineKeyFound := h.registrationCache.Get(state) diff --git a/utils.go b/utils.go index af267eb3..3dace875 100644 --- a/utils.go +++ b/utils.go @@ -317,3 +317,13 @@ func GenerateRandomStringURLSafe(n int) (string, error) { return base64.RawURLEncoding.EncodeToString(b), err } + +func IsStringInSlice(slice []string, str string) bool { + for _, s := range slice { + if s == str { + return true + } + } + + return false +}