mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-19 10:20:05 +09:00
Clean up logging and error handling in oidc
We should never expose errors via web, it gives attackers a lot of info (Insert OWASP guide). Also handle error that didnt separate not found gorm issue and other errors.
This commit is contained in:
parent
fac33e46e1
commit
fcd4d94927
1 changed files with 48 additions and 14 deletions
62
oidc.go
62
oidc.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
@ -15,6 +16,7 @@ import (
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -37,7 +39,10 @@ func (h *Headscale) initOIDC() error {
|
||||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
log.Error().
|
||||||
|
Err(err).
|
||||||
|
Caller().
|
||||||
|
Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -69,8 +74,8 @@ func (h *Headscale) initOIDC() error {
|
||||||
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
// Puts machine key in cache so the callback can retrieve it using the oidc state param
|
||||||
// Listens in /oidc/register/:mKey.
|
// Listens in /oidc/register/:mKey.
|
||||||
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||||
mKeyStr := ctx.Param("mkey")
|
machineKeyStr := ctx.Param("mkey")
|
||||||
if mKeyStr == "" {
|
if machineKeyStr == "" {
|
||||||
ctx.String(http.StatusBadRequest, "Wrong params")
|
ctx.String(http.StatusBadRequest, "Wrong params")
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -78,7 +83,9 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||||
|
|
||||||
randomBlob := make([]byte, randomByteSize)
|
randomBlob := make([]byte, randomByteSize)
|
||||||
if _, err := rand.Read(randomBlob); err != nil {
|
if _, err := rand.Read(randomBlob); err != nil {
|
||||||
log.Error().Msg("could not read 16 bytes from rand")
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Msg("could not read 16 bytes from rand")
|
||||||
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -87,7 +94,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||||
|
|
||||||
// place the machine key into the state cache, so it can be retrieved later
|
// place the machine key into the state cache, so it can be retrieved later
|
||||||
h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
|
h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration)
|
||||||
|
|
||||||
authURL := h.oauth2Config.AuthCodeURL(stateStr)
|
authURL := h.oauth2Config.AuthCodeURL(stateStr)
|
||||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||||
|
@ -130,7 +137,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
|
|
||||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
|
log.Error().
|
||||||
|
Err(err).
|
||||||
|
Caller().
|
||||||
|
Msg("failed to verify id token")
|
||||||
|
ctx.String(http.StatusBadRequest, "Failed to verify id token")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -145,27 +156,31 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
// Extract custom claims
|
// Extract custom claims
|
||||||
var claims IDTokenClaims
|
var claims IDTokenClaims
|
||||||
if err = idToken.Claims(&claims); err != nil {
|
if err = idToken.Claims(&claims); err != nil {
|
||||||
|
log.Error().
|
||||||
|
Err(err).
|
||||||
|
Caller().
|
||||||
|
Msg("Failed to decode id token claims")
|
||||||
ctx.String(
|
ctx.String(
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
fmt.Sprintf("Failed to decode id token claims: %s", err),
|
fmt.Sprintf("Failed to decode id token claims"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve machinekey from state cache
|
// retrieve machinekey from state cache
|
||||||
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state)
|
||||||
|
|
||||||
if !mKeyFound {
|
if !machineKeyFound {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msg("requested machine state key expired before authorisation completed")
|
Msg("requested machine state key expired before authorisation completed")
|
||||||
ctx.String(http.StatusBadRequest, "state has expired")
|
ctx.String(http.StatusBadRequest, "state has expired")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
mKeyStr, mKeyOK := mKeyIf.(string)
|
machineKey, machineKeyOK := machineKeyIf.(string)
|
||||||
|
|
||||||
if !mKeyOK {
|
if !machineKeyOK {
|
||||||
log.Error().Msg("could not get machine key from cache")
|
log.Error().Msg("could not get machine key from cache")
|
||||||
ctx.String(
|
ctx.String(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
|
@ -176,7 +191,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve machine information
|
// retrieve machine information
|
||||||
machine, err := h.GetMachineByMachineKey(mKeyStr)
|
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msg("machine key not found in database")
|
log.Error().Msg("machine key not found in database")
|
||||||
ctx.String(
|
ctx.String(
|
||||||
|
@ -195,12 +210,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
log.Debug().Msg("Registering new machine after successful callback")
|
log.Debug().Msg("Registering new machine after successful callback")
|
||||||
|
|
||||||
namespace, err := h.GetNamespace(namespaceName)
|
namespace, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
namespace, err = h.CreateNamespace(namespaceName)
|
namespace, err = h.CreateNamespace(namespaceName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Msgf("could not create new namespace '%s'", claims.Email)
|
Err(err).
|
||||||
|
Caller().
|
||||||
|
Msgf("could not create new namespace '%s'", namespaceName)
|
||||||
ctx.String(
|
ctx.String(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"could not create new namespace",
|
"could not create new namespace",
|
||||||
|
@ -208,10 +225,26 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Str("namespace", namespaceName).
|
||||||
|
Msg("could not find or create namespace")
|
||||||
|
ctx.String(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
"could not find or create namespace",
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, err := h.getAvailableIP()
|
ip, err := h.getAvailableIP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("could not get an IP from the pool")
|
||||||
ctx.String(
|
ctx.String(
|
||||||
http.StatusInternalServerError,
|
http.StatusInternalServerError,
|
||||||
"could not get an IP from the pool",
|
"could not get an IP from the pool",
|
||||||
|
@ -242,6 +275,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Error().
|
log.Error().
|
||||||
|
Caller().
|
||||||
Str("email", claims.Email).
|
Str("email", claims.Email).
|
||||||
Str("username", claims.Username).
|
Str("username", claims.Username).
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
|
|
Loading…
Reference in a new issue