diff --git a/CHANGELOG.md b/CHANGELOG.md index 8505e5a5..7b63a234 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722) - Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563) +- Improve registration protocol implementation and switch to NodeKey as main identifier [#725](https://github.com/juanfont/headscale/pull/725) ## 0.16.0 (2022-07-25) diff --git a/api.go b/api.go index 21b85be4..561545b8 100644 --- a/api.go +++ b/api.go @@ -21,6 +21,8 @@ import ( ) const ( + // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. + registrationHoldoff = time.Second * 5 reservedResponseHeaderSize = 4 RegisterMethodAuthKey = "authkey" RegisterMethodOIDC = "oidc" @@ -107,13 +109,17 @@ var registerWebAPITemplate = template.Must( `)) // RegisterWebAPI shows a simple message in the browser to point to the CLI -// Listens in /register. +// Listens in /register/:nkey. +// +// This is not part of the Tailscale control API, as we could send whatever URL +// in the RegisterResponse.AuthURL field. func (h *Headscale) RegisterWebAPI( writer http.ResponseWriter, req *http.Request, ) { - machineKeyStr := req.URL.Query().Get("key") - if machineKeyStr == "" { + vars := mux.Vars(req) + nodeKeyStr, ok := vars["nkey"] + if !ok || nodeKeyStr == "" { writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("Wrong params")) @@ -129,7 +135,7 @@ func (h *Headscale) RegisterWebAPI( var content bytes.Buffer if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{ - Key: machineKeyStr, + Key: nodeKeyStr, }); err != nil { log.Error(). Str("func", "RegisterWebAPI"). @@ -206,8 +212,6 @@ func (h *Headscale) RegistrationHandler( now := time.Now().UTC() machine, err := h.GetMachineByMachineKey(machineKey) if errors.Is(err, gorm.ErrRecordNotFound) { - log.Info().Str("machine", registerRequest.Hostinfo.Hostname).Msg("New machine") - machineKeyStr := MachinePublicKeyStripPrefix(machineKey) // If the machine has AuthKey set, handle registration via PreAuthKeys @@ -217,6 +221,44 @@ func (h *Headscale) RegistrationHandler( return } + // Check if the node is waiting for interactive login. + // + // TODO(juan): We could use this field to improve our protocol implementation, + // and hold the request until the client closes it, or the interactive + // login is completed (i.e., the user registers the machine). + // This is not implemented yet, as it is no strictly required. The only side-effect + // is that the client will hammer headscale with requests until it gets a + // successful RegisterResponse. + if registerRequest.Followup != "" { + if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { + log.Debug(). + Caller(). + Str("machine", registerRequest.Hostinfo.Hostname). + Str("node_key", registerRequest.NodeKey.ShortString()). + Str("node_key_old", registerRequest.OldNodeKey.ShortString()). + Str("follow_up", registerRequest.Followup). + Msg("Machine is waiting for interactive login") + + ticker := time.NewTicker(registrationHoldoff) + select { + case <-req.Context().Done(): + return + case <-ticker.C: + h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest) + + return + } + } + } + + log.Info(). + Caller(). + Str("machine", registerRequest.Hostinfo.Hostname). + Str("node_key", registerRequest.NodeKey.ShortString()). + Str("node_key_old", registerRequest.OldNodeKey.ShortString()). + Str("follow_up", registerRequest.Followup). + Msg("New machine not yet in the database") + givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname) if err != nil { log.Error(). @@ -251,7 +293,7 @@ func (h *Headscale) RegistrationHandler( } h.registrationCache.Set( - machineKeyStr, + newMachine.NodeKey, newMachine, registerCacheExpiration, ) @@ -652,7 +694,7 @@ func (h *Headscale) handleMachineRegistrationNew( // The machine registration is new, redirect the client to the registration URL log.Debug(). Str("machine", registerRequest.Hostinfo.Hostname). - Msg("The node is sending us a new NodeKey, sending auth url") + Msg("The node seems to be new, sending auth url") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf( "%s/oidc/register/%s", @@ -660,8 +702,8 @@ func (h *Headscale) handleMachineRegistrationNew( machineKey.String(), ) } else { - resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) + resp.AuthURL = fmt.Sprintf("%s/register/%s", + strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey)) } respBody, err := encode(resp, &machineKey, h.privateKey) diff --git a/app.go b/app.go index 84ca86c0..3e001203 100644 --- a/app.go +++ b/app.go @@ -417,21 +417,17 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router { router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) - router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet) - router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler). - Methods(http.MethodPost) + router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet) + router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost) router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost) - router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet) + router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet) router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) - router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). - Methods(http.MethodGet) + router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).Methods(http.MethodGet) router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) - router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig). - Methods(http.MethodGet) + router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet) router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet) - router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1). - Methods(http.MethodGet) + router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet) if h.cfg.DERP.ServerEnabled { router.HandleFunc("/derp", h.DERPHandler) diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index c2b1e950..a4f2a693 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -108,7 +108,7 @@ var registerNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Error getting machine key from flag: %s", err), + fmt.Sprintf("Error getting node key from flag: %s", err), output, ) diff --git a/grpcv1.go b/grpcv1.go index 452ac21d..e3db5dd4 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -159,7 +159,7 @@ func (api headscaleV1APIServer) RegisterMachine( ) (*v1.RegisterMachineResponse, error) { log.Trace(). Str("namespace", request.GetNamespace()). - Str("machine_key", request.GetKey()). + Str("node_key", request.GetKey()). Msg("Registering machine") machine, err := api.h.RegisterMachineFromAuthCallback( diff --git a/machine.go b/machine.go index 22be0da1..aebfbcef 100644 --- a/machine.go +++ b/machine.go @@ -350,7 +350,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { return &m, nil } -// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. +// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. func (h *Headscale) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*Machine, error) { @@ -362,6 +362,19 @@ func (h *Headscale) GetMachineByMachineKey( return &m, nil } +// GetMachineByNodeKey finds a Machine by its current NodeKey. +func (h *Headscale) GetMachineByNodeKey( + nodeKey key.NodePublic, +) (*Machine, error) { + machine := Machine{} + if result := h.db.Preload("Namespace").First(&machine, "node_key = ?", + NodePublicKeyStripPrefix(nodeKey)); result.Error != nil { + return nil, result.Error + } + + return &machine, nil +} + // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { @@ -762,11 +775,11 @@ func getTags( } func (h *Headscale) RegisterMachineFromAuthCallback( - machineKeyStr string, + nodeKeyStr string, namespaceName string, registrationMethod string, ) (*Machine, error) { - if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { + if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok { if registrationMachine, ok := machineInterface.(Machine); ok { namespace, err := h.GetNamespace(namespaceName) if err != nil { diff --git a/oidc.go b/oidc.go index 4c1bf5a7..63762716 100644 --- a/oidc.go +++ b/oidc.go @@ -27,7 +27,7 @@ const ( errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain") errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user") errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed") - errOIDCMachineKeyMissing = Error("could not get machine key from cache") + errOIDCNodeKeyMissing = Error("could not get node key from cache") ) type IDTokenClaims struct { @@ -68,26 +68,26 @@ func (h *Headscale) initOIDC() error { } // RegisterOIDC redirects to the OIDC provider for authentication -// Puts machine key in cache so the callback can retrieve it using the oidc state param -// Listens in /oidc/register/:mKey. +// Puts NodeKey in cache so the callback can retrieve it using the oidc state param +// Listens in /oidc/register/:nKey. func (h *Headscale) RegisterOIDC( writer http.ResponseWriter, req *http.Request, ) { vars := mux.Vars(req) - machineKeyStr, ok := vars["mkey"] - if !ok || machineKeyStr == "" { + nodeKeyStr, ok := vars["nkey"] + if !ok || nodeKeyStr == "" { log.Error(). Caller(). - Msg("Missing machine key in URL") - http.Error(writer, "Missing machine key in URL", http.StatusBadRequest) + Msg("Missing node key in URL") + http.Error(writer, "Missing node key in URL", http.StatusBadRequest) return } log.Trace(). Caller(). - Str("machine_key", machineKeyStr). + Str("node_key", nodeKeyStr). Msg("Received oidc register call") randomBlob := make([]byte, randomByteSize) @@ -102,8 +102,8 @@ func (h *Headscale) RegisterOIDC( stateStr := hex.EncodeToString(randomBlob)[:32] - // place the machine key into the state cache, so it can be retrieved later - h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) + // place the node key into the state cache, so it can be retrieved later + h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration) // Add any extra parameter provided in the configuration to the Authorize Endpoint request extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) @@ -135,7 +135,7 @@ var oidcCallbackTemplate = template.Must( ) // OIDCCallback handles the callback from the OIDC endpoint -// Retrieves the mkey from the state cache and adds the machine to the users email namespace +// Retrieves the nkey from the state cache and adds the machine to the users email namespace // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: Add groups information from OIDC tokens into machine HostInfo // Listens in /oidc/callback. @@ -178,7 +178,7 @@ func (h *Headscale) OIDCCallback( return } - machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) + nodeKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) if err != nil || machineExists { return } @@ -196,7 +196,7 @@ func (h *Headscale) OIDCCallback( return } - if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil { + if err := h.registerMachineForOIDCCallback(writer, namespace, nodeKey); err != nil { return } @@ -401,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback( writer http.ResponseWriter, state string, claims *IDTokenClaims, -) (*key.MachinePublic, bool, error) { +) (*key.NodePublic, bool, error) { // retrieve machinekey from state cache machineKeyIf, machineKeyFound := h.registrationCache.Get(state) if !machineKeyFound { @@ -420,14 +420,14 @@ func (h *Headscale) validateMachineForOIDCCallback( return nil, false, errOIDCInvalidMachineState } - var machineKey key.MachinePublic - machineKeyFromCache, machineKeyOK := machineKeyIf.(string) - err := machineKey.UnmarshalText( - []byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), + var nodeKey key.NodePublic + nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string) + err := nodeKey.UnmarshalText( + []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), ) if err != nil { log.Error(). - Msg("could not parse machine public key") + Msg("could not parse node public key") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, werr := writer.Write([]byte("could not parse public key")) @@ -441,11 +441,11 @@ func (h *Headscale) validateMachineForOIDCCallback( return nil, false, err } - if !machineKeyOK { - log.Error().Msg("could not get machine key from cache") + if !nodeKeyOK { + log.Error().Msg("could not get node key from cache") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("could not get machine key from cache")) + _, err := writer.Write([]byte("could not get node key from cache")) if err != nil { log.Error(). Caller(). @@ -453,14 +453,14 @@ func (h *Headscale) validateMachineForOIDCCallback( Msg("Failed to write response") } - return nil, false, errOIDCMachineKeyMissing + return nil, false, errOIDCNodeKeyMissing } // retrieve machine information if it exist // The error is not important, because if it does not // exist, then this is a new machine and we will move // on to registration. - machine, _ := h.GetMachineByMachineKey(machineKey) + machine, _ := h.GetMachineByNodeKey(nodeKey) if machine != nil { log.Trace(). @@ -520,7 +520,7 @@ func (h *Headscale) validateMachineForOIDCCallback( return nil, true, nil } - return &machineKey, false, nil + return &nodeKey, false, nil } func getNamespaceName( @@ -600,12 +600,12 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( func (h *Headscale) registerMachineForOIDCCallback( writer http.ResponseWriter, namespace *Namespace, - machineKey *key.MachinePublic, + nodeKey *key.NodePublic, ) error { - machineKeyStr := MachinePublicKeyStripPrefix(*machineKey) + nodeKeyStr := NodePublicKeyStripPrefix(*nodeKey) if _, err := h.RegisterMachineFromAuthCallback( - machineKeyStr, + nodeKeyStr, namespace.Name, RegisterMethodOIDC, ); err != nil {