diff --git a/api.go b/api.go index cbe48072..c542b3aa 100644 --- a/api.go +++ b/api.go @@ -170,7 +170,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("machine", m.Name). Msg("Machine registration has expired. Sending a authurl to register") - if h.cfg.OIDCIssuer != "" { + if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { @@ -225,7 +225,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, sending auth url") - if h.cfg.OIDCIssuer != "" { + if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", @@ -424,7 +424,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, db.Save(&m) h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys? - + pak.Used = true db.Save(&pak) diff --git a/app.go b/app.go index 89c43589..c89856f8 100644 --- a/app.go +++ b/app.go @@ -3,9 +3,6 @@ package headscale import ( "errors" "fmt" - "github.com/coreos/go-oidc/v3/oidc" - "github.com/patrickmn/go-cache" - "golang.org/x/oauth2" "net/http" "os" "sort" @@ -13,6 +10,10 @@ import ( "sync" "time" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" + "github.com/rs/zerolog/log" "github.com/gin-gonic/gin" @@ -57,14 +58,19 @@ type Config struct { DNSConfig *tailcfg.DNSConfig - OIDCIssuer string - OIDCClientID string - OIDCClientSecret string + OIDC OIDCConfig MaxMachineRegistrationDuration time.Duration DefaultMachineRegistrationDuration time.Duration } +type OIDCConfig struct { + Issuer string + ClientID string + ClientSecret string + MatchMap map[string]string +} + // Headscale represents the base app of the service type Headscale struct { cfg Config @@ -122,14 +128,14 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, err } - if cfg.OIDCIssuer != "" { + if cfg.OIDC.Issuer != "" { err = h.initOIDC() if err != nil { return nil, err } - } + } - if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS + if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) if err != nil { return nil, err @@ -294,7 +300,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { times = append(times, lastChange) } - } sort.Slice(times, func(i, j int) bool { @@ -305,7 +310,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { if len(times) == 0 { return time.Now().UTC() - } else { return times[0] } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index f29c389d..4a598e77 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -7,6 +7,7 @@ import ( "io" "os" "path/filepath" + "regexp" "strings" "time" @@ -73,7 +74,6 @@ func LoadConfig(path string) error { } else { return nil } - } func GetDNSConfig() (*tailcfg.DNSConfig, string) { @@ -206,15 +206,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) { ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), - OIDCIssuer: viper.GetString("oidc_issuer"), - OIDCClientID: viper.GetString("oidc_client_id"), - OIDCClientSecret: viper.GetString("oidc_client_secret"), + OIDC: headscale.OIDCConfig{ + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: viper.GetString("oidc.client_secret"), + }, MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration } + cfg.OIDC.MatchMap = loadOIDCMatchMap() + h, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err @@ -291,3 +295,15 @@ func HasJsonOutputFlag() bool { } return false } + +// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in +// the match map is valid regex strings. +func loadOIDCMatchMap() map[string]string { + strMap := viper.GetStringMapString("oidc.domain_map") + + for oidcMatcher := range strMap { + _ = regexp.MustCompile(oidcMatcher) + } + + return strMap +} diff --git a/oidc.go b/oidc.go index 01c54b44..1b13963c 100644 --- a/oidc.go +++ b/oidc.go @@ -5,14 +5,16 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/http" + "regexp" + "strings" + "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "net/http" - "strings" - "time" ) type IDTokenClaims struct { @@ -26,7 +28,7 @@ func (h *Headscale) initOIDC() error { var err error // grab oidc config if it hasn't been already if h.oauth2Config == nil { - h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) + h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) if err != nil { log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) @@ -34,8 +36,8 @@ func (h *Headscale) initOIDC() error { } h.oauth2Config = &oauth2.Config{ - ClientID: h.cfg.OIDCClientID, - ClientSecret: h.cfg.OIDCClientSecret, + ClientID: h.cfg.OIDC.ClientID, + ClientSecret: h.cfg.OIDC.ClientSecret, Endpoint: h.oidcProvider.Endpoint(), RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, @@ -62,7 +64,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { b := make([]byte, 16) _, err := rand.Read(b) - if err != nil { log.Error().Msg("could not read 16 bytes from rand") c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") @@ -86,7 +87,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { // TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback func (h *Headscale) OIDCCallback(c *gin.Context) { - code := c.Query("code") state := c.Query("state") @@ -109,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) + verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { @@ -131,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - //retrieve machinekey from state cache + // retrieve machinekey from state cache mKeyIf, mKeyFound := h.oidcStateCache.Get(state) if !mKeyFound { @@ -149,7 +149,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { // retrieve machine information m, err := h.GetMachineByMachineKey(mKeyStr) - if err != nil { log.Error().Msg("machine key not found in database") c.String(http.StatusInternalServerError, "could not get machine info from database") @@ -158,40 +157,40 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { now := time.Now().UTC() - // register the machine if it's new - if !m.Registered { - nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation + if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { + // register the machine if it's new + if !m.Registered { - log.Debug().Msg("Registering new machine after successful callback") - - ns, err := h.GetNamespace(nsName) - if err != nil { - ns, err = h.CreateNamespace(nsName) + log.Debug().Msg("Registering new machine after successful callback") + ns, err := h.GetNamespace(nsName) if err != nil { - log.Error().Msgf("could not create new namespace '%s'", claims.Email) - c.String(http.StatusInternalServerError, "could not create new namespace") + ns, err = h.CreateNamespace(nsName) + + if err != nil { + log.Error().Msgf("could not create new namespace '%s'", claims.Email) + c.String(http.StatusInternalServerError, "could not create new namespace") + return + } + } + + ip, err := h.getAvailableIP() + if err != nil { + c.String(http.StatusInternalServerError, "could not get an IP from the pool") return } + + m.IPAddress = ip.String() + m.NamespaceID = ns.ID + m.Registered = true + m.RegisterMethod = "oidc" + m.LastSuccessfulUpdate = &now + h.db.Save(&m) } - ip, err := h.getAvailableIP() - if err != nil { - c.String(http.StatusInternalServerError, "could not get an IP from the pool") - return - } + h.updateMachineExpiry(m) - m.IPAddress = ip.String() - m.NamespaceID = ns.ID - m.Registered = true - m.RegisterMethod = "oidc" - m.LastSuccessfulUpdate = &now - h.db.Save(&m) - } - - h.updateMachineExpiry(m) - - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`