diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 55cd8fb1..8cae0016 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -10,6 +10,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -173,9 +174,11 @@ func (c *OIDCClaims) Identifier() string { // FromClaim overrides a User from OIDC claims. // All fields will be updated, except for the ID. func (u *User) FromClaim(claims *OIDCClaims) { - err := util.CheckForFQDNRules(claims.Username) + err := util.ValidateUsername(claims.Username) if err == nil { u.Name = claims.Username + } else { + log.Debug().Err(err).Msgf("Username %s is not valid", claims.Username) } if claims.EmailVerified { diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index c6861c9e..d55d6e8a 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -6,6 +6,7 @@ import ( "net/netip" "regexp" "strings" + "unicode" "go4.org/netipx" "tailscale.com/util/dnsname" @@ -20,10 +21,40 @@ const ( LabelHostnameLength = 63 ) +var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") var ErrInvalidUserName = errors.New("invalid user name") +func ValidateUsername(username string) error { + // Ensure the username meets the minimum length requirement + if len(username) < 2 { + return errors.New("username must be at least 2 characters long") + } + + // Ensure the username does not start with a number + if unicode.IsDigit(rune(username[0])) { + return errors.New("username cannot start with a number") + } + + atCount := 0 + for _, char := range username { + switch { + case unicode.IsLetter(char), unicode.IsDigit(char), char == '-': + // Valid characters + case char == '@': + atCount++ + if atCount > 1 { + return errors.New("username cannot contain more than one '@'") + } + default: + return fmt.Errorf("username contains invalid character: '%c'", char) + } + } + + return nil +} + func CheckForFQDNRules(name string) error { if len(name) > LabelHostnameLength { return fmt.Errorf( @@ -39,7 +70,7 @@ func CheckForFQDNRules(name string) error { ErrInvalidUserName, ) } - if invalidCharsInUserRegex.MatchString(name) { + if invalidDNSRegex.MatchString(name) { return fmt.Errorf( "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", name, @@ -52,7 +83,7 @@ func CheckForFQDNRules(name string) error { func ConvertWithFQDNRules(name string) string { name = strings.ToLower(name) - name = invalidCharsInUserRegex.ReplaceAllString(name, "") + name = invalidDNSRegex.ReplaceAllString(name, "") return name } @@ -197,7 +228,7 @@ func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { } else { name = strings.ReplaceAll(name, "@", ".") } - name = invalidCharsInUserRegex.ReplaceAllString(name, "-") + name = invalidDNSRegex.ReplaceAllString(name, "-") for _, elt := range strings.Split(name, ".") { if len(elt) > LabelHostnameLength {