diff --git a/oidc.go b/oidc.go index 60d531e5..f0af600e 100644 --- a/oidc.go +++ b/oidc.go @@ -148,12 +148,12 @@ func (h *Headscale) OIDCCallback( return } - rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state) + rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state) if err != nil { return } - idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken) + idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken) if err != nil { return } @@ -240,10 +240,11 @@ func validateOIDCCallbackParams( } func (h *Headscale) getIDTokenForOIDCCallback( + ctx context.Context, writer http.ResponseWriter, code, state string, ) (string, error) { - oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) + oauth2Token, err := h.oauth2Config.Exchange(ctx, code) if err != nil { log.Error(). Err(err). @@ -287,11 +288,12 @@ func (h *Headscale) getIDTokenForOIDCCallback( } func (h *Headscale) verifyIDTokenForOIDCCallback( + ctx context.Context, writer http.ResponseWriter, rawIDToken string, ) (*oidc.IDToken, error) { verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) - idToken, err := verifier.Verify(context.Background(), rawIDToken) + idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { log.Error(). Err(err).