diff --git a/api.go b/api.go index 9586a234..98676b7e 100644 --- a/api.go +++ b/api.go @@ -546,11 +546,11 @@ func (h *Headscale) handleMachineRegistrationNew( resp.AuthURL = fmt.Sprintf( "%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), - machineKey.String(), + NodePublicKeyStripPrefix(registerRequest.NodeKey), ) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) + strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey)) } respBody, err := encode(resp, &machineKey, h.privateKey) diff --git a/app.go b/app.go index 8de40634..4ff641ca 100644 --- a/app.go +++ b/app.go @@ -415,7 +415,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine { router.GET("/register", h.RegisterWebAPI) router.POST("/machine/:id/map", h.PollNetMapHandler) router.POST("/machine/:id", h.RegistrationHandler) - router.GET("/oidc/register/:mkey", h.RegisterOIDC) + router.GET("/oidc/register/:nkey", h.RegisterOIDC) router.GET("/oidc/callback", h.OIDCCallback) router.GET("/apple", h.AppleConfigMessage) router.GET("/apple/:platform", h.ApplePlatformConfig) diff --git a/machine.go b/machine.go index c2276459..2673efdf 100644 --- a/machine.go +++ b/machine.go @@ -349,7 +349,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { return &m, nil } -// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. +// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. func (h *Headscale) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*Machine, error) { @@ -361,6 +361,19 @@ func (h *Headscale) GetMachineByMachineKey( return &m, nil } +// GetMachineByNodeKeys finds a Machine by its current NodeKey or the old one, and returns the Machine struct. +func (h *Headscale) GetMachineByNodeKeys( + nodeKey key.NodePublic, oldNodeKey key.NodePublic, +) (*Machine, error) { + machine := Machine{} + if result := h.db.Preload("Namespace").First(&machine, "node_key = ? OR node_key = ?", + NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil { + return nil, result.Error + } + + return &machine, nil +} + // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { @@ -567,11 +580,14 @@ func (machine Machine) toNode( } var machineKey key.MachinePublic - err = machineKey.UnmarshalText( - []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), - ) - if err != nil { - return nil, fmt.Errorf("failed to parse machine public key: %w", err) + if machine.MachineKey != "" { + // MachineKey is only used in the legacy protocol + err = machineKey.UnmarshalText( + []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), + ) + if err != nil { + return nil, fmt.Errorf("failed to parse machine public key: %w", err) + } } var discoKey key.DiscoPublic @@ -750,11 +766,11 @@ func getTags( } func (h *Headscale) RegisterMachineFromAuthCallback( - machineKeyStr string, + nodeKeyStr string, namespaceName string, registrationMethod string, ) (*Machine, error) { - if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { + if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok { if registrationMachine, ok := machineInterface.(Machine); ok { namespace, err := h.GetNamespace(namespaceName) if err != nil { @@ -785,7 +801,7 @@ func (h *Headscale) RegisterMachine(machine Machine, ) (*Machine, error) { log.Trace(). Caller(). - Str("machine_key", machine.MachineKey). + Str("node_key", machine.NodeKey). Msg("Registering machine") log.Trace(). diff --git a/machine_test.go b/machine_test.go index 48ccb153..98184a64 100644 --- a/machine_test.go +++ b/machine_test.go @@ -11,6 +11,7 @@ import ( "gopkg.in/check.v1" "inet.af/netaddr" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) func (s *Suite) TestGetMachine(c *check.C) { @@ -65,6 +66,35 @@ func (s *Suite) TestGetMachineByID(c *check.C) { c.Assert(err, check.IsNil) } +func (s *Suite) TestGetMachineByNodeKeys(c *check.C) { + namespace, err := app.CreateNamespace("test") + c.Assert(err, check.IsNil) + + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = app.GetMachineByID(0) + c.Assert(err, check.NotNil) + + nodeKey := key.NewNode() + oldNodeKey := key.NewNode() + + machine := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()), + DiscoKey: "faa", + Hostname: "testmachine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + } + app.db.Save(&machine) + + _, err = app.GetMachineByNodeKeys(nodeKey.Public(), oldNodeKey.Public()) + c.Assert(err, check.IsNil) +} + func (s *Suite) TestDeleteMachine(c *check.C) { namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) diff --git a/oidc.go b/oidc.go index 38a1eb36..862966ad 100644 --- a/oidc.go +++ b/oidc.go @@ -62,10 +62,10 @@ func (h *Headscale) initOIDC() error { // RegisterOIDC redirects to the OIDC provider for authentication // Puts machine key in cache so the callback can retrieve it using the oidc state param -// Listens in /oidc/register/:mKey. +// Listens in /oidc/register/:nKey. func (h *Headscale) RegisterOIDC(ctx *gin.Context) { - machineKeyStr := ctx.Param("mkey") - if machineKeyStr == "" { + nodeKeyStr := ctx.Param("nkey") + if nodeKeyStr == "" { ctx.String(http.StatusBadRequest, "Wrong params") return @@ -73,7 +73,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { log.Trace(). Caller(). - Str("machine_key", machineKeyStr). + Str("node_key", nodeKeyStr). Msg("Received oidc register call") randomBlob := make([]byte, randomByteSize) @@ -89,7 +89,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) { stateStr := hex.EncodeToString(randomBlob)[:32] // place the machine key into the state cache, so it can be retrieved later - h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) + h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration) // Add any extra parameter provided in the configuration to the Authorize Endpoint request extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) @@ -217,10 +217,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - // retrieve machinekey from state cache - machineKeyIf, machineKeyFound := h.registrationCache.Get(state) + // retrieve nodekey from state cache + nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state) - if !machineKeyFound { + if !nodeKeyFound { log.Error(). Msg("requested machine state key expired before authorisation completed") ctx.String(http.StatusBadRequest, "state has expired") @@ -228,22 +228,22 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - machineKeyFromCache, machineKeyOK := machineKeyIf.(string) + nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) - var machineKey key.MachinePublic - err = machineKey.UnmarshalText( - []byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), + var nodeKey key.NodePublic + err = nodeKey.UnmarshalText( + []byte(MachinePublicKeyEnsurePrefix(nodeKeyFromCache)), ) if err != nil { log.Error(). - Msg("could not parse machine public key") + Msg("could not parse node public key") ctx.String(http.StatusBadRequest, "could not parse public key") return } - if !machineKeyOK { - log.Error().Msg("could not get machine key from cache") + if !nodeKeyOK { + log.Error().Msg("could not get node key from cache") ctx.String( http.StatusInternalServerError, "could not get machine key from cache", @@ -256,7 +256,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { // The error is not important, because if it does not // exist, then this is a new machine and we will move // on to registration. - machine, _ := h.GetMachineByMachineKey(machineKey) + machine, _ := h.GetMachineByNodeKeys(nodeKey, key.NodePublic{}) if machine != nil { log.Trace(). @@ -335,10 +335,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - machineKeyStr := MachinePublicKeyStripPrefix(machineKey) + nodeKeyStr := NodePublicKeyStripPrefix(nodeKey) _, err = h.RegisterMachineFromAuthCallback( - machineKeyStr, + nodeKeyStr, namespace.Name, RegisterMethodOIDC, )