Switching MachineKey for NodeKey wherever possible as Node identifier

This commit is contained in:
Juan Font Alonso 2022-03-29 16:54:31 +02:00
parent fc181333e5
commit 5082975289
7 changed files with 156 additions and 57 deletions

42
api.go
View file

@ -202,7 +202,7 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
} }
h.registrationCache.Set( h.registrationCache.Set(
machineKeyStr, NodePublicKeyStripPrefix(req.NodeKey),
newMachine, newMachine,
registerCacheExpiration, registerCacheExpiration,
) )
@ -477,6 +477,7 @@ func (h *Headscale) handleMachineValidRegistration(
return return
} }
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
@ -503,10 +504,10 @@ func (h *Headscale) handleMachineExpired(
if h.cfg.OIDC.Issuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String()) strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String()) strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
} }
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
@ -521,6 +522,7 @@ func (h *Headscale) handleMachineExpired(
return return
} }
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
@ -570,13 +572,21 @@ func (h *Headscale) handleMachineRegistrationNew(
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s", "%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String(), NodePublicKeyStripPrefix(registerRequest.NodeKey),
) )
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
} }
if machineKey.IsZero() {
// TS2021
ctx.JSON(http.StatusOK, resp)
return
}
// The Tailscale legacy protocol requires to encrypt the NaCl box with the MachineKey
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -590,21 +600,15 @@ func (h *Headscale) handleMachineRegistrationNew(
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
} }
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey( func (h *Headscale) handleAuthKey(
ctx *gin.Context, ctx *gin.Context,
machineKey key.MachinePublic, machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
) { ) {
var machineKeyStr string machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
if machineKey.IsZero() {
// We are handling here a Noise auth key
machineKeyStr = ""
} else {
machineKeyStr = MachinePublicKeyStripPrefix(machineKey)
}
log.Debug(). log.Debug().
Caller(). Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
@ -651,7 +655,7 @@ func (h *Headscale) handleAuthKey(
} }
log.Debug(). log.Debug().
Caller(). Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname). Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses") Msg("Authentication key was valid, proceeding to acquire IP addresses")
@ -707,14 +711,6 @@ func (h *Headscale) handleAuthKey(
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser() resp.User = *pak.Namespace.toUser()
// TS2021
if machineKey.IsZero() {
ctx.JSON(http.StatusOK, resp)
return
}
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().

2
app.go
View file

@ -484,7 +484,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router.GET("/register", h.RegisterWebAPI) router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler) router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler) router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC) router.GET("/oidc/register/:nkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback) router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleConfigMessage) router.GET("/apple", h.AppleConfigMessage)
router.GET("/apple/:platform", h.ApplePlatformConfig) router.GET("/apple/:platform", h.ApplePlatformConfig)

View file

@ -5,9 +5,10 @@ import (
"context" "context"
"time" "time"
"github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
@ -373,6 +374,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
MachineKey: request.GetKey(), MachineKey: request.GetKey(),
Name: request.GetName(), Name: request.GetName(),
Namespace: *namespace, Namespace: *namespace,
NodeKey: key.NewNode().Public().String(),
Expiry: &time.Time{}, Expiry: &time.Time{},
LastSeen: &time.Time{}, LastSeen: &time.Time{},
@ -382,7 +384,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
} }
api.h.registrationCache.Set( api.h.registrationCache.Set(
request.GetKey(), newMachine.NodeKey,
newMachine, newMachine,
registerCacheExpiration, registerCacheExpiration,
) )

View file

@ -658,11 +658,11 @@ func (machine *Machine) toProto() *v1.Machine {
} }
func (h *Headscale) RegisterMachineFromAuthCallback( func (h *Headscale) RegisterMachineFromAuthCallback(
machineKeyStr string, nodeKeyStr string,
namespaceName string, namespaceName string,
registrationMethod string, registrationMethod string,
) (*Machine, error) { ) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok { if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {

View file

@ -42,7 +42,7 @@ func (h *Headscale) NoiseRegistrationHandler(ctx *gin.Context) {
// If the machine has AuthKey set, handle registration via PreAuthKeys // If the machine has AuthKey set, handle registration via PreAuthKeys
if req.Auth.AuthKey != "" { if req.Auth.AuthKey != "" {
h.handleAuthKey(ctx, key.MachinePublic{}, req) h.handleNoiseAuthKey(ctx, req)
return return
} }
@ -244,13 +244,13 @@ func (h *Headscale) NoisePollNetMapHandler(ctx *gin.Context) {
Bool("readOnly", req.ReadOnly). Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers). Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream). Bool("stream", req.Stream).
Msg("Client map request processed") Msg("Noise client map request processed")
if req.ReadOnly { if req.ReadOnly {
log.Info(). log.Info().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Client is starting up. Probably interested in a DERP map") Msg("Noise client is starting up. Probably interested in a DERP map")
// log.Info().Str("machine", machine.Name).Bytes("resp", data).Msg("Sending DERP map to client") // log.Info().Str("machine", machine.Name).Bytes("resp", data).Msg("Sending DERP map to client")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
@ -270,7 +270,7 @@ func (h *Headscale) NoisePollNetMapHandler(ctx *gin.Context) {
Caller(). Caller().
Str("id", ctx.Param("id")). Str("id", ctx.Param("id")).
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Loading or creating update channel") Msg("Noise loading or creating update channel")
// TODO: could probably remove all that duplication once generics land. // TODO: could probably remove all that duplication once generics land.
closeChanWithLog := func(channel interface{}, name string) { closeChanWithLog := func(channel interface{}, name string) {
@ -303,7 +303,7 @@ func (h *Headscale) NoisePollNetMapHandler(ctx *gin.Context) {
log.Info(). log.Info().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Client sent endpoint update and is ok with a response without peer list") Msg("Noise client sent endpoint update and is ok with a response without peer list")
ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
// It sounds like we should update the nodes when we have received a endpoint update // It sounds like we should update the nodes when we have received a endpoint update
@ -326,7 +326,7 @@ func (h *Headscale) NoisePollNetMapHandler(ctx *gin.Context) {
log.Info(). log.Info().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Client is ready to access the tailnet") Msg("Noise client is ready to access the tailnet")
log.Info(). log.Info().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
@ -423,11 +423,12 @@ func (h *Headscale) handleNoiseNodeExpired(
// The client has registered before, but has expired // The client has registered before, but has expired
log.Debug(). log.Debug().
Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKey(ctx, key.MachinePublic{}, registerRequest) h.handleNoiseAuthKey(ctx, registerRequest)
return return
} }
@ -444,3 +445,106 @@ func (h *Headscale) handleNoiseNodeExpired(
Inc() Inc()
ctx.JSON(http.StatusOK, resp) ctx.JSON(http.StatusOK, resp)
} }
func (h *Headscale) handleNoiseAuthKey(
ctx *gin.Context,
registerRequest tailcfg.RegisterRequest,
) {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s over Noise", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
ctx.JSON(http.StatusUnauthorized, resp)
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey over Noise")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
}
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKeys(registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, refreshing with new auth key")
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
h.RefreshMachine(machine, registerRequest.Expiry)
} else {
now := time.Now().UTC()
machineToRegister := Machine{
Name: registerRequest.Hostinfo.Hostname,
NamespaceID: pak.Namespace.ID,
MachineKey: "",
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
machine, err = h.RegisterMachine(
machineToRegister,
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
ctx.String(
http.StatusInternalServerError,
"could not register machine",
)
return
}
}
h.UsePreAuthKey(pak)
resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser()
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc()
ctx.JSON(http.StatusOK, resp)
log.Info().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey on Noise")
}

36
oidc.go
View file

@ -62,10 +62,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey. // Listens in /oidc/register/:nKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) { func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
machineKeyStr := ctx.Param("mkey") nodeKeyStr := ctx.Param("nkey")
if machineKeyStr == "" { if nodeKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
@ -73,7 +73,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine_key", machineKeyStr). Str("node_key", nodeKeyStr).
Msg("Received oidc register call") Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
@ -89,7 +89,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
authURL := h.oauth2Config.AuthCodeURL(stateStr) authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -114,7 +114,7 @@ var oidcCallbackTemplate = template.Must(
) )
// OIDCCallback handles the callback from the OIDC endpoint // OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the mkey from the state cache and adds the machine to the users email namespace // Retrieves the nkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback. // Listens in /oidc/callback.
@ -188,32 +188,32 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
} }
// retrieve machinekey from state cache // retrieve machinekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state) nodeKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound { if !machineKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested node state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired") ctx.String(http.StatusBadRequest, "state has expired")
return return
} }
machineKeyFromCache, machineKeyOK := machineKeyIf.(string) nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
var machineKey key.MachinePublic var nodeKey key.NodePublic
err = machineKey.UnmarshalText( err = nodeKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse node public key")
ctx.String(http.StatusBadRequest, "could not parse public key") ctx.String(http.StatusBadRequest, "could not parse public key")
return return
} }
if !machineKeyOK { if !nodeKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get node key from cache")
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get machine key from cache", "could not get machine key from cache",
@ -226,7 +226,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new machine and we will move
// on to registration. // on to registration.
machine, _ := h.GetMachineByMachineKey(machineKey) machine, _ := h.GetMachineByNodeKeys(nodeKey, key.NodePublic{})
if machine != nil { if machine != nil {
log.Trace(). log.Trace().
@ -305,10 +305,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
machineKeyStr := MachinePublicKeyStripPrefix(machineKey) nodeKeyStr := NodePublicKeyStripPrefix(nodeKey)
_, err = h.RegisterMachineFromAuthCallback( _, err = h.RegisterMachineFromAuthCallback(
machineKeyStr, nodeKeyStr,
namespace.Name, namespace.Name,
RegisterMethodOIDC, RegisterMethodOIDC,
) )

View file

@ -388,10 +388,7 @@ func (h *Headscale) PollNetMapStream(
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Keep alive sent successfully") Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalbCne(machine)
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").