headscale/hscontrol/oidc.go
Kristoffer Dalby 218138afee
Redo OIDC configuration (#2020)
expand user, add claims to user

This commit expands the user table with additional fields that
can be retrieved from OIDC providers (and other places) and
uses this data in various tailscale response objects if it is
available.

This is the beginning of implementing
https://docs.google.com/document/d/1X85PMxIaVWDF6T_UPji3OeeUqVBcGj_uHRM5CI-AwlY/edit
trying to make OIDC more coherant and maintainable in addition
to giving the user a better experience and integration with a
provider.

remove usernames in magic dns, normalisation of emails

this commit removes the option to have usernames as part of MagicDNS
domains and headscale will now align with Tailscale, where there is a
root domain, and the machine name.

In addition, the various normalisation functions for dns names has been
made lighter not caring about username and special character that wont
occur.

Email are no longer normalised as part of the policy processing.

untagle oidc and regcache, use typed cache

This commits stops reusing the registration cache for oidc
purposes and switches the cache to be types and not use any
allowing the removal of a bunch of casting.

try to make reauth/register branches clearer in oidc

Currently there was a function that did a bunch of stuff,
finding the machine key, trying to find the node, reauthing
the node, returning some status, and it was called validate
which was very confusing.

This commit tries to split this into what to do if the node
exists, if it needs to register etc.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-02 14:50:17 +02:00

504 lines
14 KiB
Go

package hscontrol
import (
"bytes"
"context"
"crypto/rand"
_ "embed"
"encoding/hex"
"errors"
"fmt"
"html/template"
"net/http"
"slices"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"tailscale.com/types/key"
"zgo.at/zcache/v2"
)
const (
randomByteSize = 16
)
var (
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
errOIDCAllowedDomains = errors.New(
"authenticated principal does not match any allowed domain",
)
errOIDCAllowedGroups = errors.New("authenticated principal is not in any allowed group")
errOIDCAllowedUsers = errors.New(
"authenticated principal does not match any allowed user",
)
errOIDCInvalidNodeState = errors.New(
"requested node state key expired before authorisation completed",
)
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
)
type AuthProviderOIDC struct {
serverURL string
cfg *types.OIDCConfig
db *db.HSDatabase
registrationCache *zcache.Cache[string, key.MachinePublic]
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
}
func NewAuthProviderOIDC(
ctx context.Context,
serverURL string,
cfg *types.OIDCConfig,
db *db.HSDatabase,
notif *notifier.Notifier,
ipAlloc *db.IPAllocator,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
if err != nil {
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
}
oauth2Config := &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(serverURL, "/"),
),
Scopes: cfg.Scope,
}
registrationCache := zcache.New[string, key.MachinePublic](
registerCacheExpiration,
registerCacheCleanup,
)
return &AuthProviderOIDC{
serverURL: serverURL,
cfg: cfg,
db: db,
registrationCache: registrationCache,
notifier: notif,
ipAlloc: ipAlloc,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
}, nil
}
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
}
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
if a.cfg.UseExpiryFromToken {
return idTokenExpiration
}
return time.Now().Add(a.cfg.Expiry)
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /register/:mKey.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
log.Debug().
Caller().
Str("machine_key", machineKeyStr).
Bool("ok", ok).
Msg("Received oidc register call")
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(machineKeyStr),
)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
a.registrationCache.Set(
stateStr,
machineKey,
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams))
for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound)
}
type oidcCallbackTemplateConfig struct {
User string
Verb string
}
//go:embed assets/oidc_callback_template.html
var oidcCallbackTemplateContent string
var oidcCallbackTemplate = template.Must(
template.New("oidccallback").Parse(oidcCallbackTemplateContent),
)
// OIDCCallbackHandler handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the node to the users email user
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into node HostInfo
// Listens in /oidc/callback.
func (a *AuthProviderOIDC) OIDCCallbackHandler(
writer http.ResponseWriter,
req *http.Request,
) {
code, state, err := extractCodeAndStateParamFromRequest(req)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
idToken, err := a.extractIDToken(req.Context(), code)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
return
}
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
var claims types.OIDCClaims
if err := idToken.Claims(&claims); err != nil {
http.Error(writer, fmt.Errorf("failed to decode ID token claims: %w", err).Error(), http.StatusInternalServerError)
return
}
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
return
}
user, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
// Retrieve the node and the machine key from the state cache and
// database.
// If the node exists, then the node should be reauthenticated,
// if the node does not exist, and the machine key exists, then
// this is a new node that should be registered.
node, mKey := a.getMachineKeyFromState(state)
// Reauthenticate the node if it does exists.
if node != nil {
err := a.reauthenticateNode(node, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
// TODO(kradalby): replace with go-elem
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Reauthenticated",
}); err != nil {
http.Error(writer, fmt.Errorf("rendering OIDC callback template: %w", err).Error(), http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
// Register the node if it does not exist.
if mKey != nil {
if err := a.registerNode(user, mKey, nodeExpiry); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
content, err := renderOIDCCallbackTemplate(user)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
func extractCodeAndStateParamFromRequest(
req *http.Request,
) (string, string, error) {
code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state")
if code == "" || state == "" {
return "", "", errEmptyOIDCCallbackParams
}
return code, state, nil
}
// extractIDToken takes the code parameter from the callback
// and extracts the ID token from the oauth2 token.
func (a *AuthProviderOIDC) extractIDToken(
ctx context.Context,
code string,
) (*oidc.IDToken, error) {
oauth2Token, err := a.oauth2Config.Exchange(ctx, code)
if err != nil {
return nil, fmt.Errorf("could not exchange code for token: %w", err)
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return nil, errNoOIDCIDToken
}
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, fmt.Errorf("failed to verify ID token: %w", err)
}
return idToken, nil
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
allowedDomains []string,
claims *types.OIDCClaims,
) error {
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
return errOIDCAllowedDomains
}
}
return nil
}
// validateOIDCAllowedGroups checks if AllowedGroups is provided,
// and that the user has one group in the list.
// claims.Groups can be populated by adding a client scope named
// 'groups' that contains group membership.
func validateOIDCAllowedGroups(
allowedGroups []string,
claims *types.OIDCClaims,
) error {
if len(allowedGroups) > 0 {
for _, group := range allowedGroups {
if slices.Contains(claims.Groups, group) {
return nil
}
}
return errOIDCAllowedGroups
}
return nil
}
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
// that the authenticated principal is part of that list.
func validateOIDCAllowedUsers(
allowedUsers []string,
claims *types.OIDCClaims,
) error {
if len(allowedUsers) > 0 &&
!slices.Contains(allowedUsers, claims.Email) {
log.Trace().Msg("authenticated principal does not match any allowed user")
return errOIDCAllowedUsers
}
return nil
}
// getMachineKeyFromState retrieves the machine key from the state
// cache. If the machine key is found, it will try retrieve the
// node information from the database.
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
machineKey, ok := a.registrationCache.Get(state)
if !ok {
return nil, nil
}
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := a.db.GetNodeByMachineKey(machineKey)
return node, &machineKey
}
// reauthenticateNode updates the node expiry in the database
// and notifies the node and its peers about the change.
func (a *AuthProviderOIDC) reauthenticateNode(
node *types.Node,
expiry time.Time,
) error {
err := a.db.NodeSetExpiry(node.ID, expiry)
if err != nil {
return err
}
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.StateUpdate{
Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{node.ID},
},
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, expiry), node.ID)
return nil
}
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, error) {
var user *types.User
var err error
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
// This check is for legacy, if the user cannot be found by the OIDC identifier
// look it up by username. This should only be needed once.
if user == nil {
user, err = a.db.GetUserByName(claims.Username)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
// if the user is still not found, create a new empty user.
if user == nil {
user = &types.User{}
}
}
user.FromClaim(claims)
err = a.db.DB.Save(user).Error
if err != nil {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
return user, nil
}
func (a *AuthProviderOIDC) registerNode(
user *types.User,
machineKey *key.MachinePublic,
expiry time.Time,
) error {
ipv4, ipv6, err := a.ipAlloc.Next()
if err != nil {
return err
}
if _, err := a.db.RegisterNodeFromAuthCallback(
*machineKey,
types.UserID(user.ID),
&expiry,
util.RegisterMethodOIDC,
ipv4, ipv6,
); err != nil {
return fmt.Errorf("could not register node: %w", err)
}
return nil
}
// TODO(kradalby):
// Rewrite in elem-go
func renderOIDCCallbackTemplate(
user *types.User,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: user.DisplayNameOrUsername(),
Verb: "Authenticated",
}); err != nil {
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
}
return &content, nil
}