diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 73ee4686..f1eacf15 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -13,6 +13,7 @@ import ( "net/netip" "sort" "strconv" + "strings" "testing" "time" @@ -34,9 +35,13 @@ const ( dockerContextPath = "../." hsicOIDCMockHashLength = 6 defaultAccessTTL = 10 * time.Minute + nodeStateRunning = "Running" ) -var errStatusCodeNotOK = errors.New("status code not OK") +var ( + errStatusCodeNotOK = errors.New("status code not OK") + ErrOIDCClientCount = errors.New("client count must be 1 for OIDC scenario") +) type AuthOIDCScenario struct { *Scenario @@ -617,12 +622,128 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { for _, client := range allClients { status, err := client.Status() assertNoErr(t, err) - if status.BackendState != "Running" { + if status.BackendState != nodeStateRunning { t.Errorf("client %s is not running: %s", client.Hostname(), status.BackendState) } } } +type tamperVerifierTransport struct { + base http.RoundTripper +} + +func (t *tamperVerifierTransport) RoundTrip(req *http.Request) (*http.Response, error) { + log.Printf("RoundTrip: %s %s", req.Method, req.URL.String()) + + // For POST requests, tamper with form data + if req.Method == http.MethodPost { + log.Printf("Processing POST request") + err := req.ParseForm() + if err != nil { + log.Printf("Error parsing form: %v", err) + return nil, err + } + if verifier := req.Form.Get("code_challenge"); verifier != "" { + log.Printf("Found POST verifier: %s", verifier) + // Tamper with the verifier + req.Form.Set("code_challenge", verifier+"_tampered") + log.Printf("Modified POST verifier to: %s", req.Form.Get("code_challenge")) + // Update request body with modified form + req.Body = io.NopCloser(strings.NewReader(req.Form.Encode())) + req.ContentLength = int64(len(req.Form.Encode())) + } else { + log.Printf("No code_challenge found in POST form data") + } + } + + // For GET requests, tamper with URL query parameters + if req.Method == http.MethodGet { + log.Printf("Processing GET request") + q := req.URL.Query() + if verifier := q.Get("code_challenge"); verifier != "" { + log.Printf("Found GET verifier: %s", verifier) + q.Set("code_challenge", verifier+"_tampered") + req.URL.RawQuery = q.Encode() + log.Printf("Modified URL to: %s", req.URL.String()) + } else { + log.Printf("No code_challenge found in GET query params") + } + } + + // Forward the request with the tampered verifier + resp, err := t.base.RoundTrip(req) + if err != nil { + log.Printf("RoundTrip error: %v", err) + + return nil, err + } + log.Printf("Response status: %s", resp.Status) + + return resp, err +} + +func TestOIDCAuthenticationWithPKCEVerifierTampering(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + baseScenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + + scenario := AuthOIDCScenario{ + Scenario: baseScenario, + } + defer scenario.ShutdownAssertNoPanics(t) + + // Single user with one node for testing PKCE flow + spec := map[string]int{ + "user1": 1, + } + + mockusers := []mockoidc.MockUser{ + oidcMockUser("user1", true), + } + + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) + assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, + "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + } + + // Create a transport that modifies the PKCE verifier in transit + baseTransport := &http.Transport{ + // #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + tamperTransport := &tamperVerifierTransport{ + base: baseTransport, + } + + err = scenario.CreateHeadscaleEnvWithHTTPModifier( + spec, + func(cli *http.Client) { + cli.Transport = tamperTransport + }, + hsic.WithTestName("oidcauthpkce"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + ) + if err == nil { + t.Error("expected authentication to fail due to PKCE verifier tampering, but it succeeded") + } else { + log.Printf("auth got error: %s", err) + } +} + func (s *AuthOIDCScenario) CreateHeadscaleEnv( users map[string]int, opts ...hsic.Option, @@ -643,7 +764,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( // This is because the MockOIDC server can only serve login // requests based on a queue it has been given on startup. // We currently only populates it with one login request per user. - return fmt.Errorf("client count must be 1 for OIDC scenario.") + return ErrOIDCClientCount } log.Printf("creating user %s with %d clients", userName, clientCount) err = s.CreateUser(userName) @@ -665,6 +786,49 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( return nil } +func (s *AuthOIDCScenario) CreateHeadscaleEnvWithHTTPModifier( + users map[string]int, + httpModifier func(*http.Client), + opts ...hsic.Option, +) error { + headscale, err := s.Headscale(opts...) + if err != nil { + return err + } + + err = headscale.WaitForRunning() + if err != nil { + return err + } + + for userName, clientCount := range users { + if clientCount != 1 { + // OIDC scenario only supports one client per user. + // This is because the MockOIDC server can only serve login + // requests based on a queue it has been given on startup. + // We currently only populates it with one login request per user. + return ErrOIDCClientCount + } + log.Printf("creating user %s with %d clients", userName, clientCount) + err = s.CreateUser(userName) + if err != nil { + return err + } + + err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) + if err != nil { + return err + } + + err = s.runTailscaleUpWithModifier(userName, headscale.GetEndpoint(), httpModifier) + if err != nil { + return err + } + } + + return nil +} + func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { port, err := dockertestutil.RandomFreeHostPort() if err != nil { @@ -774,7 +938,7 @@ func (s *AuthOIDCScenario) runTailscaleUp( log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) } - loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) + loginURL.Host = headscale.GetIP() + ":8080" loginURL.Scheme = "http" if len(headscale.GetCert()) > 0 { @@ -782,6 +946,7 @@ func (s *AuthOIDCScenario) runTailscaleUp( } insecureTransport := &http.Transport{ + // #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint } @@ -848,6 +1013,98 @@ func (s *AuthOIDCScenario) runTailscaleUp( return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) } +func (s *AuthOIDCScenario) runTailscaleUpWithModifier( + userStr string, + loginServer string, + httpClientModifier func(*http.Client), +) error { + headscale, err := s.Headscale() + if err != nil { + return err + } + + log.Printf("running tailscale up for user %s", userStr) + if user, ok := s.users[userStr]; ok { + for _, client := range user.Clients { + c := client + err := func() error { + status, err := c.Status() + if err != nil { + log.Printf("%s failed to get status: %s", c.Hostname(), err) + return err + } + + if status.BackendState == nodeStateRunning { + log.Printf("%s is already running", c.Hostname()) + return nil + } + + log.Printf("%s running tailscale up", c.Hostname()) + + loginURL, err := c.LoginWithURL(loginServer) + if err != nil { + log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) + return err + } + + loginURL.Host = headscale.GetIP() + ":8080" + loginURL.Scheme = "http" + + if len(headscale.GetCert()) > 0 { + loginURL.Scheme = "https" + } + + insecureTransport := &http.Transport{ + // #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint + } + + log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) + + log.Printf("%s logging in with url", c.Hostname()) + httpClient := &http.Client{Transport: insecureTransport} + + // Allow the test to modify the HTTP client + if httpClientModifier != nil { + httpClientModifier(httpClient) + } + + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + resp, err := httpClient.Do(req) + if err != nil { + log.Printf( + "%s failed to login using url %s: %s", + c.Hostname(), + loginURL, + err, + ) + + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + body, _ := io.ReadAll(resp.Body) + log.Printf("body: %s", body) + + return errStatusCodeNotOK + } + + return nil + }() + if err != nil { + return err + } + } + + return nil + } + + return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) +} + func (s *AuthOIDCScenario) Shutdown() { err := s.pool.Purge(s.mockOIDC) if err != nil {