mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-29 18:33:05 +00:00
Compare commits
15 commits
bf615154fe
...
31f6a2de1b
Author | SHA1 | Date | |
---|---|---|---|
|
31f6a2de1b | ||
|
edf9e25001 | ||
|
c6336adb01 | ||
|
0f432854c6 | ||
|
0710e1aa4c | ||
|
1c767dfb0f | ||
|
d149180ff8 | ||
|
d3c3b76a0d | ||
|
14d678c0de | ||
|
eae63ec562 | ||
|
2581a2d8cf | ||
|
22bcbef778 | ||
|
78374b2d4d | ||
|
270e3a7a49 | ||
|
5fbf3f8327 |
41 changed files with 1891 additions and 365 deletions
2
.github/workflows/test-integration.yaml
vendored
2
.github/workflows/test-integration.yaml
vendored
|
@ -21,6 +21,7 @@ jobs:
|
||||||
- TestPolicyUpdateWhileRunningWithCLIInDatabase
|
- TestPolicyUpdateWhileRunningWithCLIInDatabase
|
||||||
- TestOIDCAuthenticationPingAll
|
- TestOIDCAuthenticationPingAll
|
||||||
- TestOIDCExpireNodesBasedOnTokenExpiry
|
- TestOIDCExpireNodesBasedOnTokenExpiry
|
||||||
|
- TestOIDC024UserCreation
|
||||||
- TestAuthWebFlowAuthenticationPingAll
|
- TestAuthWebFlowAuthenticationPingAll
|
||||||
- TestAuthWebFlowLogoutAndRelogin
|
- TestAuthWebFlowLogoutAndRelogin
|
||||||
- TestUserCommand
|
- TestUserCommand
|
||||||
|
@ -38,6 +39,7 @@ jobs:
|
||||||
- TestNodeMoveCommand
|
- TestNodeMoveCommand
|
||||||
- TestPolicyCommand
|
- TestPolicyCommand
|
||||||
- TestPolicyBrokenConfigCommand
|
- TestPolicyBrokenConfigCommand
|
||||||
|
- TestDERPVerifyEndpoint
|
||||||
- TestResolveMagicDNS
|
- TestResolveMagicDNS
|
||||||
- TestValidateResolvConf
|
- TestValidateResolvConf
|
||||||
- TestDERPServerScenario
|
- TestDERPServerScenario
|
||||||
|
|
79
CHANGELOG.md
79
CHANGELOG.md
|
@ -2,16 +2,82 @@
|
||||||
|
|
||||||
## Next
|
## Next
|
||||||
|
|
||||||
|
### Security fix: OIDC changes in Headscale 0.24.0
|
||||||
|
|
||||||
|
_Headscale v0.23.0 and earlier_ identified OIDC users by the "username" part of their email address (when `strip_email_domain: true`, the default) or whole email address (when `strip_email_domain: false`).
|
||||||
|
|
||||||
|
Depending on how Headscale and your Identity Provider (IdP) were configured, only using the `email` claim could allow a malicious user with an IdP account to take over another Headscale user's account, even when `strip_email_domain: false`.
|
||||||
|
|
||||||
|
This would also cause a user to lose access to their Headscale account if they changed their email address.
|
||||||
|
|
||||||
|
_Headscale v0.24.0_ now identifies OIDC users by the `iss` and `sub` claims. [These are guaranteed by the OIDC specification to be stable and unique](https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability), even if a user changes email address. A well-designed IdP will typically set `sub` to an opaque identifier like a UUID or numeric ID, which has no relation to the user's name or email address.
|
||||||
|
|
||||||
|
This issue _only_ affects Headscale installations which authenticate with OIDC.
|
||||||
|
|
||||||
|
Headscale v0.24.0 and later will also automatically update profile fields with OIDC data on login. This means that users can change those details in your IdP, and have it populate to Headscale automatically the next time they log in. However, this may affect the way you reference users in policies.
|
||||||
|
|
||||||
|
#### Migrating existing installations
|
||||||
|
|
||||||
|
Headscale v0.23.0 and earlier never recorded the `iss` and `sub` fields, so all legacy (existing) OIDC accounts from _need to be migrated_ to be properly secured.
|
||||||
|
|
||||||
|
Headscale v0.24.0 has an automatic migration feature, which is enabled by default (`map_legacy_users: true`). **This will be disabled by default in a future version of Headscale – any unmigrated users will get new accounts.**
|
||||||
|
|
||||||
|
Headscale v0.24.0 will ignore any `email` claim if the IdP does not provide an `email_verified` claim set to `true`. [What "verified" actually means is contextually dependent](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) – Headscale uses it as a signal that the contents of the `email` claim is reasonably trustworthy.
|
||||||
|
|
||||||
|
Headscale v0.23.0 and earlier never checked the `email_verified` claim. This means even if an IdP explicitly indicated to Headscale that its `email` claim was untrustworthy, Headscale would have still accepted it.
|
||||||
|
|
||||||
|
##### What does automatic migration do?
|
||||||
|
|
||||||
|
When automatic migration is enabled (`map_legacy_users: true`), Headscale will first match an OIDC account to a Headscale account by `iss` and `sub`, and then fall back to matching OIDC users similarly to how Headscale v0.23.0 did:
|
||||||
|
|
||||||
|
- If `strip_email_domain: true` (the default): the Headscale username matches the "username" part of their email address.
|
||||||
|
- If `strip_email_domain: false`: the Headscale username matches the _whole_ email address.
|
||||||
|
|
||||||
|
On migration, Headscale will change the account's username to their `preferred_username`. **This could break any ACLs or policies which are configured to match by username.**
|
||||||
|
|
||||||
|
Like with Headscale v0.23.0 and earlier, this migration only works for users who haven't changed their email address since their last Headscale login.
|
||||||
|
|
||||||
|
A _successful_ automated migration should otherwise be transparent to users.
|
||||||
|
|
||||||
|
Once a Headscale account has been migrated, it will be _unavailable_ to be matched by the legacy process. An OIDC login with a matching username, but _non-matching_ `iss` and `sub` will instead get a _new_ Headscale account.
|
||||||
|
|
||||||
|
Because of the way OIDC works, Headscale's automated migration process can _only_ work when a user tries to log in after the update. Mass updates would require Headscale implement a protocol like SCIM, which is **extremely** complicated and not available in all identity providers.
|
||||||
|
|
||||||
|
Administrators could also attempt to migrate users manually by editing the database, using their own mapping rules with known-good data sources.
|
||||||
|
|
||||||
|
Legacy account migration should have no effect on new installations where all users have a recorded `sub` and `iss`.
|
||||||
|
|
||||||
|
##### What happens when automatic migration is disabled?
|
||||||
|
|
||||||
|
When automatic migration is disabled (`map_legacy_users: false`), Headscale will only try to match an OIDC account to a Headscale account by `iss` and `sub`.
|
||||||
|
|
||||||
|
If there is no match, it will get a _new_ Headscale account – even if there was a legacy account which _could_ have matched and migrated.
|
||||||
|
|
||||||
|
We recommend new Headscale users explicitly disable automatic migration – but it should otherwise have no effect if every account has a recorded `iss` and `sub`.
|
||||||
|
|
||||||
|
When automatic migration is disabled, the `strip_email_domain` setting will have no effect.
|
||||||
|
|
||||||
|
Special thanks to @micolous for reviewing, proposing and working with us on these changes.
|
||||||
|
|
||||||
|
#### Other OIDC changes
|
||||||
|
|
||||||
|
Headscale now uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to populate and update user information every time they log in:
|
||||||
|
|
||||||
|
| Headscale profile field | OIDC claim | Notes / examples |
|
||||||
|
| ----------------------- | -------------------- | --------------------------------------------------------------------------------------------------------- |
|
||||||
|
| email address | `email` | Only used when `"email_verified": true` |
|
||||||
|
| display name | `name` | eg: `Sam Smith` |
|
||||||
|
| username | `preferred_username` | Varies depending on IdP and configuration, eg: `ssmith`, `ssmith@idp.example.com`, `\\example.com\ssmith` |
|
||||||
|
| profile picture | `picture` | URL to a profile picture or avatar |
|
||||||
|
|
||||||
|
These should show up nicely in the Tailscale client.
|
||||||
|
|
||||||
|
This will also affect the way you [reference users in policies](https://github.com/juanfont/headscale/pull/2205).
|
||||||
|
|
||||||
### BREAKING
|
### BREAKING
|
||||||
|
|
||||||
- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020)
|
- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020)
|
||||||
- Having usernames in magic DNS is no longer possible.
|
- Having usernames in magic DNS is no longer possible.
|
||||||
- Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020)
|
|
||||||
- `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC.
|
|
||||||
- Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated.
|
|
||||||
- User has been extended to store username, display name, profile picture url and email.
|
|
||||||
- These fields are forwarded to the client, and shows up nicely in the user switcher.
|
|
||||||
- These fields can be made available via the API/CLI for non-OIDC users in the future.
|
|
||||||
- Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149)
|
- Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149)
|
||||||
- Clean up old code required by old versions
|
- Clean up old code required by old versions
|
||||||
|
|
||||||
|
@ -23,6 +89,7 @@
|
||||||
- Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198)
|
- Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198)
|
||||||
- Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199)
|
- Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199)
|
||||||
- Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232)
|
- Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232)
|
||||||
|
- Loosened up `server_url` and `base_domain` check. It was overly strict in some cases.
|
||||||
|
|
||||||
## 0.23.0 (2024-09-18)
|
## 0.23.0 (2024-09-18)
|
||||||
|
|
||||||
|
|
19
Dockerfile.derper
Normal file
19
Dockerfile.derper
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# For testing purposes only
|
||||||
|
|
||||||
|
FROM golang:alpine AS build-env
|
||||||
|
|
||||||
|
WORKDIR /go/src
|
||||||
|
|
||||||
|
RUN apk add --no-cache git
|
||||||
|
ARG VERSION_BRANCH=main
|
||||||
|
RUN git clone https://github.com/tailscale/tailscale.git --branch=$VERSION_BRANCH --depth=1
|
||||||
|
WORKDIR /go/src/tailscale
|
||||||
|
|
||||||
|
ARG TARGETARCH
|
||||||
|
RUN GOARCH=$TARGETARCH go install -v ./cmd/derper
|
||||||
|
|
||||||
|
FROM alpine:3.18
|
||||||
|
RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl
|
||||||
|
|
||||||
|
COPY --from=build-env /go/bin/* /usr/local/bin/
|
||||||
|
ENTRYPOINT [ "/usr/local/bin/derper" ]
|
|
@ -28,7 +28,9 @@ ARG VERSION_GIT_HASH=""
|
||||||
ENV VERSION_GIT_HASH=$VERSION_GIT_HASH
|
ENV VERSION_GIT_HASH=$VERSION_GIT_HASH
|
||||||
ARG TARGETARCH
|
ARG TARGETARCH
|
||||||
|
|
||||||
RUN GOARCH=$TARGETARCH go install -ldflags="\
|
ARG BUILD_TAGS=""
|
||||||
|
|
||||||
|
RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\
|
||||||
-X tailscale.com/version.longStamp=$VERSION_LONG \
|
-X tailscale.com/version.longStamp=$VERSION_LONG \
|
||||||
-X tailscale.com/version.shortStamp=$VERSION_SHORT \
|
-X tailscale.com/version.shortStamp=$VERSION_SHORT \
|
||||||
-X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \
|
-X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
@ -64,6 +66,19 @@ func mockOIDC() error {
|
||||||
accessTTL = newTTL
|
accessTTL = newTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userStr := os.Getenv("MOCKOIDC_USERS")
|
||||||
|
if userStr == "" {
|
||||||
|
return fmt.Errorf("MOCKOIDC_USERS not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []mockoidc.MockUser
|
||||||
|
err := json.Unmarshal([]byte(userStr), &users)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unmarshalling users: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Interface("users", users).Msg("loading users from JSON")
|
||||||
|
|
||||||
log.Info().Msgf("Access token TTL: %s", accessTTL)
|
log.Info().Msgf("Access token TTL: %s", accessTTL)
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
port, err := strconv.Atoi(portStr)
|
||||||
|
@ -71,7 +86,7 @@ func mockOIDC() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mock, err := getMockOIDC(clientID, clientSecret)
|
mock, err := getMockOIDC(clientID, clientSecret, users)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -93,12 +108,18 @@ func mockOIDC() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) {
|
func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) {
|
||||||
keypair, err := mockoidc.NewKeypair(nil)
|
keypair, err := mockoidc.NewKeypair(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userQueue := mockoidc.UserQueue{}
|
||||||
|
|
||||||
|
for _, user := range users {
|
||||||
|
userQueue.Push(&user)
|
||||||
|
}
|
||||||
|
|
||||||
mock := mockoidc.MockOIDC{
|
mock := mockoidc.MockOIDC{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
|
@ -107,9 +128,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro
|
||||||
CodeChallengeMethodsSupported: []string{"plain", "S256"},
|
CodeChallengeMethodsSupported: []string{"plain", "S256"},
|
||||||
Keypair: keypair,
|
Keypair: keypair,
|
||||||
SessionStore: mockoidc.NewSessionStore(),
|
SessionStore: mockoidc.NewSessionStore(),
|
||||||
UserQueue: &mockoidc.UserQueue{},
|
UserQueue: &userQueue,
|
||||||
ErrorQueue: &mockoidc.ErrorQueue{},
|
ErrorQueue: &mockoidc.ErrorQueue{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Info().Msgf("Request: %+v", r)
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
if r.Response != nil {
|
||||||
|
log.Info().Msgf("Response: %+v", r.Response)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
return &mock, nil
|
return &mock, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -364,12 +364,18 @@ unix_socket_permission: "0770"
|
||||||
# allowed_users:
|
# allowed_users:
|
||||||
# - alice@example.com
|
# - alice@example.com
|
||||||
#
|
#
|
||||||
# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
|
# # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users
|
||||||
# # This will transform `first-name.last-name@example.com` to the user `first-name.last-name`
|
# # by taking the username from the legacy user and matching it with the username
|
||||||
# # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
|
# # provided by the OIDC. This is useful when migrating from legacy users to OIDC
|
||||||
# user: `first-name.last-name.example.com`
|
# # to force them using the unique identifier from the OIDC and to give them a
|
||||||
#
|
# # proper display name and picture if available.
|
||||||
# strip_email_domain: true
|
# # Note that this will only work if the username from the legacy user is the same
|
||||||
|
# # and ther is a posibility for account takeover should a username have changed
|
||||||
|
# # with the provider.
|
||||||
|
# # Disabling this feature will cause all new logins to be created as new users.
|
||||||
|
# # Note this option will be removed in the future and should be set to false
|
||||||
|
# # on all new installations, or when all users have logged in with OIDC once.
|
||||||
|
# map_legacy_users: true
|
||||||
|
|
||||||
# Logtail configuration
|
# Logtail configuration
|
||||||
# Logtail is Tailscales logging and auditing infrastructure, it allows the control panel
|
# Logtail is Tailscales logging and auditing infrastructure, it allows the control panel
|
||||||
|
|
|
@ -32,7 +32,7 @@
|
||||||
|
|
||||||
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
||||||
# update this if you have a mismatch after doing a change to thos files.
|
# update this if you have a mismatch after doing a change to thos files.
|
||||||
vendorHash = "sha256-CMkYTRjmhvTTrB7JbLj0cj9VEyzpG0iUWXkaOagwYTk=";
|
vendorHash = "sha256-Qoqu2k4vvnbRFLmT/v8lI+HCEWqJsHFs8uZRfNmwQpo=";
|
||||||
|
|
||||||
subPackages = ["cmd/headscale"];
|
subPackages = ["cmd/headscale"];
|
||||||
|
|
||||||
|
|
|
@ -457,6 +457,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||||
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
|
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
|
||||||
Methods(http.MethodGet)
|
Methods(http.MethodGet)
|
||||||
|
|
||||||
|
router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost)
|
||||||
|
|
||||||
if h.cfg.DERP.ServerEnabled {
|
if h.cfg.DERP.ServerEnabled {
|
||||||
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
|
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
|
||||||
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
|
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
|
||||||
|
|
|
@ -474,6 +474,8 @@ func NewHeadscaleDatabase(
|
||||||
Rollback: func(db *gorm.DB) error { return nil },
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
// Pick up new user fields used for OIDC and to
|
||||||
|
// populate the user with more interesting information.
|
||||||
ID: "202407191627",
|
ID: "202407191627",
|
||||||
Migrate: func(tx *gorm.DB) error {
|
Migrate: func(tx *gorm.DB) error {
|
||||||
err := tx.AutoMigrate(&types.User{})
|
err := tx.AutoMigrate(&types.User{})
|
||||||
|
@ -485,6 +487,21 @@ func NewHeadscaleDatabase(
|
||||||
},
|
},
|
||||||
Rollback: func(db *gorm.DB) error { return nil },
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
// The unique constraint of Name has been dropped
|
||||||
|
// in favour of a unique together of name and
|
||||||
|
// provider identity.
|
||||||
|
ID: "202408181235",
|
||||||
|
Migrate: func(tx *gorm.DB) error {
|
||||||
|
err := tx.AutoMigrate(&types.User{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"zgo.at/zcache/v2"
|
"zgo.at/zcache/v2"
|
||||||
)
|
)
|
||||||
|
@ -120,12 +122,12 @@ func TestMigrations(t *testing.T) {
|
||||||
dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
|
dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
|
||||||
wantFunc: func(t *testing.T, h *HSDatabase) {
|
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||||
keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||||
kratest, err := ListPreAuthKeys(rx, "kratest")
|
kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
testkra, err := ListPreAuthKeys(rx, "testkra")
|
testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -256,3 +258,110 @@ func testCopyOfDatabase(src string) (string, error) {
|
||||||
func emptyCache() *zcache.Cache[string, types.Node] {
|
func emptyCache() *zcache.Cache[string, types.Node] {
|
||||||
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConstraints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
run func(*testing.T, *gorm.DB)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-duplicate-username-if-no-oidc",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
_, err := CreateUser(db, "user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = CreateUser(db, "user1")
|
||||||
|
require.Error(t, err)
|
||||||
|
// assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username")
|
||||||
|
require.Contains(t, err.Error(), "user already exists")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-oidc-duplicate-username-and-id",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user = types.User{
|
||||||
|
Model: gorm.Model{ID: 2},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-oidc-duplicate-id",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user = types.User{
|
||||||
|
Model: gorm.Model{ID: 2},
|
||||||
|
Name: "user1.1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow-duplicate-username-cli-then-oidc",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
_, err := CreateUser(db, "user1") // Create CLI username
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user := types.User{
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier.String = "http://test.com/user1"
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow-duplicate-username-oidc-then-cli",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier.String = "http://test.com/user1"
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = CreateUser(db, "user1") // Create CLI username
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
db, err := newTestDB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating database: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tt.run(t, db.DB)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -91,15 +91,15 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
|
func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
return getNode(rx, user, name)
|
return getNode(rx, uid, name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// getNode finds a Node by name and user and returns the Node struct.
|
// getNode finds a Node by name and user and returns the Node struct.
|
||||||
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
|
func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
|
||||||
nodes, err := ListNodesByUser(tx, user)
|
nodes, err := ListNodesByUser(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,10 +29,10 @@ func (s *Suite) TestGetNode(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -50,7 +50,7 @@ func (s *Suite) TestGetNode(c *check.C) {
|
||||||
trx := db.DB.Save(node)
|
trx := db.DB.Save(node)
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -87,7 +87,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -135,7 +135,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||||
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
|
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(user.Name, "testnode3")
|
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ func (s *Suite) TestListPeers(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -189,7 +189,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
for _, name := range []string{"test", "admin"} {
|
for _, name := range []string{"test", "admin"} {
|
||||||
user, err := db.CreateUser(name)
|
user, err := db.CreateUser(name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
stor = append(stor, base{user, pak})
|
stor = append(stor, base{user, pak})
|
||||||
}
|
}
|
||||||
|
@ -281,10 +281,10 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -302,7 +302,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
}
|
}
|
||||||
db.DB.Save(node)
|
db.DB.Save(node)
|
||||||
|
|
||||||
nodeFromDB, err := db.getNode("test", "testnode")
|
nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(nodeFromDB, check.NotNil)
|
c.Assert(nodeFromDB, check.NotNil)
|
||||||
|
|
||||||
|
@ -312,7 +312,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
nodeFromDB, err = db.getNode("test", "testnode")
|
nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
||||||
|
@ -322,10 +322,10 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -348,7 +348,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
sTags := []string{"tag:test", "tag:foo"}
|
sTags := []string{"tag:test", "tag:foo"}
|
||||||
err = db.SetTags(node.ID, sTags)
|
err = db.SetTags(node.ID, sTags)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode("test", "testnode")
|
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
|
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
|
||||||
|
|
||||||
|
@ -356,7 +356,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||||
err = db.SetTags(node.ID, eTags)
|
err = db.SetTags(node.ID, eTags)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode("test", "testnode")
|
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
node.ForcedTags,
|
node.ForcedTags,
|
||||||
|
@ -367,7 +367,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
// test removing tags
|
// test removing tags
|
||||||
err = db.SetTags(node.ID, []string{})
|
err = db.SetTags(node.ID, []string{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode("test", "testnode")
|
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
|
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
|
||||||
}
|
}
|
||||||
|
@ -567,7 +567,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||||
user, err := adb.CreateUser("test")
|
user, err := adb.CreateUser("test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -699,10 +699,10 @@ func TestListEphemeralNodes(t *testing.T) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
|
|
@ -23,29 +23,27 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
// TODO(kradalby): Should be ID, not name
|
uid types.UserID,
|
||||||
userName string,
|
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*types.PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
||||||
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
|
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||||
func CreatePreAuthKey(
|
func CreatePreAuthKey(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
// TODO(kradalby): Should be ID, not name
|
uid types.UserID,
|
||||||
userName string,
|
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*types.PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
user, err := GetUserByUsername(tx, userName)
|
user, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -89,15 +87,15 @@ func CreatePreAuthKey(
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||||
return ListPreAuthKeys(rx, userName)
|
return ListPreAuthKeysByUser(rx, uid)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
|
||||||
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
|
||||||
user, err := GetUserByUsername(tx, userName)
|
user, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,14 +11,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
// ID does not exist
|
||||||
|
_, err := db.CreatePreAuthKey(12345, true, false, nil, nil)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// Did we get a valid key?
|
// Did we get a valid key?
|
||||||
|
@ -26,17 +26,18 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
c.Assert(len(key.Key), check.Equals, 48)
|
c.Assert(len(key.Key), check.Equals, 48)
|
||||||
|
|
||||||
// Make sure the User association is populated
|
// Make sure the User association is populated
|
||||||
c.Assert(key.User.Name, check.Equals, user.Name)
|
c.Assert(key.User.ID, check.Equals, user.ID)
|
||||||
|
|
||||||
_, err = db.ListPreAuthKeys("bogus")
|
// ID does not exist
|
||||||
|
_, err = db.ListPreAuthKeys(1000000)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
keys, err := db.ListPreAuthKeys(user.Name)
|
keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
|
|
||||||
// Make sure the User association is populated
|
// Make sure the User association is populated
|
||||||
c.Assert((keys)[0].User.Name, check.Equals, user.Name)
|
c.Assert((keys)[0].User.ID, check.Equals, user.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
|
@ -44,7 +45,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now().Add(-5 * time.Second)
|
now := time.Now().Add(-5 * time.Second)
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -62,7 +63,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||||
user, err := db.CreateUser("test3")
|
user, err := db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -74,7 +75,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test4")
|
user, err := db.CreateUser("test4")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -96,7 +97,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test5")
|
user, err := db.CreateUser("test5")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -118,7 +119,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test6")
|
user, err := db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -130,7 +131,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test3")
|
user, err := db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.IsNil)
|
c.Assert(pak.Expiration, check.IsNil)
|
||||||
|
|
||||||
|
@ -147,7 +148,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||||
user, err := db.CreateUser("test6")
|
user, err := db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
db.DB.Save(&pak)
|
db.DB.Save(&pak)
|
||||||
|
@ -160,15 +161,15 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||||
user, err := db.CreateUser("test8")
|
user, err := db.CreateUser("test8")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"})
|
||||||
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||||
|
|
||||||
tags := []string{"tag:test1", "tag:test2"}
|
tags := []string{"tag:test1", "tag:test2"}
|
||||||
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||||
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
listedPaks, err := db.ListPreAuthKeys("test8")
|
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
gotTags := listedPaks[0].Proto().GetAclTags()
|
gotTags := listedPaks[0].Proto().GetAclTags()
|
||||||
sort.Sort(sort.StringSlice(gotTags))
|
sort.Sort(sort.StringSlice(gotTags))
|
||||||
|
|
|
@ -639,7 +639,7 @@ func EnableAutoApprovedRoutes(
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Str("user", node.User.Name).
|
Uint("user.id", node.User.ID).
|
||||||
Strs("routeApprovers", routeApprovers).
|
Strs("routeApprovers", routeApprovers).
|
||||||
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
||||||
Msg("looking up route for autoapproving")
|
Msg("looking up route for autoapproving")
|
||||||
|
|
|
@ -35,10 +35,10 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "test_get_route_node")
|
_, err = db.getNode(types.UserID(user.ID), "test_get_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
|
@ -79,10 +79,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "test_enable_route_node")
|
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -153,10 +153,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "test_enable_route_node")
|
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -234,10 +234,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "test_enable_route_node")
|
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
|
|
@ -28,11 +28,9 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user := types.User{}
|
user := types.User{
|
||||||
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
|
Name: name,
|
||||||
return nil, ErrUserExists
|
|
||||||
}
|
}
|
||||||
user.Name = name
|
|
||||||
if err := tx.Create(&user).Error; err != nil {
|
if err := tx.Create(&user).Error; err != nil {
|
||||||
return nil, fmt.Errorf("creating user: %w", err)
|
return nil, fmt.Errorf("creating user: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -40,21 +38,21 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return DestroyUser(tx, name)
|
return DestroyUser(tx, uid)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// DestroyUser destroys a User. Returns error if the User does
|
// DestroyUser destroys a User. Returns error if the User does
|
||||||
// not exist or if there are nodes associated with it.
|
// not exist or if there are nodes associated with it.
|
||||||
func DestroyUser(tx *gorm.DB, name string) error {
|
func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||||
user, err := GetUserByUsername(tx, name)
|
user, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrUserNotFound
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := ListNodesByUser(tx, name)
|
nodes, err := ListNodesByUser(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -62,7 +60,7 @@ func DestroyUser(tx *gorm.DB, name string) error {
|
||||||
return ErrUserStillHasNodes
|
return ErrUserStillHasNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := ListPreAuthKeys(tx, name)
|
keys, err := ListPreAuthKeysByUser(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -80,17 +78,17 @@ func DestroyUser(tx *gorm.DB, name string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return RenameUser(tx, oldName, newName)
|
return RenameUser(tx, uid, newName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenameUser renames a User. Returns error if the User does
|
// RenameUser renames a User. Returns error if the User does
|
||||||
// not exist or if another User exists with the new name.
|
// not exist or if another User exists with the new name.
|
||||||
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||||
var err error
|
var err error
|
||||||
oldUser, err := GetUserByUsername(tx, oldName)
|
oldUser, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -98,50 +96,25 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = GetUserByUsername(tx, newName)
|
|
||||||
if err == nil {
|
|
||||||
return ErrUserExists
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrUserNotFound) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldUser.Name = newName
|
oldUser.Name = newName
|
||||||
|
|
||||||
if result := tx.Save(&oldUser); result.Error != nil {
|
if err := tx.Save(&oldUser).Error; err != nil {
|
||||||
return result.Error
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||||
return GetUserByUsername(rx, name)
|
return GetUserByID(rx, uid)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) {
|
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
||||||
user := types.User{}
|
user := types.User{}
|
||||||
if result := tx.First(&user, "name = ?", name); errors.Is(
|
if result := tx.First(&user, "id = ?", uid); errors.Is(
|
||||||
result.Error,
|
|
||||||
gorm.ErrRecordNotFound,
|
|
||||||
) {
|
|
||||||
return nil, ErrUserNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return &user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) {
|
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
|
||||||
return GetUserByID(rx, id)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) {
|
|
||||||
user := types.User{}
|
|
||||||
if result := tx.First(&user, "id = ?", id); errors.Is(
|
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -169,54 +142,69 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||||
return ListUsers(rx)
|
return ListUsers(rx, where...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers gets all the existing users.
|
// ListUsers gets all the existing users.
|
||||||
func ListUsers(tx *gorm.DB) ([]types.User, error) {
|
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||||
|
if len(where) > 1 {
|
||||||
|
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
|
||||||
|
}
|
||||||
|
|
||||||
|
var user *types.User
|
||||||
|
if len(where) == 1 {
|
||||||
|
user = where[0]
|
||||||
|
}
|
||||||
|
|
||||||
users := []types.User{}
|
users := []types.User{}
|
||||||
if err := tx.Find(&users).Error; err != nil {
|
if err := tx.Where(user).Find(&users).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNodesByUser gets all the nodes in a given user.
|
// GetUserByName returns a user if the provided username is
|
||||||
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
|
// unique, and otherwise an error.
|
||||||
err := util.CheckForFQDNRules(name)
|
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||||
if err != nil {
|
users, err := hsdb.ListUsers(&types.User{Name: name})
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
user, err := GetUserByUsername(tx, name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(users) == 0 {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != 1 {
|
||||||
|
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &users[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListNodesByUser gets all the nodes in a given user.
|
||||||
|
func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
|
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return AssignNodeToUser(tx, node, username)
|
return AssignNodeToUser(tx, node, uid)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AssignNodeToUser assigns a Node to a user.
|
// AssignNodeToUser assigns a Node to a user.
|
||||||
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
|
func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error {
|
||||||
err := util.CheckForFQDNRules(username)
|
user, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
user, err := GetUserByUsername(tx, username)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
@ -17,24 +19,24 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetUserByName("test")
|
_, err = db.GetUserByID(types.UserID(user.ID))
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
err := db.DestroyUser("test")
|
err := db.DestroyUser(9998)
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
|
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||||
|
@ -44,7 +46,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
user, err = db.CreateUser("test")
|
user, err = db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -57,7 +59,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
trx := db.DB.Save(&node)
|
trx := db.DB.Save(&node)
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
|
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser(types.UserID(user.ID))
|
||||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,24 +72,28 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = db.RenameUser("test", "test-renamed")
|
err = db.RenameUser(types.UserID(userTest.ID), "test-renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetUserByName("test")
|
users, err = db.ListUsers(&types.User{Name: "test"})
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, nil)
|
||||||
|
c.Assert(len(users), check.Equals, 0)
|
||||||
|
|
||||||
_, err = db.GetUserByName("test-renamed")
|
users, err = db.ListUsers(&types.User{Name: "test-renamed"})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = db.RenameUser("test-does-not-exit", "test")
|
err = db.RenameUser(99988, "test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
userTest2, err := db.CreateUser("test2")
|
userTest2, err := db.CreateUser("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||||
|
|
||||||
err = db.RenameUser("test2", "test-renamed")
|
err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed")
|
||||||
c.Assert(err, check.Equals, ErrUserExists)
|
if !strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||||
|
c.Fatalf("expected failure with unique constraint, got: %s", err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
|
@ -97,7 +103,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
newUser, err := db.CreateUser("new")
|
newUser, err := db.CreateUser("new")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -111,15 +117,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, newUser.Name)
|
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, newUser.ID)
|
c.Assert(node.UserID, check.Equals, newUser.ID)
|
||||||
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, "non-existing-user")
|
err = db.AssignNodeToUser(&node, 9584849)
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, newUser.Name)
|
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, newUser.ID)
|
c.Assert(node.UserID, check.Equals, newUser.ID)
|
||||||
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
|
@ -65,24 +65,34 @@ func (api headscaleV1APIServer) RenameUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.RenameUserRequest,
|
request *v1.RenameUserRequest,
|
||||||
) (*v1.RenameUserResponse, error) {
|
) (*v1.RenameUserResponse, error) {
|
||||||
err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName())
|
oldUser, err := api.h.db.GetUserByName(request.GetOldName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := api.h.db.GetUserByName(request.GetNewName())
|
err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.RenameUserResponse{User: user.Proto()}, nil
|
newUser, err := api.h.db.GetUserByName(request.GetNewName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &v1.RenameUserResponse{User: newUser.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) DeleteUser(
|
func (api headscaleV1APIServer) DeleteUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteUserRequest,
|
request *v1.DeleteUserRequest,
|
||||||
) (*v1.DeleteUserResponse, error) {
|
) (*v1.DeleteUserResponse, error) {
|
||||||
err := api.h.db.DestroyUser(request.GetName())
|
user, err := api.h.db.GetUserByName(request.GetName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = api.h.db.DestroyUser(types.UserID(user.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -131,8 +141,13 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
preAuthKey, err := api.h.db.CreatePreAuthKey(
|
preAuthKey, err := api.h.db.CreatePreAuthKey(
|
||||||
request.GetUser(),
|
types.UserID(user.ID),
|
||||||
request.GetReusable(),
|
request.GetReusable(),
|
||||||
request.GetEphemeral(),
|
request.GetEphemeral(),
|
||||||
&expiration,
|
&expiration,
|
||||||
|
@ -168,7 +183,12 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListPreAuthKeysRequest,
|
request *v1.ListPreAuthKeysRequest,
|
||||||
) (*v1.ListPreAuthKeysResponse, error) {
|
) (*v1.ListPreAuthKeysResponse, error) {
|
||||||
preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser())
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -406,10 +426,20 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListNodesRequest,
|
request *v1.ListNodesRequest,
|
||||||
) (*v1.ListNodesResponse, error) {
|
) (*v1.ListNodesResponse, error) {
|
||||||
|
// TODO(kradalby): it looks like this can be simplified a lot,
|
||||||
|
// the filtering of nodes by user, vs nodes as a whole can
|
||||||
|
// probably be done once.
|
||||||
|
// TODO(kradalby): This should be done in one tx.
|
||||||
|
|
||||||
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||||
if request.GetUser() != "" {
|
if request.GetUser() != "" {
|
||||||
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return db.ListNodesByUser(rx, request.GetUser())
|
return db.ListNodesByUser(rx, types.UserID(user.ID))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -465,12 +495,18 @@ func (api headscaleV1APIServer) MoveNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.MoveNodeRequest,
|
request *v1.MoveNodeRequest,
|
||||||
) (*v1.MoveNodeResponse, error) {
|
) (*v1.MoveNodeResponse, error) {
|
||||||
|
// TODO(kradalby): This should be done in one tx.
|
||||||
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.db.AssignNodeToUser(node, request.GetUser())
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -56,6 +57,65 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error)
|
||||||
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
|
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) handleVerifyRequest(
|
||||||
|
req *http.Request,
|
||||||
|
) (bool, error) {
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot read request body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
|
||||||
|
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
|
||||||
|
return false, fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes, err := h.db.ListNodes()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot list nodes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// see https://github.com/tailscale/tailscale/blob/964282d34f06ecc06ce644769c66b0b31d118340/derp/derp_server.go#L1159, Derp use verifyClientsURL to verify whether a client is allowed to connect to the DERP server.
|
||||||
|
func (h *Headscale) VerifyHandler(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
req *http.Request,
|
||||||
|
) {
|
||||||
|
if req.Method != http.MethodPost {
|
||||||
|
http.Error(writer, "Wrong method", http.StatusMethodNotAllowed)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug().
|
||||||
|
Str("handler", "/verify").
|
||||||
|
Msg("verify client")
|
||||||
|
|
||||||
|
allow, err := h.handleVerifyRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to verify client")
|
||||||
|
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := tailcfg.DERPAdmitClientResponse{
|
||||||
|
Allow: allow,
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "application/json")
|
||||||
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
err = json.NewEncoder(writer).Encode(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to write response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// KeyHandler provides the Headscale pub key
|
// KeyHandler provides the Headscale pub key
|
||||||
// Listens in /key.
|
// Listens in /key.
|
||||||
func (h *Headscale) KeyHandler(
|
func (h *Headscale) KeyHandler(
|
||||||
|
|
|
@ -436,25 +436,42 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||||
) (*types.User, error) {
|
) (*types.User, error) {
|
||||||
var user *types.User
|
var user *types.User
|
||||||
var err error
|
var err error
|
||||||
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
|
user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier())
|
||||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This check is for legacy, if the user cannot be found by the OIDC identifier
|
// This check is for legacy, if the user cannot be found by the OIDC identifier
|
||||||
// look it up by username. This should only be needed once.
|
// look it up by username. This should only be needed once.
|
||||||
if user == nil {
|
// This branch will presist for a number of versions after the OIDC migration and
|
||||||
user, err = a.db.GetUserByName(claims.Username)
|
// then be removed following a deprecation.
|
||||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
// TODO(kradalby): Remove when strip_email_domain and migration is removed
|
||||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
// after #2170 is cleaned up.
|
||||||
}
|
if a.cfg.MapLegacyUsers && user == nil {
|
||||||
|
log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username")
|
||||||
|
if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil {
|
||||||
|
log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username")
|
||||||
|
user, err = a.db.GetUserByName(oldUsername)
|
||||||
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
|
return nil, fmt.Errorf("getting user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// if the user is still not found, create a new empty user.
|
// If the user exists, but it already has a provider identifier (OIDC sub), create a new user.
|
||||||
if user == nil {
|
// This is to prevent users that have already been migrated to the new OIDC format
|
||||||
user = &types.User{}
|
// to be updated with the new OIDC identifier inexplicitly which might be the cause of an
|
||||||
|
// account takeover.
|
||||||
|
if user != nil && user.ProviderIdentifier.Valid {
|
||||||
|
log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.")
|
||||||
|
user = &types.User{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if the user is still not found, create a new empty user.
|
||||||
|
if user == nil {
|
||||||
|
user = &types.User{}
|
||||||
|
}
|
||||||
|
|
||||||
user.FromClaim(claims)
|
user.FromClaim(claims)
|
||||||
err = a.db.DB.Save(user).Error
|
err = a.db.DB.Save(user).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -502,3 +519,24 @@ func renderOIDCCallbackTemplate(
|
||||||
|
|
||||||
return &content, nil
|
return &content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Reintroduce when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
// DEPRECATED: DO NOT USE
|
||||||
|
func getUserName(
|
||||||
|
claims *types.OIDCClaims,
|
||||||
|
stripEmaildomain bool,
|
||||||
|
) (string, error) {
|
||||||
|
if !claims.EmailVerified {
|
||||||
|
return "", fmt.Errorf("email not verified")
|
||||||
|
}
|
||||||
|
userName, err := util.NormalizeToFQDNRules(
|
||||||
|
claims.Email,
|
||||||
|
stripEmaildomain,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return userName, nil
|
||||||
|
}
|
||||||
|
|
|
@ -178,7 +178,12 @@ func (pol *ACLPolicy) CompileFilterRules(
|
||||||
for srcIndex, src := range acl.Sources {
|
for srcIndex, src := range acl.Sources {
|
||||||
srcs, err := pol.expandSource(src, nodes)
|
srcs, err := pol.expandSource(src, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing policy, acl index: %d->%d: %w",
|
||||||
|
index,
|
||||||
|
srcIndex,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
srcIPs = append(srcIPs, srcs...)
|
srcIPs = append(srcIPs, srcs...)
|
||||||
}
|
}
|
||||||
|
@ -335,12 +340,21 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
case "check":
|
case "check":
|
||||||
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
|
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing SSH policy, parsing check duration, index: %d: %w",
|
||||||
|
index,
|
||||||
|
err,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
action = *checkAction
|
action = *checkAction
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing SSH policy, unknown action %q, index: %d: %w",
|
||||||
|
sshACL.Action,
|
||||||
|
index,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
||||||
|
@ -977,10 +991,7 @@ func FilterNodesByACL(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Checking if %s can access %s", node.Hostname, peer.Hostname)
|
|
||||||
|
|
||||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||||
log.Printf("CAN ACCESS %s can access %s", node.Hostname, peer.Hostname)
|
|
||||||
result = append(result, peer)
|
result = append(result, peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,8 +28,9 @@ const (
|
||||||
maxDuration time.Duration = 1<<63 - 1
|
maxDuration time.Duration = 1<<63 - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
var errOidcMutuallyExclusive = errors.New(
|
var (
|
||||||
"oidc_client_secret and oidc_client_secret_path are mutually exclusive",
|
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
||||||
|
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPAllocationStrategy string
|
type IPAllocationStrategy string
|
||||||
|
@ -162,8 +163,10 @@ type OIDCConfig struct {
|
||||||
AllowedDomains []string
|
AllowedDomains []string
|
||||||
AllowedUsers []string
|
AllowedUsers []string
|
||||||
AllowedGroups []string
|
AllowedGroups []string
|
||||||
|
StripEmaildomain bool
|
||||||
Expiry time.Duration
|
Expiry time.Duration
|
||||||
UseExpiryFromToken bool
|
UseExpiryFromToken bool
|
||||||
|
MapLegacyUsers bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type DERPConfig struct {
|
type DERPConfig struct {
|
||||||
|
@ -272,9 +275,11 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
||||||
|
|
||||||
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
||||||
|
viper.SetDefault("oidc.strip_email_domain", true)
|
||||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||||
viper.SetDefault("oidc.expiry", "180d")
|
viper.SetDefault("oidc.expiry", "180d")
|
||||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||||
|
viper.SetDefault("oidc.map_legacy_users", true)
|
||||||
|
|
||||||
viper.SetDefault("logtail.enabled", false)
|
viper.SetDefault("logtail.enabled", false)
|
||||||
viper.SetDefault("randomize_client_port", false)
|
viper.SetDefault("randomize_client_port", false)
|
||||||
|
@ -318,14 +323,18 @@ func validateServerConfig() error {
|
||||||
depr.warn("dns_config.use_username_in_magic_dns")
|
depr.warn("dns_config.use_username_in_magic_dns")
|
||||||
depr.warn("dns.use_username_in_magic_dns")
|
depr.warn("dns.use_username_in_magic_dns")
|
||||||
|
|
||||||
depr.fatal("oidc.strip_email_domain")
|
// TODO(kradalby): Reintroduce when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
// depr.fatal("oidc.strip_email_domain")
|
||||||
depr.fatal("dns.use_username_in_musername_in_magic_dns")
|
depr.fatal("dns.use_username_in_musername_in_magic_dns")
|
||||||
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
|
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
|
||||||
|
|
||||||
depr.Log()
|
depr.Log()
|
||||||
|
|
||||||
for _, removed := range []string{
|
for _, removed := range []string{
|
||||||
"oidc.strip_email_domain",
|
// TODO(kradalby): Reintroduce when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
// "oidc.strip_email_domain",
|
||||||
"dns_config.use_username_in_musername_in_magic_dns",
|
"dns_config.use_username_in_musername_in_magic_dns",
|
||||||
} {
|
} {
|
||||||
if viper.IsSet(removed) {
|
if viper.IsSet(removed) {
|
||||||
|
@ -827,11 +836,10 @@ func LoadServerConfig() (*Config, error) {
|
||||||
// - DERP run on their own domains
|
// - DERP run on their own domains
|
||||||
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
||||||
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
||||||
if dnsConfig.BaseDomain != "" &&
|
if dnsConfig.BaseDomain != "" {
|
||||||
strings.Contains(serverURL, dnsConfig.BaseDomain) {
|
if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil {
|
||||||
return nil, errors.New(
|
return nil, err
|
||||||
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
|
}
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
|
@ -897,6 +905,10 @@ func LoadServerConfig() (*Config, error) {
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
|
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
|
||||||
|
// TODO(kradalby): Remove when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||||
|
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
|
||||||
},
|
},
|
||||||
|
|
||||||
LogTail: logTailConfig,
|
LogTail: logTailConfig,
|
||||||
|
@ -924,6 +936,37 @@ func LoadServerConfig() (*Config, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BaseDomain cannot be a suffix of the server URL.
|
||||||
|
// This is because Tailscale takes over the domain in BaseDomain,
|
||||||
|
// causing the headscale server and DERP to be unreachable.
|
||||||
|
// For Tailscale upstream, the following is true:
|
||||||
|
// - DERP run on their own domains.
|
||||||
|
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com.
|
||||||
|
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net).
|
||||||
|
func isSafeServerURL(serverURL, baseDomain string) error {
|
||||||
|
server, err := url.Parse(serverURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverDomainParts := strings.Split(server.Host, ".")
|
||||||
|
baseDomainParts := strings.Split(baseDomain, ".")
|
||||||
|
|
||||||
|
if len(serverDomainParts) <= len(baseDomainParts) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s := len(serverDomainParts)
|
||||||
|
b := len(baseDomainParts)
|
||||||
|
for i := range len(baseDomainParts) {
|
||||||
|
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errServerURLSuffix
|
||||||
|
}
|
||||||
|
|
||||||
type deprecator struct {
|
type deprecator struct {
|
||||||
warns set.Set[string]
|
warns set.Set[string]
|
||||||
fatals set.Set[string]
|
fatals set.Set[string]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -139,7 +140,7 @@ func TestReadConfig(t *testing.T) {
|
||||||
return LoadServerConfig()
|
return LoadServerConfig()
|
||||||
},
|
},
|
||||||
want: nil,
|
want: nil,
|
||||||
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
|
wantErr: errServerURLSuffix.Error(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "base-domain-not-in-server-url",
|
name: "base-domain-not-in-server-url",
|
||||||
|
@ -333,3 +334,64 @@ tls_letsencrypt_challenge_type: TLS-ALPN-01
|
||||||
err = LoadConfig(tmpDir, false)
|
err = LoadConfig(tmpDir, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OK
|
||||||
|
// server_url: headscale.com, base: clients.headscale.com
|
||||||
|
// server_url: headscale.com, base: headscale.net
|
||||||
|
//
|
||||||
|
// NOT OK
|
||||||
|
// server_url: server.headscale.com, base: headscale.com.
|
||||||
|
func TestSafeServerURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
serverURL, baseDomain,
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
serverURL: "https://example.com",
|
||||||
|
baseDomain: "example.org",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://headscale.com",
|
||||||
|
baseDomain: "headscale.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://headscale.com",
|
||||||
|
baseDomain: "clients.headscale.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://headscale.com",
|
||||||
|
baseDomain: "clients.subdomain.headscale.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://headscale.kristoffer.com",
|
||||||
|
baseDomain: "mybase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://server.headscale.com",
|
||||||
|
baseDomain: "headscale.com",
|
||||||
|
wantErr: errServerURLSuffix.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://server.subdomain.headscale.com",
|
||||||
|
baseDomain: "headscale.com",
|
||||||
|
wantErr: errServerURLSuffix.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "http://foo\x00",
|
||||||
|
wantErr: `parse "http://foo\x00": net/url: invalid control character in URL`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
testName := fmt.Sprintf("server=%s domain=%s", tt.serverURL, tt.baseDomain)
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
err := isSafeServerURL(tt.serverURL, tt.baseDomain)
|
||||||
|
if tt.wantErr != "" {
|
||||||
|
assert.EqualError(t, err, tt.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -223,6 +223,16 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
|
||||||
return found
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool {
|
||||||
|
for _, node := range nodes {
|
||||||
|
if node.NodeKey == nodeKey {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (node *Node) Proto() *v1.Node {
|
func (node *Node) Proto() *v1.Node {
|
||||||
nodeProto := &v1.Node{
|
nodeProto := &v1.Node{
|
||||||
Id: uint64(node.ID),
|
Id: uint64(node.ID),
|
||||||
|
|
|
@ -26,7 +26,7 @@ type PreAuthKey struct {
|
||||||
|
|
||||||
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
||||||
protoKey := v1.PreAuthKey{
|
protoKey := v1.PreAuthKey{
|
||||||
User: key.User.Name,
|
User: key.User.Username(),
|
||||||
Id: strconv.FormatUint(key.ID, util.Base10),
|
Id: strconv.FormatUint(key.ID, util.Base10),
|
||||||
Key: key.Key,
|
Key: key.Key,
|
||||||
Ephemeral: key.Ephemeral,
|
Ephemeral: key.Ephemeral,
|
||||||
|
|
|
@ -8,7 +8,7 @@ prefixes:
|
||||||
database:
|
database:
|
||||||
type: sqlite3
|
type: sqlite3
|
||||||
|
|
||||||
server_url: "https://derp.no"
|
server_url: "https://server.derp.no"
|
||||||
|
|
||||||
dns:
|
dns:
|
||||||
magic_dns: true
|
magic_dns: true
|
||||||
|
|
|
@ -2,6 +2,7 @@ package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"database/sql"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
@ -19,10 +20,14 @@ type UserID uint64
|
||||||
// that contain our machines.
|
// that contain our machines.
|
||||||
type User struct {
|
type User struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
|
// The index `idx_name_provider_identifier` is to enforce uniqueness
|
||||||
|
// between Name and ProviderIdentifier. This ensures that
|
||||||
|
// you can have multiple users with the same name in OIDC,
|
||||||
|
// but not if you only run with CLI users.
|
||||||
|
|
||||||
// Username for the user, is used if email is empty
|
// Username for the user, is used if email is empty
|
||||||
// Should not be used, please use Username().
|
// Should not be used, please use Username().
|
||||||
Name string `gorm:"unique"`
|
Name string `gorm:"uniqueIndex:idx_name_provider_identifier;index"`
|
||||||
|
|
||||||
// Typically the full name of the user
|
// Typically the full name of the user
|
||||||
DisplayName string
|
DisplayName string
|
||||||
|
@ -34,7 +39,7 @@ type User struct {
|
||||||
// Unique identifier of the user from OIDC,
|
// Unique identifier of the user from OIDC,
|
||||||
// comes from `sub` claim in the OIDC token
|
// comes from `sub` claim in the OIDC token
|
||||||
// and is used to lookup the user.
|
// and is used to lookup the user.
|
||||||
ProviderIdentifier string `gorm:"index"`
|
ProviderIdentifier sql.NullString `gorm:"uniqueIndex:idx_name_provider_identifier;uniqueIndex:idx_provider_identifier"`
|
||||||
|
|
||||||
// Provider is the origin of the user account,
|
// Provider is the origin of the user account,
|
||||||
// same as RegistrationMethod, without authkey.
|
// same as RegistrationMethod, without authkey.
|
||||||
|
@ -51,7 +56,14 @@ type User struct {
|
||||||
// should be used throughout headscale, in information returned to the
|
// should be used throughout headscale, in information returned to the
|
||||||
// user and the Policy engine.
|
// user and the Policy engine.
|
||||||
func (u *User) Username() string {
|
func (u *User) Username() string {
|
||||||
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
|
username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10))
|
||||||
|
|
||||||
|
// TODO(kradalby): Wire up all of this for the future
|
||||||
|
// if !strings.Contains(username, "@") {
|
||||||
|
// username = username + "@"
|
||||||
|
// }
|
||||||
|
|
||||||
|
return username
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
||||||
|
@ -107,7 +119,7 @@ func (u *User) Proto() *v1.User {
|
||||||
CreatedAt: timestamppb.New(u.CreatedAt),
|
CreatedAt: timestamppb.New(u.CreatedAt),
|
||||||
DisplayName: u.DisplayName,
|
DisplayName: u.DisplayName,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
ProviderId: u.ProviderIdentifier,
|
ProviderId: u.ProviderIdentifier.String,
|
||||||
Provider: u.Provider,
|
Provider: u.Provider,
|
||||||
ProfilePicUrl: u.ProfilePicURL,
|
ProfilePicUrl: u.ProfilePicURL,
|
||||||
}
|
}
|
||||||
|
@ -116,6 +128,7 @@ func (u *User) Proto() *v1.User {
|
||||||
type OIDCClaims struct {
|
type OIDCClaims struct {
|
||||||
// Sub is the user's unique identifier at the provider.
|
// Sub is the user's unique identifier at the provider.
|
||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
|
||||||
// Name is the user's full name.
|
// Name is the user's full name.
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
@ -126,12 +139,18 @@ type OIDCClaims struct {
|
||||||
Username string `json:"preferred_username,omitempty"`
|
Username string `json:"preferred_username,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *OIDCClaims) Identifier() string {
|
||||||
|
return c.Iss + "/" + c.Sub
|
||||||
|
}
|
||||||
|
|
||||||
// FromClaim overrides a User from OIDC claims.
|
// FromClaim overrides a User from OIDC claims.
|
||||||
// All fields will be updated, except for the ID.
|
// All fields will be updated, except for the ID.
|
||||||
func (u *User) FromClaim(claims *OIDCClaims) {
|
func (u *User) FromClaim(claims *OIDCClaims) {
|
||||||
u.ProviderIdentifier = claims.Sub
|
u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true}
|
||||||
u.DisplayName = claims.Name
|
u.DisplayName = claims.Name
|
||||||
u.Email = claims.Email
|
if claims.EmailVerified {
|
||||||
|
u.Email = claims.Email
|
||||||
|
}
|
||||||
u.Name = claims.Username
|
u.Name = claims.Username
|
||||||
u.ProfilePicURL = claims.ProfilePictureURL
|
u.ProfilePicURL = claims.ProfilePictureURL
|
||||||
u.Provider = util.RegisterMethodOIDC
|
u.Provider = util.RegisterMethodOIDC
|
||||||
|
|
|
@ -182,3 +182,33 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||||
|
|
||||||
return fqdns
|
return fqdns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Reintroduce when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
// DEPRECATED: DO NOT USE
|
||||||
|
// NormalizeToFQDNRules will replace forbidden chars in user
|
||||||
|
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
||||||
|
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
||||||
|
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
name = strings.ReplaceAll(name, "'", "")
|
||||||
|
atIdx := strings.Index(name, "@")
|
||||||
|
if stripEmailDomain && atIdx > 0 {
|
||||||
|
name = name[:atIdx]
|
||||||
|
} else {
|
||||||
|
name = strings.ReplaceAll(name, "@", ".")
|
||||||
|
}
|
||||||
|
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
|
||||||
|
|
||||||
|
for _, elt := range strings.Split(name, ".") {
|
||||||
|
if len(elt) > LabelHostnameLength {
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"label %v is more than 63 chars: %w",
|
||||||
|
elt,
|
||||||
|
ErrInvalidUserName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
|
|
@ -11,10 +11,10 @@ Tests are located in files ending with `_test.go` and the framework are located
|
||||||
|
|
||||||
## Running integration tests locally
|
## Running integration tests locally
|
||||||
|
|
||||||
The easiest way to run tests locally is to use `[act](INSERT LINK)`, a local GitHub Actions runner:
|
The easiest way to run tests locally is to use [act](https://github.com/nektos/act), a local GitHub Actions runner:
|
||||||
|
|
||||||
```
|
```
|
||||||
act pull_request -W .github/workflows/test-integration-v2-TestPingAllByIP.yaml
|
act pull_request -W .github/workflows/test-integration.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, the `docker run` command in each GitHub workflow file can be used.
|
Alternatively, the `docker run` command in each GitHub workflow file can be used.
|
||||||
|
|
|
@ -3,6 +3,7 @@ package integration
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -10,14 +11,19 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
|
"github.com/oauth2-proxy/mockoidc"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
"github.com/ory/dockertest/v3/docker"
|
"github.com/ory/dockertest/v3/docker"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
@ -48,20 +54,34 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
scenario := AuthOIDCScenario{
|
scenario := AuthOIDCScenario{
|
||||||
Scenario: baseScenario,
|
Scenario: baseScenario,
|
||||||
}
|
}
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
// defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
// Logins to MockOIDC is served by a queue with a strict order,
|
||||||
|
// if we use more than one node per user, the order of the logins
|
||||||
|
// will not be deterministic and the test will fail.
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": len(MustTestVersions),
|
"user1": 1,
|
||||||
|
"user2": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL)
|
mockusers := []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
oidcMockUser("user2", false),
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
|
||||||
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||||
|
defer scenario.mockOIDC.Close()
|
||||||
|
|
||||||
oidcMap := map[string]string{
|
oidcMap := map[string]string{
|
||||||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
// TODO(kradalby): Remove when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(
|
err = scenario.CreateHeadscaleEnv(
|
||||||
|
@ -91,6 +111,55 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
|
|
||||||
success := pingAllHelper(t, allClients, allAddrs)
|
success := pingAllHelper(t, allClients, allAddrs)
|
||||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
var listUsers []v1.User
|
||||||
|
err = executeAndUnmarshal(headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"users",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
&listUsers,
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
want := []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Email: "", // Unverified
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].Id < listUsers[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||||
|
t.Fatalf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This test is really flaky.
|
// This test is really flaky.
|
||||||
|
@ -111,11 +180,16 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": 3,
|
"user1": 1,
|
||||||
|
"user2": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcConfig, err := scenario.runMockOIDC(shortAccessTTL)
|
oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
oidcMockUser("user2", false),
|
||||||
|
})
|
||||||
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||||
|
defer scenario.mockOIDC.Close()
|
||||||
|
|
||||||
oidcMap := map[string]string{
|
oidcMap := map[string]string{
|
||||||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
|
@ -159,6 +233,297 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||||
assertTailscaleNodesLogout(t, allClients)
|
assertTailscaleNodesLogout(t, allClients)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby):
|
||||||
|
// - Test that creates a new user when one exists when migration is turned off
|
||||||
|
// - Test that takes over a user when one exists when migration is turned on
|
||||||
|
// - But email is not verified
|
||||||
|
// - stripped email domain on/off
|
||||||
|
func TestOIDC024UserCreation(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config map[string]string
|
||||||
|
emailVerified bool
|
||||||
|
cliUsers []string
|
||||||
|
oidcUsers []string
|
||||||
|
want func(iss string) []v1.User
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-migration-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||||
|
},
|
||||||
|
emailVerified: true,
|
||||||
|
cliUsers: []string{"user1", "user2"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Email: "user2@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-migration-not-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||||
|
},
|
||||||
|
emailVerified: false,
|
||||||
|
cliUsers: []string{"user1", "user2"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "migration-strip-domains-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "1",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1",
|
||||||
|
},
|
||||||
|
emailVerified: true,
|
||||||
|
cliUsers: []string{"user1", "user2"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user2",
|
||||||
|
Email: "user2@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "migration-strip-domains-not-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "1",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1",
|
||||||
|
},
|
||||||
|
emailVerified: false,
|
||||||
|
cliUsers: []string{"user1", "user2"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "migration-no-strip-domains-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "1",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||||
|
},
|
||||||
|
emailVerified: true,
|
||||||
|
cliUsers: []string{"user1.headscale.net", "user2.headscale.net"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
// Hmm I think we will have to overwrite the initial name here
|
||||||
|
// createuser with "user1.headscale.net", but oidc with "user1"
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user2",
|
||||||
|
Email: "user2@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "migration-no-strip-domains-not-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "1",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||||
|
},
|
||||||
|
emailVerified: false,
|
||||||
|
cliUsers: []string{"user1.headscale.net", "user2.headscale.net"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1.headscale.net",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2.headscale.net",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
scenario := AuthOIDCScenario{
|
||||||
|
Scenario: baseScenario,
|
||||||
|
}
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
spec := map[string]int{}
|
||||||
|
for _, user := range tt.cliUsers {
|
||||||
|
spec[user] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var mockusers []mockoidc.MockUser
|
||||||
|
for _, user := range tt.oidcUsers {
|
||||||
|
mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified))
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
|
||||||
|
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||||
|
defer scenario.mockOIDC.Close()
|
||||||
|
|
||||||
|
oidcMap := map[string]string{
|
||||||
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
|
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||||
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range tt.config {
|
||||||
|
oidcMap[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(
|
||||||
|
spec,
|
||||||
|
hsic.WithTestName("oidcmigration"),
|
||||||
|
hsic.WithConfigEnv(oidcMap),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithHostnameAsServerURL(),
|
||||||
|
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
|
||||||
|
)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
// Ensure that the nodes have logged in, this is what
|
||||||
|
// triggers user creation via OIDC.
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
want := tt.want(oidcConfig.Issuer)
|
||||||
|
|
||||||
|
var listUsers []v1.User
|
||||||
|
err = executeAndUnmarshal(headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"users",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
&listUsers,
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
sort.Slice(listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].Id < listUsers[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||||
|
t.Errorf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||||
users map[string]int,
|
users map[string]int,
|
||||||
opts ...hsic.Option,
|
opts ...hsic.Option,
|
||||||
|
@ -174,6 +539,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||||
}
|
}
|
||||||
|
|
||||||
for userName, clientCount := range users {
|
for userName, clientCount := range users {
|
||||||
|
if clientCount != 1 {
|
||||||
|
// OIDC scenario only supports one client per user.
|
||||||
|
// This is because the MockOIDC server can only serve login
|
||||||
|
// requests based on a queue it has been given on startup.
|
||||||
|
// We currently only populates it with one login request per user.
|
||||||
|
return fmt.Errorf("client count must be 1 for OIDC scenario.")
|
||||||
|
}
|
||||||
log.Printf("creating user %s with %d clients", userName, clientCount)
|
log.Printf("creating user %s with %d clients", userName, clientCount)
|
||||||
err = s.CreateUser(userName)
|
err = s.CreateUser(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -194,7 +566,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) {
|
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) {
|
||||||
port, err := dockertestutil.RandomFreeHostPort()
|
port, err := dockertestutil.RandomFreeHostPort()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("could not find an open port: %s", err)
|
log.Fatalf("could not find an open port: %s", err)
|
||||||
|
@ -205,6 +577,11 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
|
||||||
|
|
||||||
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
|
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
|
||||||
|
|
||||||
|
usersJSON, err := json.Marshal(users)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
mockOidcOptions := &dockertest.RunOptions{
|
mockOidcOptions := &dockertest.RunOptions{
|
||||||
Name: hostname,
|
Name: hostname,
|
||||||
Cmd: []string{"headscale", "mockoidc"},
|
Cmd: []string{"headscale", "mockoidc"},
|
||||||
|
@ -219,6 +596,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
|
||||||
"MOCKOIDC_CLIENT_ID=superclient",
|
"MOCKOIDC_CLIENT_ID=superclient",
|
||||||
"MOCKOIDC_CLIENT_SECRET=supersecret",
|
"MOCKOIDC_CLIENT_SECRET=supersecret",
|
||||||
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
|
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
|
||||||
|
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,45 +688,40 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||||
|
|
||||||
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
|
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
|
||||||
|
|
||||||
if err := s.pool.Retry(func() error {
|
log.Printf("%s logging in with url", c.Hostname())
|
||||||
log.Printf("%s logging in with url", c.Hostname())
|
httpClient := &http.Client{Transport: insecureTransport}
|
||||||
httpClient := &http.Client{Transport: insecureTransport}
|
ctx := context.Background()
|
||||||
ctx := context.Background()
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
||||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
resp, err := httpClient.Do(req)
|
||||||
resp, err := httpClient.Do(req)
|
if err != nil {
|
||||||
if err != nil {
|
log.Printf(
|
||||||
log.Printf(
|
"%s failed to login using url %s: %s",
|
||||||
"%s failed to login using url %s: %s",
|
c.Hostname(),
|
||||||
c.Hostname(),
|
loginURL,
|
||||||
loginURL,
|
err,
|
||||||
err,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
|
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
log.Printf("body: %s", body)
|
||||||
|
|
||||||
return errStatusCodeNotOK
|
return errStatusCodeNotOK
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
_, err = io.ReadAll(resp.Body)
|
_, err = io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("%s failed to read response body: %s", c.Hostname(), err)
|
log.Printf("%s failed to read response body: %s", c.Hostname(), err)
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Finished request for %s to join tailnet", c.Hostname())
|
log.Printf("Finished request for %s to join tailnet", c.Hostname())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -395,3 +768,12 @@ func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
|
||||||
assert.Equal(t, "NeedsLogin", status.BackendState)
|
assert.Equal(t, "NeedsLogin", status.BackendState)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
|
||||||
|
return mockoidc.MockUser{
|
||||||
|
Subject: username,
|
||||||
|
PreferredUsername: username,
|
||||||
|
Email: fmt.Sprintf("%s@headscale.net", username),
|
||||||
|
EmailVerified: emailVerified,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -212,7 +212,9 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"})
|
tags := listedPreAuthKeys[index].GetAclTags()
|
||||||
|
sort.Strings(tags)
|
||||||
|
assert.Equal(t, []string{"tag:test1", "tag:test2"}, tags)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test key expiry
|
// Test key expiry
|
||||||
|
|
96
integration/derp_verify_endpoint_test.go
Normal file
96
integration/derp_verify_endpoint_test.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/juanfont/headscale/integration/dsic"
|
||||||
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDERPVerifyEndpoint(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
// Generate random hostname for the headscale instance
|
||||||
|
hash, err := util.GenerateRandomStringDNSSafe(6)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
testName := "derpverify"
|
||||||
|
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
|
||||||
|
|
||||||
|
headscalePort := 8080
|
||||||
|
|
||||||
|
// Create cert for headscale
|
||||||
|
certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
spec := map[string]int{
|
||||||
|
"user1": len(MustTestVersions),
|
||||||
|
}
|
||||||
|
|
||||||
|
derper, err := scenario.CreateDERPServer("head",
|
||||||
|
dsic.WithCACert(certHeadscale),
|
||||||
|
dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))),
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
derpMap := tailcfg.DERPMap{
|
||||||
|
Regions: map[int]*tailcfg.DERPRegion{
|
||||||
|
900: {
|
||||||
|
RegionID: 900,
|
||||||
|
RegionCode: "test-derpverify",
|
||||||
|
RegionName: "TestDerpVerify",
|
||||||
|
Nodes: []*tailcfg.DERPNode{
|
||||||
|
{
|
||||||
|
Name: "TestDerpVerify",
|
||||||
|
RegionID: 900,
|
||||||
|
HostName: derper.GetHostname(),
|
||||||
|
STUNPort: derper.GetSTUNPort(),
|
||||||
|
STUNOnly: false,
|
||||||
|
DERPPort: derper.GetDERPPort(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())},
|
||||||
|
hsic.WithHostname(hostname),
|
||||||
|
hsic.WithPort(headscalePort),
|
||||||
|
hsic.WithCustomTLS(certHeadscale, keyHeadscale),
|
||||||
|
hsic.WithHostnameAsServerURL(),
|
||||||
|
hsic.WithDERPConfig(derpMap))
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
for _, client := range allClients {
|
||||||
|
report, err := client.DebugDERPRegion("test-derpverify")
|
||||||
|
assertNoErr(t, err)
|
||||||
|
successful := false
|
||||||
|
for _, line := range report.Info {
|
||||||
|
if strings.Contains(line, "Successfully established a DERP connection with node") {
|
||||||
|
successful = true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !successful {
|
||||||
|
stJSON, err := json.Marshal(report)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
t.Errorf("Client %s could not establish a DERP connection: %s", client.Hostname(), string(stJSON))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -74,7 +74,7 @@ func ExecuteCommand(
|
||||||
select {
|
select {
|
||||||
case res := <-resultChan:
|
case res := <-resultChan:
|
||||||
if res.err != nil {
|
if res.err != nil {
|
||||||
return stdout.String(), stderr.String(), res.err
|
return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), res.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.exitCode != 0 {
|
if res.exitCode != 0 {
|
||||||
|
@ -83,12 +83,12 @@ func ExecuteCommand(
|
||||||
// log.Println("stdout: ", stdout.String())
|
// log.Println("stdout: ", stdout.String())
|
||||||
// log.Println("stderr: ", stderr.String())
|
// log.Println("stderr: ", stderr.String())
|
||||||
|
|
||||||
return stdout.String(), stderr.String(), ErrDockertestCommandFailed
|
return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
return stdout.String(), stderr.String(), nil
|
return stdout.String(), stderr.String(), nil
|
||||||
case <-time.After(execConfig.timeout):
|
case <-time.After(execConfig.timeout):
|
||||||
|
|
||||||
return stdout.String(), stderr.String(), ErrDockertestCommandTimeout
|
return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
321
integration/dsic/dsic.go
Normal file
321
integration/dsic/dsic.go
Normal file
|
@ -0,0 +1,321 @@
|
||||||
|
package dsic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
|
"github.com/ory/dockertest/v3"
|
||||||
|
"github.com/ory/dockertest/v3/docker"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dsicHashLength = 6
|
||||||
|
dockerContextPath = "../."
|
||||||
|
caCertRoot = "/usr/local/share/ca-certificates"
|
||||||
|
DERPerCertRoot = "/usr/local/share/derper-certs"
|
||||||
|
dockerExecuteTimeout = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var errDERPerStatusCodeNotOk = errors.New("DERPer status code not OK")
|
||||||
|
|
||||||
|
// DERPServerInContainer represents DERP Server in Container (DSIC).
|
||||||
|
type DERPServerInContainer struct {
|
||||||
|
version string
|
||||||
|
hostname string
|
||||||
|
|
||||||
|
pool *dockertest.Pool
|
||||||
|
container *dockertest.Resource
|
||||||
|
network *dockertest.Network
|
||||||
|
|
||||||
|
stunPort int
|
||||||
|
derpPort int
|
||||||
|
caCerts [][]byte
|
||||||
|
tlsCert []byte
|
||||||
|
tlsKey []byte
|
||||||
|
withExtraHosts []string
|
||||||
|
withVerifyClientURL string
|
||||||
|
workdir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option represent optional settings that can be given to a
|
||||||
|
// DERPer instance.
|
||||||
|
type Option = func(c *DERPServerInContainer)
|
||||||
|
|
||||||
|
// WithCACert adds it to the trusted surtificate of the Tailscale container.
|
||||||
|
func WithCACert(cert []byte) Option {
|
||||||
|
return func(dsic *DERPServerInContainer) {
|
||||||
|
dsic.caCerts = append(dsic.caCerts, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOrCreateNetwork sets the Docker container network to use with
|
||||||
|
// the DERPer instance, if the parameter is nil, a new network,
|
||||||
|
// isolating the DERPer, will be created. If a network is
|
||||||
|
// passed, the DERPer instance will join the given network.
|
||||||
|
func WithOrCreateNetwork(network *dockertest.Network) Option {
|
||||||
|
return func(tsic *DERPServerInContainer) {
|
||||||
|
if network != nil {
|
||||||
|
tsic.network = network
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
network, err := dockertestutil.GetFirstOrCreateNetwork(
|
||||||
|
tsic.pool,
|
||||||
|
tsic.hostname+"-network",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create network: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tsic.network = network
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDockerWorkdir allows the docker working directory to be set.
|
||||||
|
func WithDockerWorkdir(dir string) Option {
|
||||||
|
return func(tsic *DERPServerInContainer) {
|
||||||
|
tsic.workdir = dir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithVerifyClientURL sets the URL to verify the client.
|
||||||
|
func WithVerifyClientURL(url string) Option {
|
||||||
|
return func(tsic *DERPServerInContainer) {
|
||||||
|
tsic.withVerifyClientURL = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExtraHosts adds extra hosts to the container.
|
||||||
|
func WithExtraHosts(hosts []string) Option {
|
||||||
|
return func(tsic *DERPServerInContainer) {
|
||||||
|
tsic.withExtraHosts = hosts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new TailscaleInContainer instance.
|
||||||
|
func New(
|
||||||
|
pool *dockertest.Pool,
|
||||||
|
version string,
|
||||||
|
network *dockertest.Network,
|
||||||
|
opts ...Option,
|
||||||
|
) (*DERPServerInContainer, error) {
|
||||||
|
hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
|
||||||
|
tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err)
|
||||||
|
}
|
||||||
|
dsic := &DERPServerInContainer{
|
||||||
|
version: version,
|
||||||
|
hostname: hostname,
|
||||||
|
pool: pool,
|
||||||
|
network: network,
|
||||||
|
tlsCert: tlsCert,
|
||||||
|
tlsKey: tlsKey,
|
||||||
|
stunPort: 3478, //nolint
|
||||||
|
derpPort: 443, //nolint
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(dsic)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmdArgs strings.Builder
|
||||||
|
fmt.Fprintf(&cmdArgs, "--hostname=%s", hostname)
|
||||||
|
fmt.Fprintf(&cmdArgs, " --certmode=manual")
|
||||||
|
fmt.Fprintf(&cmdArgs, " --certdir=%s", DERPerCertRoot)
|
||||||
|
fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort)
|
||||||
|
fmt.Fprintf(&cmdArgs, " --stun=true")
|
||||||
|
fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort)
|
||||||
|
if dsic.withVerifyClientURL != "" {
|
||||||
|
fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
runOptions := &dockertest.RunOptions{
|
||||||
|
Name: hostname,
|
||||||
|
Networks: []*dockertest.Network{dsic.network},
|
||||||
|
ExtraHosts: dsic.withExtraHosts,
|
||||||
|
// we currently need to give us some time to inject the certificate further down.
|
||||||
|
Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()},
|
||||||
|
ExposedPorts: []string{
|
||||||
|
"80/tcp",
|
||||||
|
fmt.Sprintf("%d/tcp", dsic.derpPort),
|
||||||
|
fmt.Sprintf("%d/udp", dsic.stunPort),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if dsic.workdir != "" {
|
||||||
|
runOptions.WorkingDir = dsic.workdir
|
||||||
|
}
|
||||||
|
|
||||||
|
// dockertest isnt very good at handling containers that has already
|
||||||
|
// been created, this is an attempt to make sure this container isnt
|
||||||
|
// present.
|
||||||
|
err = pool.RemoveContainerByName(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var container *dockertest.Resource
|
||||||
|
buildOptions := &dockertest.BuildOptions{
|
||||||
|
Dockerfile: "Dockerfile.derper",
|
||||||
|
ContextDir: dockerContextPath,
|
||||||
|
BuildArgs: []docker.BuildArg{},
|
||||||
|
}
|
||||||
|
switch version {
|
||||||
|
case "head":
|
||||||
|
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
|
||||||
|
Name: "VERSION_BRANCH",
|
||||||
|
Value: "main",
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
|
||||||
|
Name: "VERSION_BRANCH",
|
||||||
|
Value: "v" + version,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
container, err = pool.BuildAndRunWithBuildOptions(
|
||||||
|
buildOptions,
|
||||||
|
runOptions,
|
||||||
|
dockertestutil.DockerRestartPolicy,
|
||||||
|
dockertestutil.DockerAllowLocalIPv6,
|
||||||
|
dockertestutil.DockerAllowNetworkAdministration,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"%s could not start tailscale DERPer container (version: %s): %w",
|
||||||
|
hostname,
|
||||||
|
version,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
log.Printf("Created %s container\n", hostname)
|
||||||
|
|
||||||
|
dsic.container = container
|
||||||
|
|
||||||
|
for i, cert := range dsic.caCerts {
|
||||||
|
err = dsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(dsic.tlsCert) != 0 {
|
||||||
|
err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(dsic.tlsKey) != 0 {
|
||||||
|
err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write TLS key to container: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dsic, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown stops and cleans up the DERPer container.
|
||||||
|
func (t *DERPServerInContainer) Shutdown() error {
|
||||||
|
err := t.SaveLog("/tmp/control")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf(
|
||||||
|
"Failed to save log from %s: %s",
|
||||||
|
t.hostname,
|
||||||
|
fmt.Errorf("failed to save log: %w", err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.pool.Purge(t.container)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCert returns the TLS certificate of the DERPer instance.
|
||||||
|
func (t *DERPServerInContainer) GetCert() []byte {
|
||||||
|
return t.tlsCert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hostname returns the hostname of the DERPer instance.
|
||||||
|
func (t *DERPServerInContainer) Hostname() string {
|
||||||
|
return t.hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version returns the running DERPer version of the instance.
|
||||||
|
func (t *DERPServerInContainer) Version() string {
|
||||||
|
return t.version
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the Docker container ID of the DERPServerInContainer
|
||||||
|
// instance.
|
||||||
|
func (t *DERPServerInContainer) ID() string {
|
||||||
|
return t.container.Container.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DERPServerInContainer) GetHostname() string {
|
||||||
|
return t.hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSTUNPort returns the STUN port of the DERPer instance.
|
||||||
|
func (t *DERPServerInContainer) GetSTUNPort() int {
|
||||||
|
return t.stunPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDERPPort returns the DERP port of the DERPer instance.
|
||||||
|
func (t *DERPServerInContainer) GetDERPPort() int {
|
||||||
|
return t.derpPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForRunning blocks until the DERPer instance is ready to be used.
|
||||||
|
func (t *DERPServerInContainer) WaitForRunning() error {
|
||||||
|
url := "https://" + net.JoinHostPort(t.GetHostname(), strconv.Itoa(t.GetDERPPort())) + "/"
|
||||||
|
log.Printf("waiting for DERPer to be ready at %s", url)
|
||||||
|
|
||||||
|
insecureTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint
|
||||||
|
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint
|
||||||
|
client := &http.Client{Transport: insecureTransport}
|
||||||
|
|
||||||
|
return t.pool.Retry(func() error {
|
||||||
|
resp, err := client.Get(url) //nolint
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("headscale is not ready: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return errDERPerStatusCodeNotOk
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectToNetwork connects the DERPer instance to a network.
|
||||||
|
func (t *DERPServerInContainer) ConnectToNetwork(network *dockertest.Network) error {
|
||||||
|
return t.container.ConnectToNetwork(network)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFile save file inside the container.
|
||||||
|
func (t *DERPServerInContainer) WriteFile(path string, data []byte) error {
|
||||||
|
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveLog saves the current stdout log of the container to a path
|
||||||
|
// on the host system.
|
||||||
|
func (t *DERPServerInContainer) SaveLog(path string) error {
|
||||||
|
_, _, err := dockertestutil.SaveLog(t.pool, t.container, path)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
|
@ -55,7 +55,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
|
||||||
spec := map[string]ClientsSpec{
|
spec := map[string]ClientsSpec{
|
||||||
"user1": {
|
"user1": {
|
||||||
Plain: 0,
|
Plain: 0,
|
||||||
WebsocketDERP: len(MustTestVersions),
|
WebsocketDERP: 2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -239,10 +239,13 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv(
|
||||||
|
|
||||||
if clientCount.WebsocketDERP > 0 {
|
if clientCount.WebsocketDERP > 0 {
|
||||||
// Containers that use DERP-over-WebSocket
|
// Containers that use DERP-over-WebSocket
|
||||||
|
// Note that these clients *must* be built
|
||||||
|
// from source, which is currently
|
||||||
|
// only done for HEAD.
|
||||||
err = s.CreateTailscaleIsolatedNodesInUser(
|
err = s.CreateTailscaleIsolatedNodesInUser(
|
||||||
hash,
|
hash,
|
||||||
userName,
|
userName,
|
||||||
"all",
|
tsic.VersionHead,
|
||||||
clientCount.WebsocketDERP,
|
clientCount.WebsocketDERP,
|
||||||
tsic.WithWebsocketDERP(true),
|
tsic.WithWebsocketDERP(true),
|
||||||
)
|
)
|
||||||
|
@ -307,7 +310,7 @@ func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser(
|
||||||
cert := hsServer.GetCert()
|
cert := hsServer.GetCert()
|
||||||
|
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
tsic.WithHeadscaleTLS(cert),
|
tsic.WithCACert(cert),
|
||||||
)
|
)
|
||||||
|
|
||||||
user.createWaitGroup.Go(func() error {
|
user.createWaitGroup.Go(func() error {
|
||||||
|
|
|
@ -1,19 +1,12 @@
|
||||||
package hsic
|
package hsic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -32,11 +25,14 @@ import (
|
||||||
"github.com/juanfont/headscale/integration/integrationutil"
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
"github.com/ory/dockertest/v3/docker"
|
"github.com/ory/dockertest/v3/docker"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
hsicHashLength = 6
|
hsicHashLength = 6
|
||||||
dockerContextPath = "../."
|
dockerContextPath = "../."
|
||||||
|
caCertRoot = "/usr/local/share/ca-certificates"
|
||||||
aclPolicyPath = "/etc/headscale/acl.hujson"
|
aclPolicyPath = "/etc/headscale/acl.hujson"
|
||||||
tlsCertPath = "/etc/headscale/tls.cert"
|
tlsCertPath = "/etc/headscale/tls.cert"
|
||||||
tlsKeyPath = "/etc/headscale/tls.key"
|
tlsKeyPath = "/etc/headscale/tls.key"
|
||||||
|
@ -64,6 +60,7 @@ type HeadscaleInContainer struct {
|
||||||
// optional config
|
// optional config
|
||||||
port int
|
port int
|
||||||
extraPorts []string
|
extraPorts []string
|
||||||
|
caCerts [][]byte
|
||||||
hostPortBindings map[string][]string
|
hostPortBindings map[string][]string
|
||||||
aclPolicy *policy.ACLPolicy
|
aclPolicy *policy.ACLPolicy
|
||||||
env map[string]string
|
env map[string]string
|
||||||
|
@ -88,18 +85,29 @@ func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithCACert adds it to the trusted surtificate of the container.
|
||||||
|
func WithCACert(cert []byte) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
hsic.caCerts = append(hsic.caCerts, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithTLS creates certificates and enables HTTPS.
|
// WithTLS creates certificates and enables HTTPS.
|
||||||
func WithTLS() Option {
|
func WithTLS() Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
cert, key, err := createCertificate(hsic.hostname)
|
cert, key, err := integrationutil.CreateCertificate(hsic.hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create certificates for headscale test: %s", err)
|
log.Fatalf("failed to create certificates for headscale test: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Move somewhere appropriate
|
hsic.tlsCert = cert
|
||||||
hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath
|
hsic.tlsKey = key
|
||||||
hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCustomTLS uses the given certificates for the Headscale instance.
|
||||||
|
func WithCustomTLS(cert, key []byte) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
hsic.tlsCert = cert
|
hsic.tlsCert = cert
|
||||||
hsic.tlsKey = key
|
hsic.tlsKey = key
|
||||||
}
|
}
|
||||||
|
@ -146,6 +154,13 @@ func WithTestName(testName string) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithHostname sets the hostname of the Headscale instance.
|
||||||
|
func WithHostname(hostname string) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
hsic.hostname = hostname
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithHostnameAsServerURL sets the Headscale ServerURL based on
|
// WithHostnameAsServerURL sets the Headscale ServerURL based on
|
||||||
// the Hostname.
|
// the Hostname.
|
||||||
func WithHostnameAsServerURL() Option {
|
func WithHostnameAsServerURL() Option {
|
||||||
|
@ -203,6 +218,34 @@ func WithEmbeddedDERPServerOnly() Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDERPConfig configures Headscale use a custom
|
||||||
|
// DERP server only.
|
||||||
|
func WithDERPConfig(derpMap tailcfg.DERPMap) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
contents, err := yaml.Marshal(derpMap)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to marshal DERP map: %s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hsic.env["HEADSCALE_DERP_PATHS"] = "/etc/headscale/derp.yml"
|
||||||
|
hsic.filesInContainer = append(hsic.filesInContainer,
|
||||||
|
fileInContainer{
|
||||||
|
path: "/etc/headscale/derp.yml",
|
||||||
|
contents: contents,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Disable global DERP server and embedded DERP server
|
||||||
|
hsic.env["HEADSCALE_DERP_URLS"] = ""
|
||||||
|
hsic.env["HEADSCALE_DERP_SERVER_ENABLED"] = "false"
|
||||||
|
|
||||||
|
// Envknob for enabling DERP debug logs
|
||||||
|
hsic.env["DERP_DEBUG_LOGS"] = "true"
|
||||||
|
hsic.env["DERP_PROBER_DEBUG_LOGS"] = "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithTuning allows changing the tuning settings easily.
|
// WithTuning allows changing the tuning settings easily.
|
||||||
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
|
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
@ -300,6 +343,10 @@ func New(
|
||||||
"HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1",
|
"HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1",
|
||||||
"HEADSCALE_DEBUG_DUMP_CONFIG=1",
|
"HEADSCALE_DEBUG_DUMP_CONFIG=1",
|
||||||
}
|
}
|
||||||
|
if hsic.hasTLS() {
|
||||||
|
hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath
|
||||||
|
hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath
|
||||||
|
}
|
||||||
for key, value := range hsic.env {
|
for key, value := range hsic.env {
|
||||||
env = append(env, fmt.Sprintf("%s=%s", key, value))
|
env = append(env, fmt.Sprintf("%s=%s", key, value))
|
||||||
}
|
}
|
||||||
|
@ -313,7 +360,7 @@ func New(
|
||||||
// Cmd: []string{"headscale", "serve"},
|
// Cmd: []string{"headscale", "serve"},
|
||||||
// TODO(kradalby): Get rid of this hack, we currently need to give us some
|
// TODO(kradalby): Get rid of this hack, we currently need to give us some
|
||||||
// to inject the headscale configuration further down.
|
// to inject the headscale configuration further down.
|
||||||
Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; headscale serve ; /bin/sleep 30"},
|
Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; update-ca-certificates ; headscale serve ; /bin/sleep 30"},
|
||||||
Env: env,
|
Env: env,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -351,6 +398,14 @@ func New(
|
||||||
|
|
||||||
hsic.container = container
|
hsic.container = container
|
||||||
|
|
||||||
|
// Write the CA certificates to the container
|
||||||
|
for i, cert := range hsic.caCerts {
|
||||||
|
err = hsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML()))
|
err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to write headscale config to container: %w", err)
|
return nil, fmt.Errorf("failed to write headscale config to container: %w", err)
|
||||||
|
@ -749,86 +804,3 @@ func (t *HeadscaleInContainer) SendInterrupt() error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint
|
|
||||||
func createCertificate(hostname string) ([]byte, []byte, error) {
|
|
||||||
// From:
|
|
||||||
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
|
|
||||||
|
|
||||||
ca := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(2019),
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"Headscale testing INC"},
|
|
||||||
Country: []string{"NL"},
|
|
||||||
Locality: []string{"Leiden"},
|
|
||||||
},
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(60 * time.Hour),
|
|
||||||
IsCA: true,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
||||||
x509.ExtKeyUsageClientAuth,
|
|
||||||
x509.ExtKeyUsageServerAuth,
|
|
||||||
},
|
|
||||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cert := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1658),
|
|
||||||
Subject: pkix.Name{
|
|
||||||
CommonName: hostname,
|
|
||||||
Organization: []string{"Headscale testing INC"},
|
|
||||||
Country: []string{"NL"},
|
|
||||||
Locality: []string{"Leiden"},
|
|
||||||
},
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(60 * time.Minute),
|
|
||||||
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
||||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
||||||
DNSNames: []string{hostname},
|
|
||||||
}
|
|
||||||
|
|
||||||
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
certBytes, err := x509.CreateCertificate(
|
|
||||||
rand.Reader,
|
|
||||||
cert,
|
|
||||||
ca,
|
|
||||||
&certPrivKey.PublicKey,
|
|
||||||
caPrivKey,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
certPEM := new(bytes.Buffer)
|
|
||||||
|
|
||||||
err = pem.Encode(certPEM, &pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: certBytes,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
certPrivKeyPEM := new(bytes.Buffer)
|
|
||||||
|
|
||||||
err = pem.Encode(certPrivKeyPEM, &pem.Block{
|
|
||||||
Type: "RSA PRIVATE KEY",
|
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,9 +3,16 @@ package integrationutil
|
||||||
import (
|
import (
|
||||||
"archive/tar"
|
"archive/tar"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math/big"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
|
@ -93,3 +100,86 @@ func FetchPathFromContainer(
|
||||||
|
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint
|
||||||
|
func CreateCertificate(hostname string) ([]byte, []byte, error) {
|
||||||
|
// From:
|
||||||
|
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
|
||||||
|
|
||||||
|
ca := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2019),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Headscale testing INC"},
|
||||||
|
Country: []string{"NL"},
|
||||||
|
Locality: []string{"Leiden"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(60 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageClientAuth,
|
||||||
|
x509.ExtKeyUsageServerAuth,
|
||||||
|
},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1658),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: hostname,
|
||||||
|
Organization: []string{"Headscale testing INC"},
|
||||||
|
Country: []string{"NL"},
|
||||||
|
Locality: []string{"Leiden"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(60 * time.Minute),
|
||||||
|
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
DNSNames: []string{hostname},
|
||||||
|
}
|
||||||
|
|
||||||
|
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certBytes, err := x509.CreateCertificate(
|
||||||
|
rand.Reader,
|
||||||
|
cert,
|
||||||
|
ca,
|
||||||
|
&certPrivKey.PublicKey,
|
||||||
|
caPrivKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM := new(bytes.Buffer)
|
||||||
|
|
||||||
|
err = pem.Encode(certPEM, &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: certBytes,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certPrivKeyPEM := new(bytes.Buffer)
|
||||||
|
|
||||||
|
err = pem.Encode(certPrivKeyPEM, &pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
|
"github.com/juanfont/headscale/integration/dsic"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
|
@ -140,6 +141,7 @@ type Scenario struct {
|
||||||
// TODO(kradalby): support multiple headcales for later, currently only
|
// TODO(kradalby): support multiple headcales for later, currently only
|
||||||
// use one.
|
// use one.
|
||||||
controlServers *xsync.MapOf[string, ControlServer]
|
controlServers *xsync.MapOf[string, ControlServer]
|
||||||
|
derpServers []*dsic.DERPServerInContainer
|
||||||
|
|
||||||
users map[string]*User
|
users map[string]*User
|
||||||
|
|
||||||
|
@ -224,6 +226,13 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, derp := range s.derpServers {
|
||||||
|
err := derp.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to tear down derp server: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.pool.RemoveNetwork(s.network); err != nil {
|
if err := s.pool.RemoveNetwork(s.network); err != nil {
|
||||||
log.Printf("failed to remove network: %s", err)
|
log.Printf("failed to remove network: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -352,7 +361,7 @@ func (s *Scenario) CreateTailscaleNodesInUser(
|
||||||
hostname := headscale.GetHostname()
|
hostname := headscale.GetHostname()
|
||||||
|
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
tsic.WithHeadscaleTLS(cert),
|
tsic.WithCACert(cert),
|
||||||
tsic.WithHeadscaleName(hostname),
|
tsic.WithHeadscaleName(hostname),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -651,3 +660,20 @@ func (s *Scenario) WaitForTailscaleLogout() error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateDERPServer creates a new DERP server in a container.
|
||||||
|
func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) {
|
||||||
|
derp, err := dsic.New(s.pool, version, s.network, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create DERP server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = derp.WaitForRunning()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to reach DERP server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.derpServers = append(s.derpServers, derp)
|
||||||
|
|
||||||
|
return derp, nil
|
||||||
|
}
|
||||||
|
|
|
@ -30,6 +30,7 @@ type TailscaleClient interface {
|
||||||
FQDN() (string, error)
|
FQDN() (string, error)
|
||||||
Status(...bool) (*ipnstate.Status, error)
|
Status(...bool) (*ipnstate.Status, error)
|
||||||
Netmap() (*netmap.NetworkMap, error)
|
Netmap() (*netmap.NetworkMap, error)
|
||||||
|
DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error)
|
||||||
Netcheck() (*netcheck.Report, error)
|
Netcheck() (*netcheck.Report, error)
|
||||||
WaitForNeedsLogin() error
|
WaitForNeedsLogin() error
|
||||||
WaitForRunning() error
|
WaitForRunning() error
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -32,7 +33,7 @@ const (
|
||||||
defaultPingTimeout = 300 * time.Millisecond
|
defaultPingTimeout = 300 * time.Millisecond
|
||||||
defaultPingCount = 10
|
defaultPingCount = 10
|
||||||
dockerContextPath = "../."
|
dockerContextPath = "../."
|
||||||
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
|
caCertRoot = "/usr/local/share/ca-certificates"
|
||||||
dockerExecuteTimeout = 60 * time.Second
|
dockerExecuteTimeout = 60 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,6 +45,11 @@ var (
|
||||||
errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey")
|
errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey")
|
||||||
errTailscaleNotConnected = errors.New("tailscale not connected")
|
errTailscaleNotConnected = errors.New("tailscale not connected")
|
||||||
errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login")
|
errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login")
|
||||||
|
errInvalidClientConfig = errors.New("verifiably invalid client config requested")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
VersionHead = "head"
|
||||||
)
|
)
|
||||||
|
|
||||||
func errTailscaleStatus(hostname string, err error) error {
|
func errTailscaleStatus(hostname string, err error) error {
|
||||||
|
@ -65,7 +71,7 @@ type TailscaleInContainer struct {
|
||||||
fqdn string
|
fqdn string
|
||||||
|
|
||||||
// optional config
|
// optional config
|
||||||
headscaleCert []byte
|
caCerts [][]byte
|
||||||
headscaleHostname string
|
headscaleHostname string
|
||||||
withWebsocketDERP bool
|
withWebsocketDERP bool
|
||||||
withSSH bool
|
withSSH bool
|
||||||
|
@ -74,17 +80,23 @@ type TailscaleInContainer struct {
|
||||||
withExtraHosts []string
|
withExtraHosts []string
|
||||||
workdir string
|
workdir string
|
||||||
netfilter string
|
netfilter string
|
||||||
|
|
||||||
|
// build options, solely for HEAD
|
||||||
|
buildConfig TailscaleInContainerBuildConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type TailscaleInContainerBuildConfig struct {
|
||||||
|
tags []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option represent optional settings that can be given to a
|
// Option represent optional settings that can be given to a
|
||||||
// Tailscale instance.
|
// Tailscale instance.
|
||||||
type Option = func(c *TailscaleInContainer)
|
type Option = func(c *TailscaleInContainer)
|
||||||
|
|
||||||
// WithHeadscaleTLS takes the certificate of the Headscale instance
|
// WithCACert adds it to the trusted surtificate of the Tailscale container.
|
||||||
// and adds it to the trusted surtificate of the Tailscale container.
|
func WithCACert(cert []byte) Option {
|
||||||
func WithHeadscaleTLS(cert []byte) Option {
|
|
||||||
return func(tsic *TailscaleInContainer) {
|
return func(tsic *TailscaleInContainer) {
|
||||||
tsic.headscaleCert = cert
|
tsic.caCerts = append(tsic.caCerts, cert)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +125,7 @@ func WithOrCreateNetwork(network *dockertest.Network) Option {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithHeadscaleName set the name of the headscale instance,
|
// WithHeadscaleName set the name of the headscale instance,
|
||||||
// mostly useful in combination with TLS and WithHeadscaleTLS.
|
// mostly useful in combination with TLS and WithCACert.
|
||||||
func WithHeadscaleName(hsName string) Option {
|
func WithHeadscaleName(hsName string) Option {
|
||||||
return func(tsic *TailscaleInContainer) {
|
return func(tsic *TailscaleInContainer) {
|
||||||
tsic.headscaleHostname = hsName
|
tsic.headscaleHostname = hsName
|
||||||
|
@ -175,6 +187,22 @@ func WithNetfilter(state string) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithBuildTag adds an additional value to the `-tags=` parameter
|
||||||
|
// of the Go compiler, allowing callers to customize the Tailscale client build.
|
||||||
|
// This option is only meaningful when invoked on **HEAD** versions of the client.
|
||||||
|
// Attempts to use it with any other version is a bug in the calling code.
|
||||||
|
func WithBuildTag(tag string) Option {
|
||||||
|
return func(tsic *TailscaleInContainer) {
|
||||||
|
if tsic.version != VersionHead {
|
||||||
|
panic(errInvalidClientConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
tsic.buildConfig.tags = append(
|
||||||
|
tsic.buildConfig.tags, tag,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// New returns a new TailscaleInContainer instance.
|
// New returns a new TailscaleInContainer instance.
|
||||||
func New(
|
func New(
|
||||||
pool *dockertest.Pool,
|
pool *dockertest.Pool,
|
||||||
|
@ -219,18 +247,20 @@ func New(
|
||||||
}
|
}
|
||||||
|
|
||||||
if tsic.withWebsocketDERP {
|
if tsic.withWebsocketDERP {
|
||||||
|
if version != VersionHead {
|
||||||
|
return tsic, errInvalidClientConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
WithBuildTag("ts_debug_websockets")(tsic)
|
||||||
|
|
||||||
tailscaleOptions.Env = append(
|
tailscaleOptions.Env = append(
|
||||||
tailscaleOptions.Env,
|
tailscaleOptions.Env,
|
||||||
fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP),
|
fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tsic.headscaleHostname != "" {
|
tailscaleOptions.ExtraHosts = append(tailscaleOptions.ExtraHosts,
|
||||||
tailscaleOptions.ExtraHosts = []string{
|
"host.docker.internal:host-gateway")
|
||||||
"host.docker.internal:host-gateway",
|
|
||||||
fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tsic.workdir != "" {
|
if tsic.workdir != "" {
|
||||||
tailscaleOptions.WorkingDir = tsic.workdir
|
tailscaleOptions.WorkingDir = tsic.workdir
|
||||||
|
@ -245,14 +275,36 @@ func New(
|
||||||
}
|
}
|
||||||
|
|
||||||
var container *dockertest.Resource
|
var container *dockertest.Resource
|
||||||
|
|
||||||
|
if version != VersionHead {
|
||||||
|
// build options are not meaningful with pre-existing images,
|
||||||
|
// let's not lead anyone astray by pretending otherwise.
|
||||||
|
defaultBuildConfig := TailscaleInContainerBuildConfig{}
|
||||||
|
hasBuildConfig := !reflect.DeepEqual(defaultBuildConfig, tsic.buildConfig)
|
||||||
|
if hasBuildConfig {
|
||||||
|
return tsic, errInvalidClientConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch version {
|
switch version {
|
||||||
case "head":
|
case VersionHead:
|
||||||
buildOptions := &dockertest.BuildOptions{
|
buildOptions := &dockertest.BuildOptions{
|
||||||
Dockerfile: "Dockerfile.tailscale-HEAD",
|
Dockerfile: "Dockerfile.tailscale-HEAD",
|
||||||
ContextDir: dockerContextPath,
|
ContextDir: dockerContextPath,
|
||||||
BuildArgs: []docker.BuildArg{},
|
BuildArgs: []docker.BuildArg{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buildTags := strings.Join(tsic.buildConfig.tags, ",")
|
||||||
|
if len(buildTags) > 0 {
|
||||||
|
buildOptions.BuildArgs = append(
|
||||||
|
buildOptions.BuildArgs,
|
||||||
|
docker.BuildArg{
|
||||||
|
Name: "BUILD_TAGS",
|
||||||
|
Value: buildTags,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
container, err = pool.BuildAndRunWithBuildOptions(
|
container, err = pool.BuildAndRunWithBuildOptions(
|
||||||
buildOptions,
|
buildOptions,
|
||||||
tailscaleOptions,
|
tailscaleOptions,
|
||||||
|
@ -294,8 +346,8 @@ func New(
|
||||||
|
|
||||||
tsic.container = container
|
tsic.container = container
|
||||||
|
|
||||||
if tsic.hasTLS() {
|
for i, cert := range tsic.caCerts {
|
||||||
err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert)
|
err = tsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -304,10 +356,6 @@ func New(
|
||||||
return tsic, nil
|
return tsic, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TailscaleInContainer) hasTLS() bool {
|
|
||||||
return len(t.headscaleCert) != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown stops and cleans up the Tailscale container.
|
// Shutdown stops and cleans up the Tailscale container.
|
||||||
func (t *TailscaleInContainer) Shutdown() error {
|
func (t *TailscaleInContainer) Shutdown() error {
|
||||||
err := t.SaveLog("/tmp/control")
|
err := t.SaveLog("/tmp/control")
|
||||||
|
@ -682,6 +730,34 @@ func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TailscaleInContainer) DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error) {
|
||||||
|
if !util.TailscaleVersionNewerOrEqual("1.34", t.version) {
|
||||||
|
panic("tsic.DebugDERPRegion() called with unsupported version: " + t.version)
|
||||||
|
}
|
||||||
|
|
||||||
|
command := []string{
|
||||||
|
"tailscale",
|
||||||
|
"debug",
|
||||||
|
"derp",
|
||||||
|
region,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, stderr, err := t.Execute(command)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("stderr: %s\n", stderr) // nolint
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to execute tailscale debug derp command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var report ipnstate.DebugDERPRegionReport
|
||||||
|
err = json.Unmarshal([]byte(result), &report)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal tailscale derp region report: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &report, err
|
||||||
|
}
|
||||||
|
|
||||||
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
|
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
|
||||||
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
|
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
|
||||||
command := []string{
|
command := []string{
|
||||||
|
|
Loading…
Reference in a new issue