use registration ID in reg flow, support followup

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2025-01-13 11:54:35 +01:00
parent fbd0d5e7ce
commit b4e7394563
No known key found for this signature in database
8 changed files with 143 additions and 119 deletions

View file

@ -96,7 +96,7 @@ type Headscale struct {
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
registrationCache *zcache.Cache[string, types.Node]
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
authProvider AuthProvider
@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
}
registrationCache := zcache.New[string, types.Node](
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
registerCacheExpiration,
registerCacheCleanup,
)
@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)

View file

@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/db"
@ -20,16 +22,18 @@ import (
type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(key.MachinePublic) string
AuthURL(types.RegistrationID) string
}
func logAuthFunc(
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) (func(string), func(string), func(error, string)) {
return func(msg string) {
log.Info().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@ -41,6 +45,7 @@ func logAuthFunc(
func(msg string) {
log.Trace().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@ -52,6 +57,7 @@ func logAuthFunc(
func(err error, msg string) {
log.Error().
Caller().
Str("registration_id", registrationId.String()).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@ -63,6 +69,40 @@ func logAuthFunc(
}
}
func (h *Headscale) waitForFollowup(
req *http.Request,
regReq tailcfg.RegisterRequest,
logTrace func(string),
) {
logTrace("register request is a followup")
fu, err := url.Parse(regReq.Followup)
if err != nil {
logTrace("failed to parse followup URL")
return
}
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
if err != nil {
logTrace("followup URL does not contains a valid registration ID")
return
}
logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg))
if reg, ok := h.registrationCache.Get(followupReg); ok {
logTrace("Node is waiting for interactive login")
select {
case <-req.Context().Done():
logTrace("node went away before it was registered")
return
case <-reg.Registered:
logTrace("node has successfully registered")
return
}
}
}
// handleRegister is the logic for registering a client.
func (h *Headscale) handleRegister(
writer http.ResponseWriter,
@ -70,7 +110,18 @@ func (h *Headscale) handleRegister(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
registrationId, err := types.NewRegistrationID()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to generate registration ID")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
now := time.Now().UTC()
logTrace("handleRegister called, looking up machine in DB")
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
@ -84,27 +135,9 @@ func (h *Headscale) handleRegister(
}
// Check if the node is waiting for interactive login.
//
// TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the node).
// This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if regReq.Followup != "" {
logTrace("register request is a followup")
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
logTrace("Node is waiting for interactive login")
select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
h.handleNewNode(writer, regReq, machineKey)
return
}
}
h.waitForFollowup(req, regReq, logTrace)
return
}
logInfo("Node not found in database, creating new")
@ -113,25 +146,28 @@ func (h *Headscale) handleRegister(
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the node and then keep it around until a callback
// happens
newNode := types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
newNode := types.RegisterNode{
Node: types.Node{
MachineKey: machineKey,
Hostname: regReq.Hostinfo.Hostname,
NodeKey: regReq.NodeKey,
LastSeen: &now,
Expiry: &time.Time{},
},
Registered: make(chan struct{}),
}
if !regReq.Expiry.IsZero() {
logTrace("Non-zero expiry time requested")
newNode.Expiry = &regReq.Expiry
newNode.Node.Expiry = &regReq.Expiry
}
h.registrationCache.Set(
machineKey.String(),
registrationId,
newNode,
)
h.handleNewNode(writer, regReq, machineKey)
h.handleNewNode(writer, regReq, registrationId)
return
}
@ -206,15 +242,12 @@ func (h *Headscale) handleRegister(
}
if regReq.Followup != "" {
select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
}
h.waitForFollowup(req, regReq, logTrace)
return
}
// The node has expired or it is logged out
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey)
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId)
// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
node.Expiry = &time.Time{}
@ -444,16 +477,16 @@ func (h *Headscale) handleAuthKey(
func (h *Headscale) handleNewNode(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId)
resp := tailcfg.RegisterResponse{}
// The node registration is new, redirect the client to the registration URL
logTrace("The node seems to be new, sending auth url")
logTrace("The node is new, sending auth url")
resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)
respBody, err := json.Marshal(resp)
if err != nil {
@ -660,6 +693,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
regReq tailcfg.RegisterRequest,
node types.Node,
machineKey key.MachinePublic,
registrationId types.RegistrationID,
) {
resp := tailcfg.RegisterResponse{}
@ -673,12 +707,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
log.Trace().
Caller().
Str("node", node.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Msg("Node registration has expired or logged out. Sending a auth url to register")
resp.AuthURL = h.authProvider.AuthURL(machineKey)
resp.AuthURL = h.authProvider.AuthURL(registrationId)
respBody, err := json.Marshal(resp)
if err != nil {
@ -703,7 +737,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
log.Trace().
Caller().
Str("machine_key", machineKey.ShortString()).
Str("registration_id", registrationId.String()).
Str("node_key", regReq.NodeKey.ShortString()).
Str("node_key_old", regReq.OldNodeKey.ShortString()).
Str("node", node.Hostname).

View file

@ -41,7 +41,7 @@ type KV struct {
type HSDatabase struct {
DB *gorm.DB
cfg *types.DatabaseConfig
regCache *zcache.Cache[string, types.Node]
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
baseDomain string
}
@ -51,7 +51,7 @@ type HSDatabase struct {
func NewHeadscaleDatabase(
cfg types.DatabaseConfig,
baseDomain string,
regCache *zcache.Cache[string, types.Node],
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
) (*HSDatabase, error) {
dbConn, err := openDB(cfg)
if err != nil {

View file

@ -256,8 +256,8 @@ func testCopyOfDatabase(src string) (string, error) {
return dst, err
}
func emptyCache() *zcache.Cache[string, types.Node] {
return zcache.New[string, types.Node](time.Minute, time.Hour)
func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] {
return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour)
}
// requireConstraintFailed checks if the error is a constraint failure with

View file

@ -320,7 +320,7 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
}
func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
mkey key.MachinePublic,
registrationID types.RegistrationID,
userID types.UserID,
nodeExpiry *time.Time,
registrationMethod string,
@ -328,7 +328,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
ipv6 *netip.Addr,
) (*types.Node, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if node, ok := hsdb.regCache.Get(mkey.String()); ok {
if reg, ok := hsdb.regCache.Get(registrationID); ok {
user, err := GetUserByID(tx, userID)
if err != nil {
return nil, fmt.Errorf(
@ -338,36 +338,40 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback(
}
log.Debug().
Str("machine_key", mkey.ShortString()).
Str("registration_id", registrationID.String()).
Str("username", user.Username()).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)).
Msg("Registering node from API/CLI or auth callback")
// TODO(kradalby): This looks quite wrong? why ID 0?
// Why not always?
// Registration of expired node with different user
if node.ID != 0 &&
node.UserID != user.ID {
if reg.Node.ID != 0 &&
reg.Node.UserID != user.ID {
return nil, ErrDifferentRegisteredUser
}
node.UserID = user.ID
node.User = *user
node.RegisterMethod = registrationMethod
reg.Node.UserID = user.ID
reg.Node.User = *user
reg.Node.RegisterMethod = registrationMethod
if nodeExpiry != nil {
node.Expiry = nodeExpiry
reg.Node.Expiry = nodeExpiry
}
node, err := RegisterNode(
tx,
node,
reg.Node,
ipv4, ipv6,
)
if err == nil {
hsdb.regCache.Delete(mkey.String())
hsdb.regCache.Delete(registrationID)
}
// Signal to waiting clients that the machine has been registered.
close(reg.Registered)
return node, err
}

View file

@ -227,11 +227,10 @@ func (api headscaleV1APIServer) RegisterNode(
) (*v1.RegisterNodeResponse, error) {
log.Trace().
Str("user", request.GetUser()).
Str("machine_key", request.GetKey()).
Str("registration_id", request.GetKey()).
Msg("Registering node")
var mkey key.MachinePublic
err := mkey.UnmarshalText([]byte(request.GetKey()))
registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil {
return nil, err
}
@ -247,7 +246,7 @@ func (api headscaleV1APIServer) RegisterNode(
}
node, err := api.h.db.RegisterNodeFromAuthCallback(
mkey,
registrationId,
types.UserID(user.ID),
nil,
util.RegisterMethodCLI,
@ -839,36 +838,37 @@ func (api headscaleV1APIServer) DebugCreateNode(
Hostname: "DebugTestNode",
}
var mkey key.MachinePublic
err = mkey.UnmarshalText([]byte(request.GetKey()))
registrationId, err := types.RegistrationIDFromString(request.GetKey())
if err != nil {
return nil, err
}
nodeKey := key.NewNode()
newNode := types.Node{
MachineKey: mkey,
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
User: *user,
newNode := types.RegisterNode{
Node: types.Node{
NodeKey: nodeKey.Public(),
Hostname: request.GetName(),
User: *user,
Expiry: &time.Time{},
LastSeen: &time.Time{},
Expiry: &time.Time{},
LastSeen: &time.Time{},
Hostinfo: &hostinfo,
Hostinfo: &hostinfo,
},
Registered: make(chan struct{}),
}
log.Debug().
Str("machine_key", mkey.ShortString()).
Str("registration_id", registrationId.String()).
Msg("adding debug machine via CLI, appending to registration cache")
api.h.registrationCache.Set(
mkey.String(),
registrationId,
newNode,
)
return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil
return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil
}
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}

View file

@ -8,13 +8,13 @@ import (
"net/http"
"strconv"
"strings"
"time"
"github.com/chasefleming/elem-go"
"github.com/chasefleming/elem-go/attrs"
"github.com/chasefleming/elem-go/styles"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -32,8 +32,6 @@ const (
// See also https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go
NoiseCapabilityVersion = 39
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
registrationHoldoff = time.Second * 5
reservedResponseHeaderSize = 4
)
@ -239,11 +237,11 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb {
}
}
func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string {
func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
registrationId.String())
}
// RegisterWebAPI shows a simple message in the browser to point to the CLI

View file

@ -21,7 +21,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"tailscale.com/types/key"
"zgo.at/zcache/v2"
)
@ -49,8 +48,8 @@ var (
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
MachineKey key.MachinePublic
Verifier *string
RegistrationID types.RegistrationID
Verifier *string
}
type AuthProviderOIDC struct {
@ -112,11 +111,11 @@ func NewAuthProviderOIDC(
}, nil
}
func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string {
func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string {
return fmt.Sprintf(
"%s/register/%s",
strings.TrimSuffix(a.serverURL, "/"),
mKey.String())
registrationID.String())
}
func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time {
@ -129,32 +128,29 @@ func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /register/:mKey.
// Listens in /register/:registration_id.
func (a *AuthProviderOIDC) RegisterHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
log.Debug().
Caller().
Str("machine_key", machineKeyStr).
Bool("ok", ok).
Msg("Received oidc register call")
registrationIdStr, ok := vars["registration_id"]
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var machineKey key.MachinePublic
err := machineKey.UnmarshalText(
[]byte(machineKeyStr),
)
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
http.Error(writer, "invalid registration ID", http.StatusBadRequest)
return
}
log.Debug().
Caller().
Str("registration_id", registrationId.String()).
Bool("ok", ok).
Msg("Received oidc register call")
// Set the state and nonce cookies to protect against CSRF attacks
state, err := setCSRFCookie(writer, req, "state")
if err != nil {
@ -171,7 +167,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
// Initialize registration info with machine key
registrationInfo := RegistrationInfo{
MachineKey: machineKey,
RegistrationID: registrationId,
}
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
@ -326,8 +322,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
}
// Register the node if it does not exist.
if mKey != nil {
if err := a.registerNode(user, mKey, nodeExpiry); err != nil {
if registrationId != nil {
if err := a.registerNode(user, *registrationId, nodeExpiry); err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
@ -456,22 +452,14 @@ func validateOIDCAllowedUsers(
return nil
}
// getMachineKeyFromState retrieves the machine key from the state
// cache. If the machine key is found, it will try retrieve the
// node information from the database.
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
// getRegistrationIDFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, nil
return nil
}
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)
return node, &regInfo.MachineKey
return &regInfo.RegistrationID
}
// reauthenticateNode updates the node expiry in the database
@ -558,7 +546,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
func (a *AuthProviderOIDC) registerNode(
user *types.User,
machineKey *key.MachinePublic,
registrationID types.RegistrationID,
expiry time.Time,
) error {
ipv4, ipv6, err := a.ipAlloc.Next()
@ -567,7 +555,7 @@ func (a *AuthProviderOIDC) registerNode(
}
if _, err := a.db.RegisterNodeFromAuthCallback(
*machineKey,
registrationID,
types.UserID(user.ID),
&expiry,
util.RegisterMethodOIDC,