diff --git a/hscontrol/app.go b/hscontrol/app.go index 3349392b..7e5850a3 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -96,7 +96,7 @@ type Headscale struct { mapper *mapper.Mapper nodeNotifier *notifier.Notifier - registrationCache *zcache.Cache[string, types.Node] + registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] authProvider AuthProvider @@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - registrationCache := zcache.New[string, types.Node]( + registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( registerCacheExpiration, registerCacheCleanup, ) @@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) - router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet) + router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet) if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index b4923ccb..491594c3 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "net/http" + "net/url" + "strings" "time" "github.com/juanfont/headscale/hscontrol/db" @@ -20,16 +22,18 @@ import ( type AuthProvider interface { RegisterHandler(http.ResponseWriter, *http.Request) - AuthURL(key.MachinePublic) string + AuthURL(types.RegistrationID) string } func logAuthFunc( registerRequest tailcfg.RegisterRequest, machineKey key.MachinePublic, + registrationId types.RegistrationID, ) (func(string), func(string), func(error, string)) { return func(msg string) { log.Info(). Caller(). + Str("registration_id", registrationId.String()). Str("machine_key", machineKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()). @@ -41,6 +45,7 @@ func logAuthFunc( func(msg string) { log.Trace(). Caller(). + Str("registration_id", registrationId.String()). Str("machine_key", machineKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()). @@ -52,6 +57,7 @@ func logAuthFunc( func(err error, msg string) { log.Error(). Caller(). + Str("registration_id", registrationId.String()). Str("machine_key", machineKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()). @@ -63,6 +69,40 @@ func logAuthFunc( } } +func (h *Headscale) waitForFollowup( + req *http.Request, + regReq tailcfg.RegisterRequest, + logTrace func(string), +) { + logTrace("register request is a followup") + fu, err := url.Parse(regReq.Followup) + if err != nil { + logTrace("failed to parse followup URL") + return + } + + followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) + if err != nil { + logTrace("followup URL does not contains a valid registration ID") + return + } + + logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg)) + + if reg, ok := h.registrationCache.Get(followupReg); ok { + logTrace("Node is waiting for interactive login") + + select { + case <-req.Context().Done(): + logTrace("node went away before it was registered") + return + case <-reg.Registered: + logTrace("node has successfully registered") + return + } + } +} + // handleRegister is the logic for registering a client. func (h *Headscale) handleRegister( writer http.ResponseWriter, @@ -70,7 +110,18 @@ func (h *Headscale) handleRegister( regReq tailcfg.RegisterRequest, machineKey key.MachinePublic, ) { - logInfo, logTrace, _ := logAuthFunc(regReq, machineKey) + registrationId, err := types.NewRegistrationID() + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to generate registration ID") + http.Error(writer, "Internal server error", http.StatusInternalServerError) + + return + } + + logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId) now := time.Now().UTC() logTrace("handleRegister called, looking up machine in DB") node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey) @@ -84,27 +135,9 @@ func (h *Headscale) handleRegister( } // Check if the node is waiting for interactive login. - // - // TODO(juan): We could use this field to improve our protocol implementation, - // and hold the request until the client closes it, or the interactive - // login is completed (i.e., the user registers the node). - // This is not implemented yet, as it is no strictly required. The only side-effect - // is that the client will hammer headscale with requests until it gets a - // successful RegisterResponse. if regReq.Followup != "" { - logTrace("register request is a followup") - if _, ok := h.registrationCache.Get(machineKey.String()); ok { - logTrace("Node is waiting for interactive login") - - select { - case <-req.Context().Done(): - return - case <-time.After(registrationHoldoff): - h.handleNewNode(writer, regReq, machineKey) - - return - } - } + h.waitForFollowup(req, regReq, logTrace) + return } logInfo("Node not found in database, creating new") @@ -113,25 +146,28 @@ func (h *Headscale) handleRegister( // that we rely on a method that calls back some how (OpenID or CLI) // We create the node and then keep it around until a callback // happens - newNode := types.Node{ - MachineKey: machineKey, - Hostname: regReq.Hostinfo.Hostname, - NodeKey: regReq.NodeKey, - LastSeen: &now, - Expiry: &time.Time{}, + newNode := types.RegisterNode{ + Node: types.Node{ + MachineKey: machineKey, + Hostname: regReq.Hostinfo.Hostname, + NodeKey: regReq.NodeKey, + LastSeen: &now, + Expiry: &time.Time{}, + }, + Registered: make(chan struct{}), } if !regReq.Expiry.IsZero() { logTrace("Non-zero expiry time requested") - newNode.Expiry = ®Req.Expiry + newNode.Node.Expiry = ®Req.Expiry } h.registrationCache.Set( - machineKey.String(), + registrationId, newNode, ) - h.handleNewNode(writer, regReq, machineKey) + h.handleNewNode(writer, regReq, registrationId) return } @@ -206,15 +242,12 @@ func (h *Headscale) handleRegister( } if regReq.Followup != "" { - select { - case <-req.Context().Done(): - return - case <-time.After(registrationHoldoff): - } + h.waitForFollowup(req, regReq, logTrace) + return } // The node has expired or it is logged out - h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey) + h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId) // TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use node.Expiry = &time.Time{} @@ -444,16 +477,16 @@ func (h *Headscale) handleAuthKey( func (h *Headscale) handleNewNode( writer http.ResponseWriter, registerRequest tailcfg.RegisterRequest, - machineKey key.MachinePublic, + registrationId types.RegistrationID, ) { - logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey) + logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId) resp := tailcfg.RegisterResponse{} // The node registration is new, redirect the client to the registration URL - logTrace("The node seems to be new, sending auth url") + logTrace("The node is new, sending auth url") - resp.AuthURL = h.authProvider.AuthURL(machineKey) + resp.AuthURL = h.authProvider.AuthURL(registrationId) respBody, err := json.Marshal(resp) if err != nil { @@ -660,6 +693,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut( regReq tailcfg.RegisterRequest, node types.Node, machineKey key.MachinePublic, + registrationId types.RegistrationID, ) { resp := tailcfg.RegisterResponse{} @@ -673,12 +707,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut( log.Trace(). Caller(). Str("node", node.Hostname). - Str("machine_key", machineKey.ShortString()). + Str("registration_id", registrationId.String()). Str("node_key", regReq.NodeKey.ShortString()). Str("node_key_old", regReq.OldNodeKey.ShortString()). Msg("Node registration has expired or logged out. Sending a auth url to register") - resp.AuthURL = h.authProvider.AuthURL(machineKey) + resp.AuthURL = h.authProvider.AuthURL(registrationId) respBody, err := json.Marshal(resp) if err != nil { @@ -703,7 +737,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut( log.Trace(). Caller(). - Str("machine_key", machineKey.ShortString()). + Str("registration_id", registrationId.String()). Str("node_key", regReq.NodeKey.ShortString()). Str("node_key_old", regReq.OldNodeKey.ShortString()). Str("node", node.Hostname). diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 0d9120c2..41d77051 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -41,7 +41,7 @@ type KV struct { type HSDatabase struct { DB *gorm.DB cfg *types.DatabaseConfig - regCache *zcache.Cache[string, types.Node] + regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] baseDomain string } @@ -51,7 +51,7 @@ type HSDatabase struct { func NewHeadscaleDatabase( cfg types.DatabaseConfig, baseDomain string, - regCache *zcache.Cache[string, types.Node], + regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], ) (*HSDatabase, error) { dbConn, err := openDB(cfg) if err != nil { diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index c3d9a835..213f05d7 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -256,8 +256,8 @@ func testCopyOfDatabase(src string) (string, error) { return dst, err } -func emptyCache() *zcache.Cache[string, types.Node] { - return zcache.New[string, types.Node](time.Minute, time.Hour) +func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { + return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) } // requireConstraintFailed checks if the error is a constraint failure with diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index ce9c90e9..b718b9c6 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -320,7 +320,7 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { } func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( - mkey key.MachinePublic, + registrationID types.RegistrationID, userID types.UserID, nodeExpiry *time.Time, registrationMethod string, @@ -328,7 +328,7 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( ipv6 *netip.Addr, ) (*types.Node, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) { - if node, ok := hsdb.regCache.Get(mkey.String()); ok { + if reg, ok := hsdb.regCache.Get(registrationID); ok { user, err := GetUserByID(tx, userID) if err != nil { return nil, fmt.Errorf( @@ -338,36 +338,40 @@ func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( } log.Debug(). - Str("machine_key", mkey.ShortString()). + Str("registration_id", registrationID.String()). Str("username", user.Username()). Str("registrationMethod", registrationMethod). Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). Msg("Registering node from API/CLI or auth callback") + // TODO(kradalby): This looks quite wrong? why ID 0? + // Why not always? // Registration of expired node with different user - if node.ID != 0 && - node.UserID != user.ID { + if reg.Node.ID != 0 && + reg.Node.UserID != user.ID { return nil, ErrDifferentRegisteredUser } - node.UserID = user.ID - node.User = *user - node.RegisterMethod = registrationMethod + reg.Node.UserID = user.ID + reg.Node.User = *user + reg.Node.RegisterMethod = registrationMethod if nodeExpiry != nil { - node.Expiry = nodeExpiry + reg.Node.Expiry = nodeExpiry } node, err := RegisterNode( tx, - node, + reg.Node, ipv4, ipv6, ) if err == nil { - hsdb.regCache.Delete(mkey.String()) + hsdb.regCache.Delete(registrationID) } + // Signal to waiting clients that the machine has been registered. + close(reg.Registered) return node, err } diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index b7c7e50e..1d4ecbb1 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -227,11 +227,10 @@ func (api headscaleV1APIServer) RegisterNode( ) (*v1.RegisterNodeResponse, error) { log.Trace(). Str("user", request.GetUser()). - Str("machine_key", request.GetKey()). + Str("registration_id", request.GetKey()). Msg("Registering node") - var mkey key.MachinePublic - err := mkey.UnmarshalText([]byte(request.GetKey())) + registrationId, err := types.RegistrationIDFromString(request.GetKey()) if err != nil { return nil, err } @@ -247,7 +246,7 @@ func (api headscaleV1APIServer) RegisterNode( } node, err := api.h.db.RegisterNodeFromAuthCallback( - mkey, + registrationId, types.UserID(user.ID), nil, util.RegisterMethodCLI, @@ -839,36 +838,37 @@ func (api headscaleV1APIServer) DebugCreateNode( Hostname: "DebugTestNode", } - var mkey key.MachinePublic - err = mkey.UnmarshalText([]byte(request.GetKey())) + registrationId, err := types.RegistrationIDFromString(request.GetKey()) if err != nil { return nil, err } nodeKey := key.NewNode() - newNode := types.Node{ - MachineKey: mkey, - NodeKey: nodeKey.Public(), - Hostname: request.GetName(), - User: *user, + newNode := types.RegisterNode{ + Node: types.Node{ + NodeKey: nodeKey.Public(), + Hostname: request.GetName(), + User: *user, - Expiry: &time.Time{}, - LastSeen: &time.Time{}, + Expiry: &time.Time{}, + LastSeen: &time.Time{}, - Hostinfo: &hostinfo, + Hostinfo: &hostinfo, + }, + Registered: make(chan struct{}), } log.Debug(). - Str("machine_key", mkey.ShortString()). + Str("registration_id", registrationId.String()). Msg("adding debug machine via CLI, appending to registration cache") api.h.registrationCache.Set( - mkey.String(), + registrationId, newNode, ) - return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil + return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil } func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 3858df93..3b7829a9 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -8,13 +8,13 @@ import ( "net/http" "strconv" "strings" - "time" "github.com/chasefleming/elem-go" "github.com/chasefleming/elem-go/attrs" "github.com/chasefleming/elem-go/styles" "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/templates" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -32,8 +32,6 @@ const ( // See also https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go NoiseCapabilityVersion = 39 - // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. - registrationHoldoff = time.Second * 5 reservedResponseHeaderSize = 4 ) @@ -239,11 +237,11 @@ func NewAuthProviderWeb(serverURL string) *AuthProviderWeb { } } -func (a *AuthProviderWeb) AuthURL(mKey key.MachinePublic) string { +func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - mKey.String()) + registrationId.String()) } // RegisterWebAPI shows a simple message in the browser to point to the CLI diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 4470ba41..73a64b91 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -21,7 +21,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "tailscale.com/types/key" "zgo.at/zcache/v2" ) @@ -49,8 +48,8 @@ var ( // RegistrationInfo contains both machine key and verifier information for OIDC validation. type RegistrationInfo struct { - MachineKey key.MachinePublic - Verifier *string + RegistrationID types.RegistrationID + Verifier *string } type AuthProviderOIDC struct { @@ -112,11 +111,11 @@ func NewAuthProviderOIDC( }, nil } -func (a *AuthProviderOIDC) AuthURL(mKey key.MachinePublic) string { +func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string { return fmt.Sprintf( "%s/register/%s", strings.TrimSuffix(a.serverURL, "/"), - mKey.String()) + registrationID.String()) } func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { @@ -129,32 +128,29 @@ func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time // RegisterOIDC redirects to the OIDC provider for authentication // Puts NodeKey in cache so the callback can retrieve it using the oidc state param -// Listens in /register/:mKey. +// Listens in /register/:registration_id. func (a *AuthProviderOIDC) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { vars := mux.Vars(req) - machineKeyStr, ok := vars["mkey"] - - log.Debug(). - Caller(). - Str("machine_key", machineKeyStr). - Bool("ok", ok). - Msg("Received oidc register call") + registrationIdStr, ok := vars["registration_id"] // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - var machineKey key.MachinePublic - err := machineKey.UnmarshalText( - []byte(machineKeyStr), - ) + registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - http.Error(writer, err.Error(), http.StatusBadRequest) + http.Error(writer, "invalid registration ID", http.StatusBadRequest) return } + log.Debug(). + Caller(). + Str("registration_id", registrationId.String()). + Bool("ok", ok). + Msg("Received oidc register call") + // Set the state and nonce cookies to protect against CSRF attacks state, err := setCSRFCookie(writer, req, "state") if err != nil { @@ -171,7 +167,7 @@ func (a *AuthProviderOIDC) RegisterHandler( // Initialize registration info with machine key registrationInfo := RegistrationInfo{ - MachineKey: machineKey, + RegistrationID: registrationId, } extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) @@ -326,8 +322,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } // Register the node if it does not exist. - if mKey != nil { - if err := a.registerNode(user, mKey, nodeExpiry); err != nil { + if registrationId != nil { + if err := a.registerNode(user, *registrationId, nodeExpiry); err != nil { http.Error(writer, err.Error(), http.StatusInternalServerError) return } @@ -456,22 +452,14 @@ func validateOIDCAllowedUsers( return nil } -// getMachineKeyFromState retrieves the machine key from the state -// cache. If the machine key is found, it will try retrieve the -// node information from the database. -func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) { +// getRegistrationIDFromState retrieves the registration ID from the state. +func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID { regInfo, ok := a.registrationCache.Get(state) if !ok { - return nil, nil + return nil } - // retrieve node information if it exist - // The error is not important, because if it does not - // exist, then this is a new node and we will move - // on to registration. - node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey) - - return node, ®Info.MachineKey + return ®Info.RegistrationID } // reauthenticateNode updates the node expiry in the database @@ -558,7 +546,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( func (a *AuthProviderOIDC) registerNode( user *types.User, - machineKey *key.MachinePublic, + registrationID types.RegistrationID, expiry time.Time, ) error { ipv4, ipv6, err := a.ipAlloc.Next() @@ -567,7 +555,7 @@ func (a *AuthProviderOIDC) registerNode( } if _, err := a.db.RegisterNodeFromAuthCallback( - *machineKey, + registrationID, types.UserID(user.ID), &expiry, util.RegisterMethodOIDC,