diff --git a/api.go b/api.go index cbe48072..c542b3aa 100644 --- a/api.go +++ b/api.go @@ -170,7 +170,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("machine", m.Name). Msg("Machine registration has expired. Sending a authurl to register") - if h.cfg.OIDCIssuer != "" { + if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { @@ -225,7 +225,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Str("machine", m.Name). Msg("The node is sending us a new NodeKey, sending auth url") - if h.cfg.OIDCIssuer != "" { + if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", @@ -424,7 +424,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, db.Save(&m) h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys? - + pak.Used = true db.Save(&pak) diff --git a/app.go b/app.go index 89c43589..c89856f8 100644 --- a/app.go +++ b/app.go @@ -3,9 +3,6 @@ package headscale import ( "errors" "fmt" - "github.com/coreos/go-oidc/v3/oidc" - "github.com/patrickmn/go-cache" - "golang.org/x/oauth2" "net/http" "os" "sort" @@ -13,6 +10,10 @@ import ( "sync" "time" + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" + "github.com/rs/zerolog/log" "github.com/gin-gonic/gin" @@ -57,14 +58,19 @@ type Config struct { DNSConfig *tailcfg.DNSConfig - OIDCIssuer string - OIDCClientID string - OIDCClientSecret string + OIDC OIDCConfig MaxMachineRegistrationDuration time.Duration DefaultMachineRegistrationDuration time.Duration } +type OIDCConfig struct { + Issuer string + ClientID string + ClientSecret string + MatchMap map[string]string +} + // Headscale represents the base app of the service type Headscale struct { cfg Config @@ -122,14 +128,14 @@ func NewHeadscale(cfg Config) (*Headscale, error) { return nil, err } - if cfg.OIDCIssuer != "" { + if cfg.OIDC.Issuer != "" { err = h.initOIDC() if err != nil { return nil, err } - } + } - if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS + if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) if err != nil { return nil, err @@ -294,7 +300,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { times = append(times, lastChange) } - } sort.Slice(times, func(i, j int) bool { @@ -305,7 +310,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { if len(times) == 0 { return time.Now().UTC() - } else { return times[0] } diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index f29c389d..4a598e77 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -7,6 +7,7 @@ import ( "io" "os" "path/filepath" + "regexp" "strings" "time" @@ -73,7 +74,6 @@ func LoadConfig(path string) error { } else { return nil } - } func GetDNSConfig() (*tailcfg.DNSConfig, string) { @@ -206,15 +206,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) { ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), - OIDCIssuer: viper.GetString("oidc_issuer"), - OIDCClientID: viper.GetString("oidc_client_id"), - OIDCClientSecret: viper.GetString("oidc_client_secret"), + OIDC: headscale.OIDCConfig{ + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: viper.GetString("oidc.client_secret"), + }, MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration } + cfg.OIDC.MatchMap = loadOIDCMatchMap() + h, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err @@ -291,3 +295,15 @@ func HasJsonOutputFlag() bool { } return false } + +// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in +// the match map is valid regex strings. +func loadOIDCMatchMap() map[string]string { + strMap := viper.GetStringMapString("oidc.domain_map") + + for oidcMatcher := range strMap { + _ = regexp.MustCompile(oidcMatcher) + } + + return strMap +} diff --git a/oidc.go b/oidc.go index 01c54b44..1b13963c 100644 --- a/oidc.go +++ b/oidc.go @@ -5,14 +5,16 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/http" + "regexp" + "strings" + "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "net/http" - "strings" - "time" ) type IDTokenClaims struct { @@ -26,7 +28,7 @@ func (h *Headscale) initOIDC() error { var err error // grab oidc config if it hasn't been already if h.oauth2Config == nil { - h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) + h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) if err != nil { log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) @@ -34,8 +36,8 @@ func (h *Headscale) initOIDC() error { } h.oauth2Config = &oauth2.Config{ - ClientID: h.cfg.OIDCClientID, - ClientSecret: h.cfg.OIDCClientSecret, + ClientID: h.cfg.OIDC.ClientID, + ClientSecret: h.cfg.OIDC.ClientSecret, Endpoint: h.oidcProvider.Endpoint(), RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, @@ -62,7 +64,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { b := make([]byte, 16) _, err := rand.Read(b) - if err != nil { log.Error().Msg("could not read 16 bytes from rand") c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") @@ -86,7 +87,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) { // TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback func (h *Headscale) OIDCCallback(c *gin.Context) { - code := c.Query("code") state := c.Query("state") @@ -109,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) + verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { @@ -131,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { return } - //retrieve machinekey from state cache + // retrieve machinekey from state cache mKeyIf, mKeyFound := h.oidcStateCache.Get(state) if !mKeyFound { @@ -149,7 +149,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { // retrieve machine information m, err := h.GetMachineByMachineKey(mKeyStr) - if err != nil { log.Error().Msg("machine key not found in database") c.String(http.StatusInternalServerError, "could not get machine info from database") @@ -158,40 +157,40 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { now := time.Now().UTC() - // register the machine if it's new - if !m.Registered { - nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation + if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { + // register the machine if it's new + if !m.Registered { - log.Debug().Msg("Registering new machine after successful callback") - - ns, err := h.GetNamespace(nsName) - if err != nil { - ns, err = h.CreateNamespace(nsName) + log.Debug().Msg("Registering new machine after successful callback") + ns, err := h.GetNamespace(nsName) if err != nil { - log.Error().Msgf("could not create new namespace '%s'", claims.Email) - c.String(http.StatusInternalServerError, "could not create new namespace") + ns, err = h.CreateNamespace(nsName) + + if err != nil { + log.Error().Msgf("could not create new namespace '%s'", claims.Email) + c.String(http.StatusInternalServerError, "could not create new namespace") + return + } + } + + ip, err := h.getAvailableIP() + if err != nil { + c.String(http.StatusInternalServerError, "could not get an IP from the pool") return } + + m.IPAddress = ip.String() + m.NamespaceID = ns.ID + m.Registered = true + m.RegisterMethod = "oidc" + m.LastSuccessfulUpdate = &now + h.db.Save(&m) } - ip, err := h.getAvailableIP() - if err != nil { - c.String(http.StatusInternalServerError, "could not get an IP from the pool") - return - } + h.updateMachineExpiry(m) - m.IPAddress = ip.String() - m.NamespaceID = ns.ID - m.Registered = true - m.RegisterMethod = "oidc" - m.LastSuccessfulUpdate = &now - h.db.Save(&m) - } - - h.updateMachineExpiry(m) - - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`

headscale

@@ -202,4 +201,24 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { `, claims.Email))) + + } + + log.Error(). + Str("email", claims.Email). + Str("username", claims.Username). + Str("machine", m.Name). + Msg("Email could not be mapped to a namespace") + c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace") +} + +func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) { + for match, namespace := range h.cfg.OIDC.MatchMap { + regex := regexp.MustCompile(match) + if regex.MatchString(email) { + return namespace, true + } + } + + return "", false } diff --git a/oidc_test.go b/oidc_test.go new file mode 100644 index 00000000..ddb44e4c --- /dev/null +++ b/oidc_test.go @@ -0,0 +1,173 @@ +package headscale + +import ( + "sync" + "testing" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/patrickmn/go-cache" + "golang.org/x/oauth2" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/wgkey" +) + +func TestHeadscale_getNamespaceFromEmail(t *testing.T) { + type fields struct { + cfg Config + db *gorm.DB + dbString string + dbType string + dbDebug bool + publicKey *wgkey.Key + privateKey *wgkey.Private + aclPolicy *ACLPolicy + aclRules *[]tailcfg.FilterRule + lastStateChange sync.Map + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config + oidcStateCache *cache.Cache + } + type args struct { + email string + } + tests := []struct { + name string + fields fields + args args + want string + want1 bool + }{ + { + name: "match all", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*": "space", + }, + }, + }, + }, + args: args{ + email: "test@example.no", + }, + want: "space", + want1: true, + }, + { + name: "match user", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + "specific@user\\.no": "user-namespace", + }, + }, + }, + }, + args: args{ + email: "specific@user.no", + }, + want: "user-namespace", + want1: true, + }, + { + name: "match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@example\\.no": "example", + }, + }, + }, + }, + args: args{ + email: "test@example.no", + }, + want: "example", + want1: true, + }, + { + name: "multi match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@example\\.no": "exammple", + ".*@gmail\\.com": "gmail", + }, + }, + }, + }, + args: args{ + email: "someuser@gmail.com", + }, + want: "gmail", + want1: true, + }, + { + name: "no match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@dontknow.no": "never", + }, + }, + }, + }, + args: args{ + email: "test@wedontknow.no", + }, + want: "", + want1: false, + }, + { + name: "multi no match domain", + fields: fields{ + cfg: Config{ + OIDC: OIDCConfig{ + MatchMap: map[string]string{ + ".*@dontknow.no": "never", + ".*@wedontknow.no": "other", + ".*\\.no": "stuffy", + }, + }, + }, + }, + args: args{ + email: "tasy@nonofthem.com", + }, + want: "", + want1: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Headscale{ + cfg: tt.fields.cfg, + db: tt.fields.db, + dbString: tt.fields.dbString, + dbType: tt.fields.dbType, + dbDebug: tt.fields.dbDebug, + publicKey: tt.fields.publicKey, + privateKey: tt.fields.privateKey, + aclPolicy: tt.fields.aclPolicy, + aclRules: tt.fields.aclRules, + lastStateChange: tt.fields.lastStateChange, + oidcProvider: tt.fields.oidcProvider, + oauth2Config: tt.fields.oauth2Config, + oidcStateCache: tt.fields.oidcStateCache, + } + got, got1 := h.getNamespaceFromEmail(tt.args.email) + if got != tt.want { + t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1) + } + }) + } +}