Merge pull request #350 from restanrm/feat-oidc-login-as-namespace

This commit is contained in:
Kristoffer Dalby 2022-02-25 13:22:26 +01:00 committed by GitHub
commit b1bd17f316
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 342 additions and 324 deletions

View file

@ -45,6 +45,10 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh
- Add API Key support - Add API Key support
- Enable remote control of `headscale` via CLI [docs](docs/remote-cli.md) - Enable remote control of `headscale` via CLI [docs](docs/remote-cli.md)
- Enable HTTP API (beta, subject to change) - Enable HTTP API (beta, subject to change)
- OpenID Connect users will be mapped per namespaces
- Each user will get its own namespace, created if it does not exist
- `oidc.domain_map` option has been removed
- `strip_email_domain` option has been added (see [config-example.yaml](./config_example.yaml))
**Changes**: **Changes**:

2
app.go
View file

@ -115,7 +115,7 @@ type OIDCConfig struct {
Issuer string Issuer string
ClientID string ClientID string
ClientSecret string ClientSecret string
MatchMap map[string]string StripEmaildomain bool
} }
type DERPConfig struct { type DERPConfig struct {

View file

@ -10,7 +10,6 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -65,6 +64,8 @@ func LoadConfig(path string) error {
viper.SetDefault("cli.timeout", "5s") viper.SetDefault("cli.timeout", "5s")
viper.SetDefault("cli.insecure", false) viper.SetDefault("cli.insecure", false)
viper.SetDefault("oidc.strip_email_domain", true)
if err := viper.ReadInConfig(); err != nil { if err := viper.ReadInConfig(); err != nil {
return fmt.Errorf("fatal error reading config file: %w", err) return fmt.Errorf("fatal error reading config file: %w", err)
} }
@ -347,6 +348,7 @@ func getHeadscaleConfig() headscale.Config {
Issuer: viper.GetString("oidc.issuer"), Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"), ClientID: viper.GetString("oidc.client_id"),
ClientSecret: viper.GetString("oidc.client_secret"), ClientSecret: viper.GetString("oidc.client_secret"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
}, },
CLI: headscale.CLIConfig{ CLI: headscale.CLIConfig{
@ -376,8 +378,6 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
cfg := getHeadscaleConfig() cfg := getHeadscaleConfig()
cfg.OIDC.MatchMap = loadOIDCMatchMap()
app, err := headscale.NewHeadscale(cfg) app, err := headscale.NewHeadscale(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
@ -534,18 +534,6 @@ func (tokenAuth) RequireTransportSecurity() bool {
return true return true
} }
// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
// the match map is valid regex strings.
func loadOIDCMatchMap() map[string]string {
strMap := viper.GetStringMapString("oidc.domain_map")
for oidcMatcher := range strMap {
_ = regexp.MustCompile(oidcMatcher)
}
return strMap
}
func GetFileMode(key string) fs.FileMode { func GetFileMode(key string) fs.FileMode {
modeStr := viper.GetString(key) modeStr := viper.GetString(key)

View file

@ -187,7 +187,9 @@ unix_socket_permission: "0770"
# client_id: "your-oidc-client-id" # client_id: "your-oidc-client-id"
# client_secret: "your-oidc-client-secret" # client_secret: "your-oidc-client-secret"
# #
# # Domain map is used to map incomming users (by their email) to # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
# # a namespace. The key can be a string, or regex. # This will transform `first-name.last-name@example.com` to the namespace `first-name.last-name`
# domain_map: # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
# ".*": default-namespace # namespace: `first-name.last-name.example.com`
#
# strip_email_domain: true

4
dns.go
View file

@ -165,11 +165,7 @@ func getMapResponseDNSConfig(
dnsConfig.Domains, dnsConfig.Domains,
fmt.Sprintf( fmt.Sprintf(
"%s.%s", "%s.%s",
strings.ReplaceAll(
machine.Namespace.Name, machine.Namespace.Name,
"@",
".",
), // Replace @ with . for valid domain for machine
baseDomain, baseDomain,
), ),
) )

View file

@ -24,6 +24,11 @@ const (
errMachineAlreadyRegistered = Error("machine already registered") errMachineAlreadyRegistered = Error("machine already registered")
errMachineRouteIsNotAvailable = Error("route is not available on machine") errMachineRouteIsNotAvailable = Error("route is not available on machine")
errMachineAddressesInvalid = Error("failed to parse machine addresses") errMachineAddressesInvalid = Error("failed to parse machine addresses")
errHostnameTooLong = Error("Hostname too long")
)
const (
maxHostnameLength = 255
) )
// Machine is a Headscale client. // Machine is a Headscale client.
@ -595,13 +600,16 @@ func (machine Machine) toNode(
hostname = fmt.Sprintf( hostname = fmt.Sprintf(
"%s.%s.%s", "%s.%s.%s",
machine.Name, machine.Name,
strings.ReplaceAll(
machine.Namespace.Name, machine.Namespace.Name,
"@",
".",
), // Replace @ with . for valid domain for machine
baseDomain, baseDomain,
) )
if len(hostname) > maxHostnameLength {
return nil, fmt.Errorf(
"hostname %q is too long it cannot except 255 ASCII chars: %w",
hostname,
errHostnameTooLong,
)
}
} else { } else {
hostname = machine.Name hostname = machine.Name
} }

View file

@ -2,7 +2,10 @@ package headscale
import ( import (
"errors" "errors"
"fmt"
"regexp"
"strconv" "strconv"
"strings"
"time" "time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -16,8 +19,16 @@ const (
errNamespaceExists = Error("Namespace already exists") errNamespaceExists = Error("Namespace already exists")
errNamespaceNotFound = Error("Namespace not found") errNamespaceNotFound = Error("Namespace not found")
errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found") errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
errInvalidNamespaceName = Error("Invalid namespace name")
) )
const (
// value related to RFC 1123 and 952.
labelHostnameLength = 63
)
var invalidCharsInNamespaceRegex = regexp.MustCompile("[^a-z0-9-.]+")
// Namespace is the way Headscale implements the concept of users in Tailscale // Namespace is the way Headscale implements the concept of users in Tailscale
// //
// At the end of the day, users in Tailscale are some kind of 'bubbles' or namespaces // At the end of the day, users in Tailscale are some kind of 'bubbles' or namespaces
@ -30,6 +41,10 @@ type Namespace struct {
// CreateNamespace creates a new Namespace. Returns error if could not be created // CreateNamespace creates a new Namespace. Returns error if could not be created
// or another namespace already exists. // or another namespace already exists.
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
err := CheckNamespaceName(name)
if err != nil {
return nil, err
}
namespace := Namespace{} namespace := Namespace{}
if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil { if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
return nil, errNamespaceExists return nil, errNamespaceExists
@ -84,10 +99,15 @@ func (h *Headscale) DestroyNamespace(name string) error {
// RenameNamespace renames a Namespace. Returns error if the Namespace does // RenameNamespace renames a Namespace. Returns error if the Namespace does
// not exist or if another Namespace exists with the new name. // not exist or if another Namespace exists with the new name.
func (h *Headscale) RenameNamespace(oldName, newName string) error { func (h *Headscale) RenameNamespace(oldName, newName string) error {
var err error
oldNamespace, err := h.GetNamespace(oldName) oldNamespace, err := h.GetNamespace(oldName)
if err != nil { if err != nil {
return err return err
} }
err = CheckNamespaceName(newName)
if err != nil {
return err
}
_, err = h.GetNamespace(newName) _, err = h.GetNamespace(newName)
if err == nil { if err == nil {
return errNamespaceExists return errNamespaceExists
@ -130,6 +150,10 @@ func (h *Headscale) ListNamespaces() ([]Namespace, error) {
// ListMachinesInNamespace gets all the nodes in a given namespace. // ListMachinesInNamespace gets all the nodes in a given namespace.
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
err := CheckNamespaceName(name)
if err != nil {
return nil, err
}
namespace, err := h.GetNamespace(name) namespace, err := h.GetNamespace(name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -145,6 +169,10 @@ func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
// SetMachineNamespace assigns a Machine to a namespace. // SetMachineNamespace assigns a Machine to a namespace.
func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error { func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
err := CheckNamespaceName(namespaceName)
if err != nil {
return err
}
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return err return err
@ -208,3 +236,55 @@ func (n *Namespace) toProto() *v1.Namespace {
CreatedAt: timestamppb.New(n.CreatedAt), CreatedAt: timestamppb.New(n.CreatedAt),
} }
} }
// NormalizeNamespaceName will replace forbidden chars in namespace
// it can also return an error if the namespace doesn't respect RFC 952 and 1123.
func NormalizeNamespaceName(name string, stripEmailDomain bool) (string, error) {
name = strings.ToLower(name)
name = strings.ReplaceAll(name, "'", "")
atIdx := strings.Index(name, "@")
if stripEmailDomain && atIdx > 0 {
name = name[:atIdx]
} else {
name = strings.ReplaceAll(name, "@", ".")
}
name = invalidCharsInNamespaceRegex.ReplaceAllString(name, "-")
for _, elt := range strings.Split(name, ".") {
if len(elt) > labelHostnameLength {
return "", fmt.Errorf(
"label %v is more than 63 chars: %w",
elt,
errInvalidNamespaceName,
)
}
}
return name, nil
}
func CheckNamespaceName(name string) error {
if len(name) > labelHostnameLength {
return fmt.Errorf(
"Namespace must not be over 63 chars. %v doesn't comply with this rule: %w",
name,
errInvalidNamespaceName,
)
}
if strings.ToLower(name) != name {
return fmt.Errorf(
"Namespace name should be lowercase. %v doesn't comply with this rule: %w",
name,
errInvalidNamespaceName,
)
}
if invalidCharsInNamespaceRegex.MatchString(name) {
return fmt.Errorf(
"Namespace name should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
name,
errInvalidNamespaceName,
)
}
return nil
}

View file

@ -1,6 +1,8 @@
package headscale package headscale
import ( import (
"testing"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"inet.af/netaddr" "inet.af/netaddr"
@ -71,23 +73,23 @@ func (s *Suite) TestRenameNamespace(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(namespaces), check.Equals, 1) c.Assert(len(namespaces), check.Equals, 1)
err = app.RenameNamespace("test", "test_renamed") err = app.RenameNamespace("test", "test-renamed")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = app.GetNamespace("test") _, err = app.GetNamespace("test")
c.Assert(err, check.Equals, errNamespaceNotFound) c.Assert(err, check.Equals, errNamespaceNotFound)
_, err = app.GetNamespace("test_renamed") _, err = app.GetNamespace("test-renamed")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
err = app.RenameNamespace("test_does_not_exit", "test") err = app.RenameNamespace("test-does-not-exit", "test")
c.Assert(err, check.Equals, errNamespaceNotFound) c.Assert(err, check.Equals, errNamespaceNotFound)
namespaceTest2, err := app.CreateNamespace("test2") namespaceTest2, err := app.CreateNamespace("test2")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(namespaceTest2.Name, check.Equals, "test2") c.Assert(namespaceTest2.Name, check.Equals, "test2")
err = app.RenameNamespace("test2", "test_renamed") err = app.RenameNamespace("test2", "test-renamed")
c.Assert(err, check.Equals, errNamespaceExists) c.Assert(err, check.Equals, errNamespaceExists)
} }
@ -235,3 +237,143 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
} }
c.Assert(found, check.Equals, true) c.Assert(found, check.Equals, true)
} }
func TestNormalizeNamespaceName(t *testing.T) {
type args struct {
name string
stripEmailDomain bool
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "normalize simple name",
args: args{
name: "normalize-simple.name",
stripEmailDomain: false,
},
want: "normalize-simple.name",
wantErr: false,
},
{
name: "normalize an email",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: false,
},
want: "foo.bar.example.com",
wantErr: false,
},
{
name: "normalize an email domain should be removed",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: true,
},
want: "foo.bar",
wantErr: false,
},
{
name: "strip enabled no email passed as argument",
args: args{
name: "not-email-and-strip-enabled",
stripEmailDomain: true,
},
want: "not-email-and-strip-enabled",
wantErr: false,
},
{
name: "normalize complex email",
args: args{
name: "foo.bar+complex-email@example.com",
stripEmailDomain: false,
},
want: "foo.bar-complex-email.example.com",
wantErr: false,
},
{
name: "namespace name with space",
args: args{
name: "name space",
stripEmailDomain: false,
},
want: "name-space",
wantErr: false,
},
{
name: "namespace with quote",
args: args{
name: "Jamie's iPhone 5",
stripEmailDomain: false,
},
want: "jamies-iphone-5",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeNamespaceName(tt.args.name, tt.args.stripEmailDomain)
if (err != nil) != tt.wantErr {
t.Errorf(
"NormalizeNamespaceName() error = %v, wantErr %v",
err,
tt.wantErr,
)
return
}
if got != tt.want {
t.Errorf("NormalizeNamespaceName() = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckNamespaceName(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid: namespace",
args: args{name: "valid-namespace"},
wantErr: false,
},
{
name: "invalid: capitalized namespace",
args: args{name: "Invalid-CapItaLIzed-namespace"},
wantErr: true,
},
{
name: "invalid: email as namespace",
args: args{name: "foo.bar@example.com"},
wantErr: true,
},
{
name: "invalid: chars in namespace name",
args: args{name: "super-namespace+name"},
wantErr: true,
},
{
name: "invalid: too long name for namespace",
args: args{
name: "super-long-namespace-name-that-should-be-a-little-more-than-63-chars",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := CheckNamespaceName(tt.args.name); (err != nil) != tt.wantErr {
t.Errorf("CheckNamespaceName() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

48
oidc.go
View file

@ -9,7 +9,6 @@ import (
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
"regexp"
"strings" "strings"
"time" "time"
@ -282,7 +281,19 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
now := time.Now().UTC() now := time.Now().UTC()
if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok { namespaceName, err := NormalizeNamespaceName(
claims.Email,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().Err(err).Caller().Msgf("couldn't normalize email")
ctx.String(
http.StatusInternalServerError,
"couldn't normalize email",
)
return
}
// register the machine if it's new // register the machine if it's new
if !machine.Registered { if !machine.Registered {
log.Debug().Msg("Registering new machine after successful callback") log.Debug().Msg("Registering new machine after successful callback")
@ -317,8 +328,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
h.ipAllocationMutex.Lock()
ips, err := h.getAvailableIPs() ips, err := h.getAvailableIPs()
if err != nil { if err != nil {
log.Error(). log.Error().
@ -340,8 +349,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
machine.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
machine.Expiry = &requestedTime machine.Expiry = &requestedTime
h.db.Save(&machine) h.db.Save(&machine)
h.ipAllocationMutex.Unlock()
} }
var content bytes.Buffer var content bytes.Buffer
@ -362,33 +369,4 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
} }
ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes()) ctx.Data(http.StatusOK, "text/html; charset=utf-8", content.Bytes())
return
}
log.Error().
Caller().
Str("email", claims.Email).
Str("username", claims.Username).
Str("machine", machine.Name).
Msg("Email could not be mapped to a namespace")
ctx.String(
http.StatusBadRequest,
"email from claim could not be mapped to a namespace",
)
}
// getNamespaceFromEmail passes the users email through a list of "matchers"
// and iterates through them until it matches and returns a namespace.
// If no match is found, an empty string will be returned.
// TODO(kradalby): golang Maps key order is not stable, so this list is _not_ deterministic. Find a way to make the list of keys stable, preferably in the order presented in a users configuration.
func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) {
for match, namespace := range h.cfg.OIDC.MatchMap {
regex := regexp.MustCompile(match)
if regex.MatchString(email) {
return namespace, true
}
}
return "", false
} }

View file

@ -1,180 +0,0 @@
package headscale
import (
"sync"
"testing"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
type fields struct {
cfg Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
privateKey *key.MachinePrivate
aclPolicy *ACLPolicy
aclRules []tailcfg.FilterRule
lastStateChange sync.Map
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
}
type args struct {
email string
}
tests := []struct {
name string
fields fields
args args
want string
want1 bool
}{
{
name: "match all",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*": "space",
},
},
},
},
args: args{
email: "test@example.no",
},
want: "space",
want1: true,
},
{
name: "match user",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
"specific@user\\.no": "user-namespace",
},
},
},
},
args: args{
email: "specific@user.no",
},
want: "user-namespace",
want1: true,
},
{
name: "match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@example\\.no": "example",
},
},
},
},
args: args{
email: "test@example.no",
},
want: "example",
want1: true,
},
{
name: "multi match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@example\\.no": "exammple",
".*@gmail\\.com": "gmail",
},
},
},
},
args: args{
email: "someuser@gmail.com",
},
want: "gmail",
want1: true,
},
{
name: "no match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@dontknow.no": "never",
},
},
},
},
args: args{
email: "test@wedontknow.no",
},
want: "",
want1: false,
},
{
name: "multi no match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@dontknow.no": "never",
".*@wedontknow.no": "other",
".*\\.no": "stuffy",
},
},
},
},
args: args{
email: "tasy@nonofthem.com",
},
want: "",
want1: false,
},
}
//nolint
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
app := &Headscale{
cfg: test.fields.cfg,
db: test.fields.db,
dbString: test.fields.dbString,
dbType: test.fields.dbType,
dbDebug: test.fields.dbDebug,
privateKey: test.fields.privateKey,
aclPolicy: test.fields.aclPolicy,
aclRules: test.fields.aclRules,
lastStateChange: test.fields.lastStateChange,
oidcProvider: test.fields.oidcProvider,
oauth2Config: test.fields.oauth2Config,
oidcStateCache: test.fields.oidcStateCache,
}
got, got1 := app.getNamespaceFromEmail(test.args.email)
if got != test.want {
t.Errorf(
"Headscale.getNamespaceFromEmail() got = %v, want %v",
got,
test.want,
)
}
if got1 != test.want1 {
t.Errorf(
"Headscale.getNamespaceFromEmail() got1 = %v, want %v",
got1,
test.want1,
)
}
})
}
}

View file

@ -20,7 +20,7 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
ips, err := app.getAvailableIPs() ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
namespace, err := app.CreateNamespace("test_ip") namespace, err := app.CreateNamespace("test-ip")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
@ -160,7 +160,7 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
c.Assert(len(ips), check.Equals, 1) c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String()) c.Assert(ips[0].String(), check.Equals, expected.String())
namespace, err := app.CreateNamespace("test_ip") namespace, err := app.CreateNamespace("test-ip")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)