mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-30 02:43:05 +00:00
OIDC code cleanup and harmonize with regular web auth
This commit is contained in:
parent
46df219ed3
commit
4034fbc6e9
1 changed files with 63 additions and 28 deletions
91
oidc.go
91
oidc.go
|
@ -76,20 +76,52 @@ func (h *Headscale) RegisterOIDC(
|
||||||
) {
|
) {
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
nodeKeyStr, ok := vars["nkey"]
|
nodeKeyStr, ok := vars["nkey"]
|
||||||
if !ok || nodeKeyStr == "" {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Msg("Missing node key in URL")
|
|
||||||
http.Error(writer, "Missing node key in URL", http.StatusBadRequest)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Str("node_key", nodeKeyStr).
|
Str("node_key", nodeKeyStr).
|
||||||
Msg("Received oidc register call")
|
Msg("Received oidc register call")
|
||||||
|
|
||||||
|
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
||||||
|
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
writer.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, err := writer.Write([]byte("Unauthorized"))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to write response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||||
|
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||||
|
// the template and log an error.
|
||||||
|
var nodeKey key.NodePublic
|
||||||
|
err := nodeKey.UnmarshalText(
|
||||||
|
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if !ok || nodeKeyStr == "" || err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, err := writer.Write([]byte("Wrong params"))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to write response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
randomBlob := make([]byte, randomByteSize)
|
randomBlob := make([]byte, randomByteSize)
|
||||||
if _, err := rand.Read(randomBlob); err != nil {
|
if _, err := rand.Read(randomBlob); err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -103,7 +135,7 @@ func (h *Headscale) RegisterOIDC(
|
||||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||||
|
|
||||||
// place the node key into the state cache, so it can be retrieved later
|
// place the node key into the state cache, so it can be retrieved later
|
||||||
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
|
h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration)
|
||||||
|
|
||||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
||||||
|
@ -405,8 +437,8 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||||
claims *IDTokenClaims,
|
claims *IDTokenClaims,
|
||||||
) (*key.NodePublic, bool, error) {
|
) (*key.NodePublic, bool, error) {
|
||||||
// retrieve machinekey from state cache
|
// retrieve machinekey from state cache
|
||||||
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
|
||||||
if !machineKeyFound {
|
if !nodeKeyFound {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msg("requested machine state key expired before authorisation completed")
|
Msg("requested machine state key expired before authorisation completed")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
@ -423,16 +455,34 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||||
}
|
}
|
||||||
|
|
||||||
var nodeKey key.NodePublic
|
var nodeKey key.NodePublic
|
||||||
nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string)
|
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
|
||||||
|
if !nodeKeyOK {
|
||||||
|
log.Error().
|
||||||
|
Msg("requested machine state key is not a string")
|
||||||
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, err := writer.Write([]byte("state is invalid"))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to write response")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false, errOIDCInvalidMachineState
|
||||||
|
}
|
||||||
|
|
||||||
err := nodeKey.UnmarshalText(
|
err := nodeKey.UnmarshalText(
|
||||||
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Str("nodeKey", nodeKeyFromCache).
|
||||||
|
Bool("nodeKeyOK", nodeKeyOK).
|
||||||
Msg("could not parse node public key")
|
Msg("could not parse node public key")
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusBadRequest)
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
_, werr := writer.Write([]byte("could not parse public key"))
|
_, werr := writer.Write([]byte("could not parse node public key"))
|
||||||
if werr != nil {
|
if werr != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
|
@ -443,21 +493,6 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !nodeKeyOK {
|
|
||||||
log.Error().Msg("could not get node key from cache")
|
|
||||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
writer.WriteHeader(http.StatusInternalServerError)
|
|
||||||
_, err := writer.Write([]byte("could not get node key from cache"))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to write response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false, errOIDCNodeKeyMissing
|
|
||||||
}
|
|
||||||
|
|
||||||
// retrieve machine information if it exist
|
// retrieve machine information if it exist
|
||||||
// The error is not important, because if it does not
|
// The error is not important, because if it does not
|
||||||
// exist, then this is a new machine and we will move
|
// exist, then this is a new machine and we will move
|
||||||
|
|
Loading…
Reference in a new issue