Make sure nodes can reauthenticate

This commit fixes an issue where nodes were not able to reauthenticate.
This commit is contained in:
Kristoffer Dalby 2021-11-22 19:32:11 +00:00
parent 200c10e48c
commit e600ead3e9
3 changed files with 83 additions and 30 deletions

16
api.go
View file

@ -375,13 +375,13 @@ func (h *Headscale) handleMachineExpired(
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name). machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name).
Inc() Inc()
ctx.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
return return
} }
machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name). machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc() Inc()
ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
} }
@ -503,6 +503,14 @@ func (h *Headscale) handleAuthKey(
return return
} }
if machine.isRegistered() {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, reauthenticating")
h.RefreshMachine(&machine, reqisterRequest.Expiry)
} else {
log.Debug(). log.Debug().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", machine.Name).
@ -524,6 +532,7 @@ func (h *Headscale) handleAuthKey(
Str("ip", ip.String()). Str("ip", ip.String()).
Msgf("Assigning %s to %s", ip, machine.Name) Msgf("Assigning %s to %s", ip, machine.Name)
machine.Expiry = &reqisterRequest.Expiry
machine.AuthKeyID = uint(pak.ID) machine.AuthKeyID = uint(pak.ID)
machine.IPAddress = ip.String() machine.IPAddress = ip.String()
machine.NamespaceID = pak.NamespaceID machine.NamespaceID = pak.NamespaceID
@ -533,6 +542,7 @@ func (h *Headscale) handleAuthKey(
machine.Registered = true machine.Registered = true
machine.RegisterMethod = RegisterMethodAuthKey machine.RegisterMethod = RegisterMethodAuthKey
h.db.Save(&machine) h.db.Save(&machine)
}
pak.Used = true pak.Used = true
h.db.Save(&pak) h.db.Save(&pak)
@ -558,6 +568,6 @@ func (h *Headscale) handleAuthKey(
log.Info(). log.Info().
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", machine.Name). Str("machine", machine.Name).
Str("ip", ip.String()). Str("ip", machine.IPAddress).
Msg("Successfully authenticated via AuthKey") Msg("Successfully authenticated via AuthKey")
} }

View file

@ -270,6 +270,15 @@ func (h *Headscale) ExpireMachine(machine *Machine) {
h.db.Save(machine) h.db.Save(machine)
} }
// RefreshMachine takes a Machine struct and sets the expire field to now.
func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) {
now := time.Now()
machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry
h.db.Save(machine)
}
// DeleteMachine softs deletes a Machine from the database. // DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(machine *Machine) error { func (h *Headscale) DeleteMachine(machine *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(machine) err := h.RemoveSharedMachineFromAllNamespaces(machine)
@ -644,6 +653,7 @@ func (h *Headscale) RegisterMachine(
machine.NamespaceID = namespace.ID machine.NamespaceID = namespace.ID
machine.Registered = true machine.Registered = true
machine.RegisterMethod = RegisterMethodCLI machine.RegisterMethod = RegisterMethodCLI
machine.Expiry = &requestedTime
h.db.Save(&machine) h.db.Save(&machine)
log.Trace(). log.Trace().

35
oidc.go
View file

@ -81,6 +81,11 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
return return
} }
log.Trace().
Caller().
Str("machine_key", machineKeyStr).
Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil { if _, err := rand.Read(randomBlob); err != nil {
log.Error(). log.Error().
@ -124,7 +129,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
log.Debug().Msgf("AccessToken: %v", oauth2Token.AccessToken) log.Trace().
Caller().
Str("code", code).
Str("state", state).
Msg("Got oidc callback")
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK { if !rawIDTokenOK {
@ -202,6 +211,29 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
if machine.isRegistered() {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("machine already registered, reauthenticating")
h.RefreshMachine(machine, requestedTime)
ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
<h1>headscale</h1>
<p>
Reuthenticated as %s, you can now close this window.
</p>
</body>
</html>
`, claims.Email)))
return
}
now := time.Now().UTC() now := time.Now().UTC()
if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok { if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
@ -258,6 +290,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
machine.Registered = true machine.Registered = true
machine.RegisterMethod = RegisterMethodOIDC machine.RegisterMethod = RegisterMethodOIDC
machine.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
machine.Expiry = &requestedTime
h.db.Save(&machine) h.db.Save(&machine)
} }