allow @ and Log if OIDC username is not consider valid (#2340)

This commit is contained in:
Kristoffer Dalby 2025-01-16 18:04:54 +01:00 committed by GitHub
parent 1ab7b315a2
commit 38aef77e54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 38 additions and 4 deletions

View file

@ -10,6 +10,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -173,9 +174,11 @@ func (c *OIDCClaims) Identifier() string {
// FromClaim overrides a User from OIDC claims. // FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID. // All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) { func (u *User) FromClaim(claims *OIDCClaims) {
err := util.CheckForFQDNRules(claims.Username) err := util.ValidateUsername(claims.Username)
if err == nil { if err == nil {
u.Name = claims.Username u.Name = claims.Username
} else {
log.Debug().Err(err).Msgf("Username %s is not valid", claims.Username)
} }
if claims.EmailVerified { if claims.EmailVerified {

View file

@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"regexp" "regexp"
"strings" "strings"
"unicode"
"go4.org/netipx" "go4.org/netipx"
"tailscale.com/util/dnsname" "tailscale.com/util/dnsname"
@ -20,10 +21,40 @@ const (
LabelHostnameLength = 63 LabelHostnameLength = 63
) )
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
var ErrInvalidUserName = errors.New("invalid user name") var ErrInvalidUserName = errors.New("invalid user name")
func ValidateUsername(username string) error {
// Ensure the username meets the minimum length requirement
if len(username) < 2 {
return errors.New("username must be at least 2 characters long")
}
// Ensure the username does not start with a number
if unicode.IsDigit(rune(username[0])) {
return errors.New("username cannot start with a number")
}
atCount := 0
for _, char := range username {
switch {
case unicode.IsLetter(char), unicode.IsDigit(char), char == '-':
// Valid characters
case char == '@':
atCount++
if atCount > 1 {
return errors.New("username cannot contain more than one '@'")
}
default:
return fmt.Errorf("username contains invalid character: '%c'", char)
}
}
return nil
}
func CheckForFQDNRules(name string) error { func CheckForFQDNRules(name string) error {
if len(name) > LabelHostnameLength { if len(name) > LabelHostnameLength {
return fmt.Errorf( return fmt.Errorf(
@ -39,7 +70,7 @@ func CheckForFQDNRules(name string) error {
ErrInvalidUserName, ErrInvalidUserName,
) )
} }
if invalidCharsInUserRegex.MatchString(name) { if invalidDNSRegex.MatchString(name) {
return fmt.Errorf( return fmt.Errorf(
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
name, name,
@ -52,7 +83,7 @@ func CheckForFQDNRules(name string) error {
func ConvertWithFQDNRules(name string) string { func ConvertWithFQDNRules(name string) string {
name = strings.ToLower(name) name = strings.ToLower(name)
name = invalidCharsInUserRegex.ReplaceAllString(name, "") name = invalidDNSRegex.ReplaceAllString(name, "")
return name return name
} }
@ -197,7 +228,7 @@ func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
} else { } else {
name = strings.ReplaceAll(name, "@", ".") name = strings.ReplaceAll(name, "@", ".")
} }
name = invalidCharsInUserRegex.ReplaceAllString(name, "-") name = invalidDNSRegex.ReplaceAllString(name, "-")
for _, elt := range strings.Split(name, ".") { for _, elt := range strings.Split(name, ".") {
if len(elt) > LabelHostnameLength { if len(elt) > LabelHostnameLength {