From fa641e38b8a62ad665e15370a2b29a48c6486060 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 8 Jan 2025 16:29:37 +0100 Subject: [PATCH] Set CSRF cookies for OIDC (#2328) * set state and nounce in oidc to prevent csrf Fixes #2276 * try to fix new postgres issue Signed-off-by: Kristoffer Dalby --------- Signed-off-by: Kristoffer Dalby --- .github/workflows/test.yml | 6 ++++ hscontrol/oidc.go | 61 +++++++++++++++++++++++++++++++---- integration/auth_oidc_test.go | 54 +++++++++++++++++++++++-------- 3 files changed, 100 insertions(+), 21 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f4659332..610c60f6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,4 +34,10 @@ jobs: - name: Run tests if: steps.changed-files.outputs.files == 'true' + env: + # As of 2025-01-06, these env vars was not automatically + # set anymore which breaks the initdb for postgres on + # some of the database migration tests. + LC_ALL: "en_US.UTF-8" + LC_CTYPE: "en_US.UTF-8" run: nix develop --command -- gotestsum diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 35e3c778..8f3003cb 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -3,9 +3,7 @@ package hscontrol import ( "bytes" "context" - "crypto/rand" _ "embed" - "encoding/hex" "errors" "fmt" "html/template" @@ -157,13 +155,19 @@ func (a *AuthProviderOIDC) RegisterHandler( return } - randomBlob := make([]byte, randomByteSize) - if _, err := rand.Read(randomBlob); err != nil { + // Set the state and nonce cookies to protect against CSRF attacks + state, err := setCSRFCookie(writer, req, "state") + if err != nil { http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - stateStr := hex.EncodeToString(randomBlob)[:32] + // Set the state and nonce cookies to protect against CSRF attacks + nonce, err := setCSRFCookie(writer, req, "nonce") + if err != nil { + http.Error(writer, "Internal server error", http.StatusInternalServerError) + return + } // Initialize registration info with machine key registrationInfo := RegistrationInfo{ @@ -191,11 +195,12 @@ func (a *AuthProviderOIDC) RegisterHandler( for k, v := range a.cfg.ExtraParams { extras = append(extras, oauth2.SetAuthURLParam(k, v)) } + extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info - a.registrationCache.Set(stateStr, registrationInfo) + a.registrationCache.Set(state, registrationInfo) - authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...) + authURL := a.oauth2Config.AuthCodeURL(state, extras...) log.Debug().Msgf("Redirecting to %s for authentication", authURL) http.Redirect(writer, req, authURL, http.StatusFound) @@ -228,11 +233,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } + log.Debug().Interface("cookies", req.Cookies()).Msg("Received oidc callback") + cookieState, err := req.Cookie("state") + if err != nil { + http.Error(writer, "state not found", http.StatusBadRequest) + return + } + + if state != cookieState.Value { + http.Error(writer, "state did not match", http.StatusBadRequest) + return + } + idToken, err := a.extractIDToken(req.Context(), code, state) if err != nil { http.Error(writer, err.Error(), http.StatusBadRequest) return } + + nonce, err := req.Cookie("nonce") + if err != nil { + http.Error(writer, "nonce not found", http.StatusBadRequest) + return + } + if idToken.Nonce != nonce.Value { + http.Error(writer, "nonce did not match", http.StatusBadRequest) + return + } + nodeExpiry := a.determineNodeExpiry(idToken.Expiry) var claims types.OIDCClaims @@ -592,3 +620,22 @@ func getUserName( return userName, nil } + +func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) { + val, err := util.GenerateRandomStringURLSafe(64) + if err != nil { + return val, err + } + + c := &http.Cookie{ + Path: "/oidc/callback", + Name: name, + Value: val, + MaxAge: int(time.Hour.Seconds()), + Secure: r.TLS != nil, + HttpOnly: true, + } + http.SetCookie(w, c) + + return val, nil +} diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index e8b49991..e74eae56 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -10,6 +10,8 @@ import ( "log" "net" "net/http" + "net/http/cookiejar" + "net/http/httptest" "net/netip" "sort" "strconv" @@ -747,6 +749,24 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc }, nil } +type LoggingRoundTripper struct{} + +func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + noTls := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint + } + resp, err := noTls.RoundTrip(req) + if err != nil { + return nil, err + } + + log.Printf("---") + log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String()) + log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies()) + + return resp, nil +} + func (s *AuthOIDCScenario) runTailscaleUp( userStr, loginServer string, ) error { @@ -758,35 +778,39 @@ func (s *AuthOIDCScenario) runTailscaleUp( log.Printf("running tailscale up for user %s", userStr) if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { - c := client + tsc := client user.joinWaitGroup.Go(func() error { - loginURL, err := c.LoginWithURL(loginServer) + loginURL, err := tsc.LoginWithURL(loginServer) if err != nil { - log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) + log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) } - loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) + loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetHostname()) loginURL.Scheme = "http" if len(headscale.GetCert()) > 0 { loginURL.Scheme = "https" } - insecureTransport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint + httptest.NewRecorder() + hc := &http.Client{ + Transport: LoggingRoundTripper{}, + } + hc.Jar, err = cookiejar.New(nil) + if err != nil { + log.Printf("failed to create cookie jar: %s", err) } - log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) + log.Printf("%s login url: %s\n", tsc.Hostname(), loginURL.String()) - log.Printf("%s logging in with url", c.Hostname()) - httpClient := &http.Client{Transport: insecureTransport} + log.Printf("%s logging in with url", tsc.Hostname()) ctx := context.Background() req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := httpClient.Do(req) + resp, err := hc.Do(req) if err != nil { log.Printf( "%s failed to login using url %s: %s", - c.Hostname(), + tsc.Hostname(), loginURL, err, ) @@ -794,8 +818,10 @@ func (s *AuthOIDCScenario) runTailscaleUp( return err } + log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) + if resp.StatusCode != http.StatusOK { - log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + log.Printf("%s response code of oidc login request was %s", tsc.Hostname(), resp.Status) body, _ := io.ReadAll(resp.Body) log.Printf("body: %s", body) @@ -806,12 +832,12 @@ func (s *AuthOIDCScenario) runTailscaleUp( _, err = io.ReadAll(resp.Body) if err != nil { - log.Printf("%s failed to read response body: %s", c.Hostname(), err) + log.Printf("%s failed to read response body: %s", tsc.Hostname(), err) return err } - log.Printf("Finished request for %s to join tailnet", c.Hostname()) + log.Printf("Finished request for %s to join tailnet", tsc.Hostname()) return nil })