mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
TS2021: Use NodeKey for everything, as MachineKey is deprecated in TS2021
This commit is contained in:
parent
b40b4e8d45
commit
e8205e8d5a
5 changed files with 76 additions and 30 deletions
4
api.go
4
api.go
|
@ -546,11 +546,11 @@ func (h *Headscale) handleMachineRegistrationNew(
|
|||
resp.AuthURL = fmt.Sprintf(
|
||||
"%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"),
|
||||
machineKey.String(),
|
||||
NodePublicKeyStripPrefix(registerRequest.NodeKey),
|
||||
)
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
|
||||
}
|
||||
|
||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||
|
|
2
app.go
2
app.go
|
@ -415,7 +415,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
|
|||
router.GET("/register", h.RegisterWebAPI)
|
||||
router.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||
router.POST("/machine/:id", h.RegistrationHandler)
|
||||
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
|
||||
router.GET("/oidc/register/:nkey", h.RegisterOIDC)
|
||||
router.GET("/oidc/callback", h.OIDCCallback)
|
||||
router.GET("/apple", h.AppleConfigMessage)
|
||||
router.GET("/apple/:platform", h.ApplePlatformConfig)
|
||||
|
|
34
machine.go
34
machine.go
|
@ -349,7 +349,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
|||
return &m, nil
|
||||
}
|
||||
|
||||
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
|
||||
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByMachineKey(
|
||||
machineKey key.MachinePublic,
|
||||
) (*Machine, error) {
|
||||
|
@ -361,6 +361,19 @@ func (h *Headscale) GetMachineByMachineKey(
|
|||
return &m, nil
|
||||
}
|
||||
|
||||
// GetMachineByNodeKeys finds a Machine by its current NodeKey or the old one, and returns the Machine struct.
|
||||
func (h *Headscale) GetMachineByNodeKeys(
|
||||
nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
||||
) (*Machine, error) {
|
||||
machine := Machine{}
|
||||
if result := h.db.Preload("Namespace").First(&machine, "node_key = ? OR node_key = ?",
|
||||
NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &machine, nil
|
||||
}
|
||||
|
||||
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
|
||||
// and updates it with the latest data from the database.
|
||||
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
|
||||
|
@ -567,11 +580,14 @@ func (machine Machine) toNode(
|
|||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
err = machineKey.UnmarshalText(
|
||||
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||
if machine.MachineKey != "" {
|
||||
// MachineKey is only used in the legacy protocol
|
||||
err = machineKey.UnmarshalText(
|
||||
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var discoKey key.DiscoPublic
|
||||
|
@ -750,11 +766,11 @@ func getTags(
|
|||
}
|
||||
|
||||
func (h *Headscale) RegisterMachineFromAuthCallback(
|
||||
machineKeyStr string,
|
||||
nodeKeyStr string,
|
||||
namespaceName string,
|
||||
registrationMethod string,
|
||||
) (*Machine, error) {
|
||||
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
|
||||
if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok {
|
||||
if registrationMachine, ok := machineInterface.(Machine); ok {
|
||||
namespace, err := h.GetNamespace(namespaceName)
|
||||
if err != nil {
|
||||
|
@ -785,7 +801,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
|
|||
) (*Machine, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine_key", machine.MachineKey).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Msg("Registering machine")
|
||||
|
||||
log.Trace().
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"gopkg.in/check.v1"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetMachine(c *check.C) {
|
||||
|
@ -65,6 +66,35 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
|
|||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByNodeKeys(c *check.C) {
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
oldNodeKey := key.NewNode()
|
||||
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
NamespaceID: namespace.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
app.db.Save(&machine)
|
||||
|
||||
_, err = app.GetMachineByNodeKeys(nodeKey.Public(), oldNodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||
namespace, err := app.CreateNamespace("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
|
36
oidc.go
36
oidc.go
|
@ -62,10 +62,10 @@ func (h *Headscale) initOIDC() error {
|
|||
|
||||
// RegisterOIDC redirects to the OIDC provider for authentication
|
||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||
// Listens in /oidc/register/:mKey.
|
||||
// Listens in /oidc/register/:nKey.
|
||||
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||
machineKeyStr := ctx.Param("mkey")
|
||||
if machineKeyStr == "" {
|
||||
nodeKeyStr := ctx.Param("nkey")
|
||||
if nodeKeyStr == "" {
|
||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||
|
||||
return
|
||||
|
@ -73,7 +73,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
|||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine_key", machineKeyStr).
|
||||
Str("node_key", nodeKeyStr).
|
||||
Msg("Received oidc register call")
|
||||
|
||||
randomBlob := make([]byte, randomByteSize)
|
||||
|
@ -89,7 +89,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
|||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the machine key into the state cache, so it can be retrieved later
|
||||
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
|
||||
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
|
||||
|
||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
||||
|
@ -217,10 +217,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
// retrieve machinekey from state cache
|
||||
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
||||
// retrieve nodekey from state cache
|
||||
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
|
||||
|
||||
if !machineKeyFound {
|
||||
if !nodeKeyFound {
|
||||
log.Error().
|
||||
Msg("requested machine state key expired before authorisation completed")
|
||||
ctx.String(http.StatusBadRequest, "state has expired")
|
||||
|
@ -228,22 +228,22 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
machineKeyFromCache, machineKeyOK := machineKeyIf.(string)
|
||||
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
err = machineKey.UnmarshalText(
|
||||
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)),
|
||||
var nodeKey key.NodePublic
|
||||
err = nodeKey.UnmarshalText(
|
||||
[]byte(MachinePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msg("could not parse machine public key")
|
||||
Msg("could not parse node public key")
|
||||
ctx.String(http.StatusBadRequest, "could not parse public key")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !machineKeyOK {
|
||||
log.Error().Msg("could not get machine key from cache")
|
||||
if !nodeKeyOK {
|
||||
log.Error().Msg("could not get node key from cache")
|
||||
ctx.String(
|
||||
http.StatusInternalServerError,
|
||||
"could not get machine key from cache",
|
||||
|
@ -256,7 +256,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
// The error is not important, because if it does not
|
||||
// exist, then this is a new machine and we will move
|
||||
// on to registration.
|
||||
machine, _ := h.GetMachineByMachineKey(machineKey)
|
||||
machine, _ := h.GetMachineByNodeKeys(nodeKey, key.NodePublic{})
|
||||
|
||||
if machine != nil {
|
||||
log.Trace().
|
||||
|
@ -335,10 +335,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
|
||||
nodeKeyStr := NodePublicKeyStripPrefix(nodeKey)
|
||||
|
||||
_, err = h.RegisterMachineFromAuthCallback(
|
||||
machineKeyStr,
|
||||
nodeKeyStr,
|
||||
namespace.Name,
|
||||
RegisterMethodOIDC,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue