mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
Compare commits
29 commits
99b2ab843a
...
f77abeda63
Author | SHA1 | Date | |
---|---|---|---|
|
f77abeda63 | ||
|
02c76bda99 | ||
|
af969f602c | ||
|
6422cdf576 | ||
|
6253fc9e72 | ||
|
662dfbf423 | ||
|
03fd7f31b4 | ||
|
4f46d6513b | ||
|
9f6c8ab62e | ||
|
2c1ad6d11a | ||
|
fffd23602b | ||
|
3a2589f1a9 | ||
|
f6276ab9d2 | ||
|
7d9b430ec2 | ||
|
3780c9fd69 | ||
|
281025bb16 | ||
|
5e7c3153b9 | ||
|
7ba0c3d515 | ||
|
4b58dc6eb4 | ||
|
4dd12a2f97 | ||
|
2fe65624c0 | ||
|
35b669fe59 | ||
|
dc07779143 | ||
|
d72663a4d0 | ||
|
0a82d3f17a | ||
|
78214699ad | ||
|
64bb56352f | ||
|
dc17b4d378 | ||
|
a6b19e85db |
46 changed files with 2151 additions and 795 deletions
1
.github/workflows/test-integration.yaml
vendored
1
.github/workflows/test-integration.yaml
vendored
|
@ -21,6 +21,7 @@ jobs:
|
||||||
- TestPolicyUpdateWhileRunningWithCLIInDatabase
|
- TestPolicyUpdateWhileRunningWithCLIInDatabase
|
||||||
- TestOIDCAuthenticationPingAll
|
- TestOIDCAuthenticationPingAll
|
||||||
- TestOIDCExpireNodesBasedOnTokenExpiry
|
- TestOIDCExpireNodesBasedOnTokenExpiry
|
||||||
|
- TestOIDC024UserCreation
|
||||||
- TestAuthWebFlowAuthenticationPingAll
|
- TestAuthWebFlowAuthenticationPingAll
|
||||||
- TestAuthWebFlowLogoutAndRelogin
|
- TestAuthWebFlowLogoutAndRelogin
|
||||||
- TestUserCommand
|
- TestUserCommand
|
||||||
|
|
|
@ -27,6 +27,7 @@ linters:
|
||||||
- nolintlint
|
- nolintlint
|
||||||
- musttag # causes issues with imported libs
|
- musttag # causes issues with imported libs
|
||||||
- depguard
|
- depguard
|
||||||
|
- exportloopref
|
||||||
|
|
||||||
# We should strive to enable these:
|
# We should strive to enable these:
|
||||||
- wrapcheck
|
- wrapcheck
|
||||||
|
@ -56,9 +57,14 @@ linters-settings:
|
||||||
- ok
|
- ok
|
||||||
- c
|
- c
|
||||||
- tt
|
- tt
|
||||||
|
- tx
|
||||||
|
- rx
|
||||||
|
|
||||||
gocritic:
|
gocritic:
|
||||||
disabled-checks:
|
disabled-checks:
|
||||||
- appendAssign
|
- appendAssign
|
||||||
# TODO(kradalby): Remove this
|
# TODO(kradalby): Remove this
|
||||||
- ifElseChain
|
- ifElseChain
|
||||||
|
|
||||||
|
nlreturn:
|
||||||
|
block-size: 4
|
||||||
|
|
78
CHANGELOG.md
78
CHANGELOG.md
|
@ -2,16 +2,82 @@
|
||||||
|
|
||||||
## Next
|
## Next
|
||||||
|
|
||||||
|
### Security fix: OIDC changes in Headscale 0.24.0
|
||||||
|
|
||||||
|
_Headscale v0.23.0 and earlier_ identified OIDC users by the "username" part of their email address (when `strip_email_domain: true`, the default) or whole email address (when `strip_email_domain: false`).
|
||||||
|
|
||||||
|
Depending on how Headscale and your Identity Provider (IdP) were configured, only using the `email` claim could allow a malicious user with an IdP account to take over another Headscale user's account, even when `strip_email_domain: false`.
|
||||||
|
|
||||||
|
This would also cause a user to lose access to their Headscale account if they changed their email address.
|
||||||
|
|
||||||
|
_Headscale v0.24.0_ now identifies OIDC users by the `iss` and `sub` claims. [These are guaranteed by the OIDC specification to be stable and unique](https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability), even if a user changes email address. A well-designed IdP will typically set `sub` to an opaque identifier like a UUID or numeric ID, which has no relation to the user's name or email address.
|
||||||
|
|
||||||
|
This issue _only_ affects Headscale installations which authenticate with OIDC.
|
||||||
|
|
||||||
|
Headscale v0.24.0 and later will also automatically update profile fields with OIDC data on login. This means that users can change those details in your IdP, and have it populate to Headscale automatically the next time they log in. However, this may affect the way you reference users in policies.
|
||||||
|
|
||||||
|
#### Migrating existing installations
|
||||||
|
|
||||||
|
Headscale v0.23.0 and earlier never recorded the `iss` and `sub` fields, so all legacy (existing) OIDC accounts from _need to be migrated_ to be properly secured.
|
||||||
|
|
||||||
|
Headscale v0.24.0 has an automatic migration feature, which is enabled by default (`map_legacy_users: true`). **This will be disabled by default in a future version of Headscale – any unmigrated users will get new accounts.**
|
||||||
|
|
||||||
|
Headscale v0.24.0 will ignore any `email` claim if the IdP does not provide an `email_verified` claim set to `true`. [What "verified" actually means is contextually dependent](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) – Headscale uses it as a signal that the contents of the `email` claim is reasonably trustworthy.
|
||||||
|
|
||||||
|
Headscale v0.23.0 and earlier never checked the `email_verified` claim. This means even if an IdP explicitly indicated to Headscale that its `email` claim was untrustworthy, Headscale would have still accepted it.
|
||||||
|
|
||||||
|
##### What does automatic migration do?
|
||||||
|
|
||||||
|
When automatic migration is enabled (`map_legacy_users: true`), Headscale will first match an OIDC account to a Headscale account by `iss` and `sub`, and then fall back to matching OIDC users similarly to how Headscale v0.23.0 did:
|
||||||
|
|
||||||
|
- If `strip_email_domain: true` (the default): the Headscale username matches the "username" part of their email address.
|
||||||
|
- If `strip_email_domain: false`: the Headscale username matches the _whole_ email address.
|
||||||
|
|
||||||
|
On migration, Headscale will change the account's username to their `preferred_username`. **This could break any ACLs or policies which are configured to match by username.**
|
||||||
|
|
||||||
|
Like with Headscale v0.23.0 and earlier, this migration only works for users who haven't changed their email address since their last Headscale login.
|
||||||
|
|
||||||
|
A _successful_ automated migration should otherwise be transparent to users.
|
||||||
|
|
||||||
|
Once a Headscale account has been migrated, it will be _unavailable_ to be matched by the legacy process. An OIDC login with a matching username, but _non-matching_ `iss` and `sub` will instead get a _new_ Headscale account.
|
||||||
|
|
||||||
|
Because of the way OIDC works, Headscale's automated migration process can _only_ work when a user tries to log in after the update. Mass updates would require Headscale implement a protocol like SCIM, which is **extremely** complicated and not available in all identity providers.
|
||||||
|
|
||||||
|
Administrators could also attempt to migrate users manually by editing the database, using their own mapping rules with known-good data sources.
|
||||||
|
|
||||||
|
Legacy account migration should have no effect on new installations where all users have a recorded `sub` and `iss`.
|
||||||
|
|
||||||
|
##### What happens when automatic migration is disabled?
|
||||||
|
|
||||||
|
When automatic migration is disabled (`map_legacy_users: false`), Headscale will only try to match an OIDC account to a Headscale account by `iss` and `sub`.
|
||||||
|
|
||||||
|
If there is no match, it will get a _new_ Headscale account – even if there was a legacy account which _could_ have matched and migrated.
|
||||||
|
|
||||||
|
We recommend new Headscale users explicitly disable automatic migration – but it should otherwise have no effect if every account has a recorded `iss` and `sub`.
|
||||||
|
|
||||||
|
When automatic migration is disabled, the `strip_email_domain` setting will have no effect.
|
||||||
|
|
||||||
|
Special thanks to @micolous for reviewing, proposing and working with us on these changes.
|
||||||
|
|
||||||
|
#### Other OIDC changes
|
||||||
|
|
||||||
|
Headscale now uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to populate and update user information every time they log in:
|
||||||
|
|
||||||
|
| Headscale profile field | OIDC claim | Notes / examples |
|
||||||
|
| ----------------------- | -------------------- | --------------------------------------------------------------------------------------------------------- |
|
||||||
|
| email address | `email` | Only used when `"email_verified": true` |
|
||||||
|
| display name | `name` | eg: `Sam Smith` |
|
||||||
|
| username | `preferred_username` | Varies depending on IdP and configuration, eg: `ssmith`, `ssmith@idp.example.com`, `\\example.com\ssmith` |
|
||||||
|
| profile picture | `picture` | URL to a profile picture or avatar |
|
||||||
|
|
||||||
|
These should show up nicely in the Tailscale client.
|
||||||
|
|
||||||
|
This will also affect the way you [reference users in policies](https://github.com/juanfont/headscale/pull/2205).
|
||||||
|
|
||||||
### BREAKING
|
### BREAKING
|
||||||
|
|
||||||
- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020)
|
- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020)
|
||||||
- Having usernames in magic DNS is no longer possible.
|
- Having usernames in magic DNS is no longer possible.
|
||||||
- Redo OpenID Connect configuration [#2020](https://github.com/juanfont/headscale/pull/2020)
|
|
||||||
- `strip_email_domain` has been removed, domain is _always_ part of the username for OIDC.
|
|
||||||
- Users are now identified by `sub` claim in the ID token instead of username, allowing the username, name and email to be updated.
|
|
||||||
- User has been extended to store username, display name, profile picture url and email.
|
|
||||||
- These fields are forwarded to the client, and shows up nicely in the user switcher.
|
|
||||||
- These fields can be made available via the API/CLI for non-OIDC users in the future.
|
|
||||||
- Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149)
|
- Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149)
|
||||||
- Clean up old code required by old versions
|
- Clean up old code required by old versions
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
package cli
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
@ -64,6 +66,19 @@ func mockOIDC() error {
|
||||||
accessTTL = newTTL
|
accessTTL = newTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userStr := os.Getenv("MOCKOIDC_USERS")
|
||||||
|
if userStr == "" {
|
||||||
|
return fmt.Errorf("MOCKOIDC_USERS not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []mockoidc.MockUser
|
||||||
|
err := json.Unmarshal([]byte(userStr), &users)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unmarshalling users: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Interface("users", users).Msg("loading users from JSON")
|
||||||
|
|
||||||
log.Info().Msgf("Access token TTL: %s", accessTTL)
|
log.Info().Msgf("Access token TTL: %s", accessTTL)
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
port, err := strconv.Atoi(portStr)
|
||||||
|
@ -71,7 +86,7 @@ func mockOIDC() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mock, err := getMockOIDC(clientID, clientSecret)
|
mock, err := getMockOIDC(clientID, clientSecret, users)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -93,12 +108,18 @@ func mockOIDC() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) {
|
func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) {
|
||||||
keypair, err := mockoidc.NewKeypair(nil)
|
keypair, err := mockoidc.NewKeypair(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userQueue := mockoidc.UserQueue{}
|
||||||
|
|
||||||
|
for _, user := range users {
|
||||||
|
userQueue.Push(&user)
|
||||||
|
}
|
||||||
|
|
||||||
mock := mockoidc.MockOIDC{
|
mock := mockoidc.MockOIDC{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
|
@ -107,9 +128,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro
|
||||||
CodeChallengeMethodsSupported: []string{"plain", "S256"},
|
CodeChallengeMethodsSupported: []string{"plain", "S256"},
|
||||||
Keypair: keypair,
|
Keypair: keypair,
|
||||||
SessionStore: mockoidc.NewSessionStore(),
|
SessionStore: mockoidc.NewSessionStore(),
|
||||||
UserQueue: &mockoidc.UserQueue{},
|
UserQueue: &userQueue,
|
||||||
ErrorQueue: &mockoidc.ErrorQueue{},
|
ErrorQueue: &mockoidc.ErrorQueue{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Info().Msgf("Request: %+v", r)
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
if r.Response != nil {
|
||||||
|
log.Info().Msgf("Response: %+v", r.Response)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
return &mock, nil
|
return &mock, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -168,6 +168,11 @@ database:
|
||||||
# https://www.sqlite.org/wal.html
|
# https://www.sqlite.org/wal.html
|
||||||
write_ahead_log: true
|
write_ahead_log: true
|
||||||
|
|
||||||
|
# Maximum number of WAL file frames before the WAL file is automatically checkpointed.
|
||||||
|
# https://www.sqlite.org/c3ref/wal_autocheckpoint.html
|
||||||
|
# Set to 0 to disable automatic checkpointing.
|
||||||
|
wal_autocheckpoint: 1000
|
||||||
|
|
||||||
# # Postgres config
|
# # Postgres config
|
||||||
# Please note that using Postgres is highly discouraged as it is only supported for legacy reasons.
|
# Please note that using Postgres is highly discouraged as it is only supported for legacy reasons.
|
||||||
# See database.type for more information.
|
# See database.type for more information.
|
||||||
|
@ -364,12 +369,18 @@ unix_socket_permission: "0770"
|
||||||
# allowed_users:
|
# allowed_users:
|
||||||
# - alice@example.com
|
# - alice@example.com
|
||||||
#
|
#
|
||||||
# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
|
# # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users
|
||||||
# # This will transform `first-name.last-name@example.com` to the user `first-name.last-name`
|
# # by taking the username from the legacy user and matching it with the username
|
||||||
# # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following
|
# # provided by the OIDC. This is useful when migrating from legacy users to OIDC
|
||||||
# user: `first-name.last-name.example.com`
|
# # to force them using the unique identifier from the OIDC and to give them a
|
||||||
#
|
# # proper display name and picture if available.
|
||||||
# strip_email_domain: true
|
# # Note that this will only work if the username from the legacy user is the same
|
||||||
|
# # and ther is a posibility for account takeover should a username have changed
|
||||||
|
# # with the provider.
|
||||||
|
# # Disabling this feature will cause all new logins to be created as new users.
|
||||||
|
# # Note this option will be removed in the future and should be set to false
|
||||||
|
# # on all new installations, or when all users have logged in with OIDC once.
|
||||||
|
# map_legacy_users: true
|
||||||
|
|
||||||
# Logtail configuration
|
# Logtail configuration
|
||||||
# Logtail is Tailscales logging and auditing infrastructure, it allows the control panel
|
# Logtail is Tailscales logging and auditing infrastructure, it allows the control panel
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
cairosvg~=2.7.1
|
mkdocs-include-markdown-plugin~=7.1
|
||||||
mkdocs-include-markdown-plugin~=6.2.2
|
mkdocs-macros-plugin~=1.3
|
||||||
mkdocs-macros-plugin~=1.2.0
|
mkdocs-material[imaging]~=9.5
|
||||||
mkdocs-material~=9.5.18
|
mkdocs-minify-plugin~=0.7
|
||||||
mkdocs-minify-plugin~=0.7.1
|
mkdocs-redirects~=1.2
|
||||||
mkdocs-redirects~=1.2.1
|
|
||||||
pillow~=10.1.0
|
|
||||||
|
|
|
@ -20,11 +20,11 @@
|
||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1731763621,
|
"lastModified": 1731890469,
|
||||||
"narHash": "sha256-ddcX4lQL0X05AYkrkV2LMFgGdRvgap7Ho8kgon3iWZk=",
|
"narHash": "sha256-D1FNZ70NmQEwNxpSSdTXCSklBH1z2isPR84J6DQrJGs=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "c69a9bffbecde46b4b939465422ddc59493d3e4d",
|
"rev": "5083ec887760adfe12af64830a66807423a859a7",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
|
@ -32,7 +32,7 @@
|
||||||
|
|
||||||
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
# When updating go.mod or go.sum, a new sha will need to be calculated,
|
||||||
# update this if you have a mismatch after doing a change to thos files.
|
# update this if you have a mismatch after doing a change to thos files.
|
||||||
vendorHash = "sha256-CMkYTRjmhvTTrB7JbLj0cj9VEyzpG0iUWXkaOagwYTk=";
|
vendorHash = "sha256-4VNiHUblvtcl9UetwiL6ZeVYb0h2e9zhYVsirhAkvOg=";
|
||||||
|
|
||||||
subPackages = ["cmd/headscale"];
|
subPackages = ["cmd/headscale"];
|
||||||
|
|
||||||
|
@ -102,6 +102,7 @@
|
||||||
ko
|
ko
|
||||||
yq-go
|
yq-go
|
||||||
ripgrep
|
ripgrep
|
||||||
|
postgresql
|
||||||
|
|
||||||
# 'dot' is needed for pprof graphs
|
# 'dot' is needed for pprof graphs
|
||||||
# go tool pprof -http=: <source>
|
# go tool pprof -http=: <source>
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -49,6 +49,7 @@ require (
|
||||||
gorm.io/gorm v1.25.11
|
gorm.io/gorm v1.25.11
|
||||||
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7
|
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7
|
||||||
zgo.at/zcache/v2 v2.1.0
|
zgo.at/zcache/v2 v2.1.0
|
||||||
|
zombiezen.com/go/postgrestest v1.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
@ -134,6 +135,7 @@ require (
|
||||||
github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect
|
github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a // indirect
|
||||||
github.com/kr/pretty v0.3.1 // indirect
|
github.com/kr/pretty v0.3.1 // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/lib/pq v1.10.9 // indirect
|
||||||
github.com/lithammer/fuzzysearch v1.1.8 // indirect
|
github.com/lithammer/fuzzysearch v1.1.8 // indirect
|
||||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
|
3
go.sum
3
go.sum
|
@ -311,6 +311,7 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
|
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
|
||||||
|
github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4=
|
github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4=
|
||||||
|
@ -731,3 +732,5 @@ tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7 h1:nfRWV6ECxwNvvXKtbqSVs
|
||||||
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg=
|
tailscale.com v1.75.0-pre.0.20240926101731-7d1160ddaab7/go.mod h1:xKxYf3B3PuezFlRaMT+VhuVu8XTFUTLy+VCzLPMJVmg=
|
||||||
zgo.at/zcache/v2 v2.1.0 h1:USo+ubK+R4vtjw4viGzTe/zjXyPw6R7SK/RL3epBBxs=
|
zgo.at/zcache/v2 v2.1.0 h1:USo+ubK+R4vtjw4viGzTe/zjXyPw6R7SK/RL3epBBxs=
|
||||||
zgo.at/zcache/v2 v2.1.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
zgo.at/zcache/v2 v2.1.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
||||||
|
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
|
||||||
|
zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ=
|
||||||
|
|
|
@ -1031,14 +1031,18 @@ func (h *Headscale) loadACLPolicy() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||||
}
|
}
|
||||||
|
users, err := h.db.ListUsers()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = pol.CompileFilterRules(nodes)
|
_, err = pol.CompileFilterRules(users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("verifying policy rules: %w", err)
|
return fmt.Errorf("verifying policy rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nodes) > 0 {
|
if len(nodes) > 0 {
|
||||||
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
|
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("verifying SSH rules: %w", err)
|
return fmt.Errorf("verifying SSH rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -474,6 +474,8 @@ func NewHeadscaleDatabase(
|
||||||
Rollback: func(db *gorm.DB) error { return nil },
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
// Pick up new user fields used for OIDC and to
|
||||||
|
// populate the user with more interesting information.
|
||||||
ID: "202407191627",
|
ID: "202407191627",
|
||||||
Migrate: func(tx *gorm.DB) error {
|
Migrate: func(tx *gorm.DB) error {
|
||||||
err := tx.AutoMigrate(&types.User{})
|
err := tx.AutoMigrate(&types.User{})
|
||||||
|
@ -485,6 +487,40 @@ func NewHeadscaleDatabase(
|
||||||
},
|
},
|
||||||
Rollback: func(db *gorm.DB) error { return nil },
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
// The unique constraint of Name has been dropped
|
||||||
|
// in favour of a unique together of name and
|
||||||
|
// provider identity.
|
||||||
|
ID: "202408181235",
|
||||||
|
Migrate: func(tx *gorm.DB) error {
|
||||||
|
err := tx.AutoMigrate(&types.User{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up indexes and unique constraints outside of GORM, it does not support
|
||||||
|
// conditional unique constraints.
|
||||||
|
// This ensures the following:
|
||||||
|
// - A user name and provider_identifier is unique
|
||||||
|
// - A provider_identifier is unique
|
||||||
|
// - A user name is unique if there is no provider_identifier is not set
|
||||||
|
for _, idx := range []string{
|
||||||
|
"DROP INDEX IF EXISTS idx_provider_identifier",
|
||||||
|
"DROP INDEX IF EXISTS idx_name_provider_identifier",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL;",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_name_provider_identifier ON users (name,provider_identifier);",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL;",
|
||||||
|
} {
|
||||||
|
err = tx.Exec(idx).Error
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating username index: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Rollback: func(db *gorm.DB) error { return nil },
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -543,10 +579,10 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Sqlite.WriteAheadLog {
|
if cfg.Sqlite.WriteAheadLog {
|
||||||
if err := db.Exec(`
|
if err := db.Exec(fmt.Sprintf(`
|
||||||
PRAGMA journal_mode=WAL;
|
PRAGMA journal_mode=WAL;
|
||||||
PRAGMA wal_autocheckpoint=0;
|
PRAGMA wal_autocheckpoint=%d;
|
||||||
`).Error; err != nil {
|
`, cfg.Sqlite.WALAutoCheckPoint)).Error; err != nil {
|
||||||
return nil, fmt.Errorf("setting WAL mode: %w", err)
|
return nil, fmt.Errorf("setting WAL mode: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -8,6 +9,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -16,6 +18,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"zgo.at/zcache/v2"
|
"zgo.at/zcache/v2"
|
||||||
)
|
)
|
||||||
|
@ -44,7 +47,7 @@ func TestMigrations(t *testing.T) {
|
||||||
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
|
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||||
return GetRoutes(rx)
|
return GetRoutes(rx)
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, routes, 10)
|
assert.Len(t, routes, 10)
|
||||||
want := types.Routes{
|
want := types.Routes{
|
||||||
|
@ -70,7 +73,7 @@ func TestMigrations(t *testing.T) {
|
||||||
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
|
routes, err := Read(h.DB, func(rx *gorm.DB) (types.Routes, error) {
|
||||||
return GetRoutes(rx)
|
return GetRoutes(rx)
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, routes, 4)
|
assert.Len(t, routes, 4)
|
||||||
want := types.Routes{
|
want := types.Routes{
|
||||||
|
@ -120,19 +123,19 @@ func TestMigrations(t *testing.T) {
|
||||||
dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
|
dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite",
|
||||||
wantFunc: func(t *testing.T, h *HSDatabase) {
|
wantFunc: func(t *testing.T, h *HSDatabase) {
|
||||||
keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||||
kratest, err := ListPreAuthKeys(rx, "kratest")
|
kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
testkra, err := ListPreAuthKeys(rx, "testkra")
|
testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return append(kratest, testkra...), nil
|
return append(kratest, testkra...), nil
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, keys, 5)
|
assert.Len(t, keys, 5)
|
||||||
want := []types.PreAuthKey{
|
want := []types.PreAuthKey{
|
||||||
|
@ -177,7 +180,7 @@ func TestMigrations(t *testing.T) {
|
||||||
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return ListNodes(rx)
|
return ListNodes(rx)
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
|
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
|
||||||
|
@ -256,3 +259,120 @@ func testCopyOfDatabase(src string) (string, error) {
|
||||||
func emptyCache() *zcache.Cache[string, types.Node] {
|
func emptyCache() *zcache.Cache[string, types.Node] {
|
||||||
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
return zcache.New[string, types.Node](time.Minute, time.Hour)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// requireConstraintFailed checks if the error is a constraint failure with
|
||||||
|
// either SQLite and PostgreSQL error messages.
|
||||||
|
func requireConstraintFailed(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
require.Error(t, err)
|
||||||
|
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
||||||
|
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConstraints(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
run func(*testing.T, *gorm.DB)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-duplicate-username-if-no-oidc",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
_, err := CreateUser(db, "user1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = CreateUser(db, "user1")
|
||||||
|
requireConstraintFailed(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-oidc-duplicate-username-and-id",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user = types.User{
|
||||||
|
Model: gorm.Model{ID: 2},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
requireConstraintFailed(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-oidc-duplicate-id",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user = types.User{
|
||||||
|
Model: gorm.Model{ID: 2},
|
||||||
|
Name: "user1.1",
|
||||||
|
}
|
||||||
|
user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true}
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
requireConstraintFailed(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow-duplicate-username-cli-then-oidc",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
_, err := CreateUser(db, "user1") // Create CLI username
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user := types.User{
|
||||||
|
Name: "user1",
|
||||||
|
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow-duplicate-username-oidc-then-cli",
|
||||||
|
run: func(t *testing.T, db *gorm.DB) {
|
||||||
|
user := types.User{
|
||||||
|
Name: "user1",
|
||||||
|
ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := db.Save(&user).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = CreateUser(db, "user1") // Create CLI username
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name+"-postgres", func(t *testing.T) {
|
||||||
|
db := newPostgresTestDB(t)
|
||||||
|
tt.run(t, db.DB.Debug())
|
||||||
|
})
|
||||||
|
t.Run(tt.name+"-sqlite", func(t *testing.T) {
|
||||||
|
db, err := newSQLiteTestDB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating database: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tt.run(t, db.DB.Debug())
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/types/ptr"
|
"tailscale.com/types/ptr"
|
||||||
)
|
)
|
||||||
|
@ -457,7 +458,12 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
db := tt.dbFunc()
|
db := tt.dbFunc()
|
||||||
|
|
||||||
alloc, err := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategySequential)
|
alloc, err := NewIPAllocator(
|
||||||
|
db,
|
||||||
|
tt.prefix4,
|
||||||
|
tt.prefix6,
|
||||||
|
types.IPAllocationStrategySequential,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to set up ip alloc: %s", err)
|
t.Fatalf("failed to set up ip alloc: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -482,24 +488,29 @@ func TestBackfillIPAddresses(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||||
alloc, err := NewIPAllocator(db, ptr.To(tsaddr.CGNATRange()), ptr.To(tsaddr.TailscaleULARange()), types.IPAllocationStrategySequential)
|
alloc, err := NewIPAllocator(
|
||||||
|
db,
|
||||||
|
ptr.To(tsaddr.CGNATRange()),
|
||||||
|
ptr.To(tsaddr.TailscaleULARange()),
|
||||||
|
types.IPAllocationStrategySequential,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to set up ip alloc: %s", err)
|
t.Fatalf("failed to set up ip alloc: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that we do not give out 100.100.100.100
|
// Validate that we do not give out 100.100.100.100
|
||||||
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange()))
|
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange()))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, na("100.100.100.101"), *nextQuad100)
|
assert.Equal(t, na("100.100.100.101"), *nextQuad100)
|
||||||
|
|
||||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||||
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange()))
|
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange()))
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
|
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
|
||||||
|
|
||||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||||
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange()))
|
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange()))
|
||||||
t.Logf("chrome: %s", nextChrome.String())
|
t.Logf("chrome: %s", nextChrome.String())
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, na("100.115.94.0"), *nextChrome)
|
assert.Equal(t, na("100.115.94.0"), *nextChrome)
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,15 +91,15 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) {
|
func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||||
return getNode(rx, user, name)
|
return getNode(rx, uid, name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// getNode finds a Node by name and user and returns the Node struct.
|
// getNode finds a Node by name and user and returns the Node struct.
|
||||||
func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) {
|
func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
|
||||||
nodes, err := ListNodesByUser(tx, user)
|
nodes, err := ListNodesByUser(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/puzpuzpuz/xsync/v3"
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
|
@ -29,10 +30,10 @@ func (s *Suite) TestGetNode(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -50,7 +51,7 @@ func (s *Suite) TestGetNode(c *check.C) {
|
||||||
trx := db.DB.Save(node)
|
trx := db.DB.Save(node)
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,7 +59,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -87,7 +88,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -135,7 +136,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||||
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
|
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(user.Name, "testnode3")
|
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,7 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
_, err = db.GetNodeByID(0)
|
||||||
|
@ -189,7 +190,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
for _, name := range []string{"test", "admin"} {
|
for _, name := range []string{"test", "admin"} {
|
||||||
user, err := db.CreateUser(name)
|
user, err := db.CreateUser(name)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
stor = append(stor, base{user, pak})
|
stor = append(stor, base{user, pak})
|
||||||
}
|
}
|
||||||
|
@ -255,10 +256,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(testPeers), check.Equals, 9)
|
c.Assert(len(testPeers), check.Equals, 9)
|
||||||
|
|
||||||
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
|
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
|
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
||||||
|
@ -281,10 +282,10 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -302,7 +303,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
}
|
}
|
||||||
db.DB.Save(node)
|
db.DB.Save(node)
|
||||||
|
|
||||||
nodeFromDB, err := db.getNode("test", "testnode")
|
nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(nodeFromDB, check.NotNil)
|
c.Assert(nodeFromDB, check.NotNil)
|
||||||
|
|
||||||
|
@ -312,7 +313,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
err = db.NodeSetExpiry(nodeFromDB.ID, now)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
nodeFromDB, err = db.getNode("test", "testnode")
|
nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
c.Assert(nodeFromDB.IsExpired(), check.Equals, true)
|
||||||
|
@ -322,10 +323,10 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
|
@ -348,7 +349,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
sTags := []string{"tag:test", "tag:foo"}
|
sTags := []string{"tag:test", "tag:foo"}
|
||||||
err = db.SetTags(node.ID, sTags)
|
err = db.SetTags(node.ID, sTags)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode("test", "testnode")
|
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
|
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
|
||||||
|
|
||||||
|
@ -356,7 +357,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||||
err = db.SetTags(node.ID, eTags)
|
err = db.SetTags(node.ID, eTags)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode("test", "testnode")
|
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(
|
c.Assert(
|
||||||
node.ForcedTags,
|
node.ForcedTags,
|
||||||
|
@ -367,7 +368,7 @@ func (s *Suite) TestSetTags(c *check.C) {
|
||||||
// test removing tags
|
// test removing tags
|
||||||
err = db.SetTags(node.ID, []string{})
|
err = db.SetTags(node.ID, []string{})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
node, err = db.getNode("test", "testnode")
|
node, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
|
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
|
||||||
}
|
}
|
||||||
|
@ -557,18 +558,18 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
adb, err := newTestDB()
|
adb, err := newSQLiteTestDB()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
||||||
|
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, pol)
|
assert.NotNil(t, pol)
|
||||||
|
|
||||||
user, err := adb.CreateUser("test")
|
user, err := adb.CreateUser("test")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
nodeKey := key.NewNode()
|
||||||
machineKey := key.NewMachine()
|
machineKey := key.NewMachine()
|
||||||
|
@ -590,21 +591,21 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
trx := adb.DB.Save(&node)
|
trx := adb.DB.Save(&node)
|
||||||
assert.NoError(t, trx.Error)
|
require.NoError(t, trx.Error)
|
||||||
|
|
||||||
sendUpdate, err := adb.SaveNodeRoutes(&node)
|
sendUpdate, err := adb.SaveNodeRoutes(&node)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, sendUpdate)
|
assert.False(t, sendUpdate)
|
||||||
|
|
||||||
node0ByID, err := adb.GetNodeByID(0)
|
node0ByID, err := adb.GetNodeByID(0)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// TODO(kradalby): Check state update
|
// TODO(kradalby): Check state update
|
||||||
err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
|
err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, enabledRoutes, len(tt.want))
|
assert.Len(t, enabledRoutes, len(tt.want))
|
||||||
|
|
||||||
tsaddr.SortPrefixes(enabledRoutes)
|
tsaddr.SortPrefixes(enabledRoutes)
|
||||||
|
@ -691,19 +692,19 @@ func generateRandomNumber(t *testing.T, max int64) int64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestListEphemeralNodes(t *testing.T) {
|
func TestListEphemeralNodes(t *testing.T) {
|
||||||
db, err := newTestDB()
|
db, err := newSQLiteTestDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating db: %s", err)
|
t.Fatalf("creating db: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
|
@ -726,16 +727,16 @@ func TestListEphemeralNodes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.DB.Save(&node).Error
|
err = db.DB.Save(&node).Error
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Save(&nodeEph).Error
|
err = db.DB.Save(&nodeEph).Error
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes, err := db.ListNodes()
|
nodes, err := db.ListNodes()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ephemeralNodes, err := db.ListEphemeralNodes()
|
ephemeralNodes, err := db.ListEphemeralNodes()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, nodes, 2)
|
assert.Len(t, nodes, 2)
|
||||||
assert.Len(t, ephemeralNodes, 1)
|
assert.Len(t, ephemeralNodes, 1)
|
||||||
|
@ -747,16 +748,16 @@ func TestListEphemeralNodes(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRenameNode(t *testing.T) {
|
func TestRenameNode(t *testing.T) {
|
||||||
db, err := newTestDB()
|
db, err := newSQLiteTestDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating db: %s", err)
|
t.Fatalf("creating db: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user2, err := db.CreateUser("test2")
|
user2, err := db.CreateUser("test2")
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
|
@ -777,10 +778,10 @@ func TestRenameNode(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = db.DB.Save(&node).Error
|
err = db.DB.Save(&node).Error
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Save(&node2).Error
|
err = db.DB.Save(&node2).Error
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
err = db.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
_, err := RegisterNode(tx, node, nil, nil)
|
_, err := RegisterNode(tx, node, nil, nil)
|
||||||
|
@ -790,10 +791,10 @@ func TestRenameNode(t *testing.T) {
|
||||||
_, err = RegisterNode(tx, node2, nil, nil)
|
_, err = RegisterNode(tx, node2, nil, nil)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes, err := db.ListNodes()
|
nodes, err := db.ListNodes()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, nodes, 2)
|
assert.Len(t, nodes, 2)
|
||||||
|
|
||||||
|
@ -815,26 +816,26 @@ func TestRenameNode(t *testing.T) {
|
||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
return RenameNode(tx, nodes[0].ID, "newname")
|
return RenameNode(tx, nodes[0].ID, "newname")
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes, err = db.ListNodes()
|
nodes, err = db.ListNodes()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 2)
|
assert.Len(t, nodes, 2)
|
||||||
assert.Equal(t, nodes[0].Hostname, "test")
|
assert.Equal(t, "test", nodes[0].Hostname)
|
||||||
assert.Equal(t, nodes[0].GivenName, "newname")
|
assert.Equal(t, "newname", nodes[0].GivenName)
|
||||||
|
|
||||||
// Nodes can reuse name that is no longer used
|
// Nodes can reuse name that is no longer used
|
||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
return RenameNode(tx, nodes[1].ID, "test")
|
return RenameNode(tx, nodes[1].ID, "test")
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes, err = db.ListNodes()
|
nodes, err = db.ListNodes()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 2)
|
assert.Len(t, nodes, 2)
|
||||||
assert.Equal(t, nodes[0].Hostname, "test")
|
assert.Equal(t, "test", nodes[0].Hostname)
|
||||||
assert.Equal(t, nodes[0].GivenName, "newname")
|
assert.Equal(t, "newname", nodes[0].GivenName)
|
||||||
assert.Equal(t, nodes[1].GivenName, "test")
|
assert.Equal(t, "test", nodes[1].GivenName)
|
||||||
|
|
||||||
// Nodes cannot be renamed to used names
|
// Nodes cannot be renamed to used names
|
||||||
err = db.Write(func(tx *gorm.DB) error {
|
err = db.Write(func(tx *gorm.DB) error {
|
||||||
|
|
|
@ -23,29 +23,27 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||||
// TODO(kradalby): Should be ID, not name
|
uid types.UserID,
|
||||||
userName string,
|
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*types.PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) {
|
||||||
return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags)
|
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||||
func CreatePreAuthKey(
|
func CreatePreAuthKey(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
// TODO(kradalby): Should be ID, not name
|
uid types.UserID,
|
||||||
userName string,
|
|
||||||
reusable bool,
|
reusable bool,
|
||||||
ephemeral bool,
|
ephemeral bool,
|
||||||
expiration *time.Time,
|
expiration *time.Time,
|
||||||
aclTags []string,
|
aclTags []string,
|
||||||
) (*types.PreAuthKey, error) {
|
) (*types.PreAuthKey, error) {
|
||||||
user, err := GetUserByUsername(tx, userName)
|
user, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -89,15 +87,15 @@ func CreatePreAuthKey(
|
||||||
return &key, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
|
||||||
return ListPreAuthKeys(rx, userName)
|
return ListPreAuthKeysByUser(rx, uid)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
|
||||||
func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) {
|
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
|
||||||
user, err := GetUserByUsername(tx, userName)
|
user, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,14 +11,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
// ID does not exist
|
||||||
|
_, err := db.CreatePreAuthKey(12345, true, false, nil, nil)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// Did we get a valid key?
|
// Did we get a valid key?
|
||||||
|
@ -26,17 +26,18 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
c.Assert(len(key.Key), check.Equals, 48)
|
c.Assert(len(key.Key), check.Equals, 48)
|
||||||
|
|
||||||
// Make sure the User association is populated
|
// Make sure the User association is populated
|
||||||
c.Assert(key.User.Name, check.Equals, user.Name)
|
c.Assert(key.User.ID, check.Equals, user.ID)
|
||||||
|
|
||||||
_, err = db.ListPreAuthKeys("bogus")
|
// ID does not exist
|
||||||
|
_, err = db.ListPreAuthKeys(1000000)
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
keys, err := db.ListPreAuthKeys(user.Name)
|
keys, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
|
|
||||||
// Make sure the User association is populated
|
// Make sure the User association is populated
|
||||||
c.Assert((keys)[0].User.Name, check.Equals, user.Name)
|
c.Assert((keys)[0].User.ID, check.Equals, user.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
|
@ -44,7 +45,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now().Add(-5 * time.Second)
|
now := time.Now().Add(-5 * time.Second)
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -62,7 +63,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||||
user, err := db.CreateUser("test3")
|
user, err := db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -74,7 +75,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test4")
|
user, err := db.CreateUser("test4")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -96,7 +97,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test5")
|
user, err := db.CreateUser("test5")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -118,7 +119,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test6")
|
user, err := db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||||
|
@ -130,7 +131,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||||
user, err := db.CreateUser("test3")
|
user, err := db.CreateUser("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(pak.Expiration, check.IsNil)
|
c.Assert(pak.Expiration, check.IsNil)
|
||||||
|
|
||||||
|
@ -147,7 +148,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||||
user, err := db.CreateUser("test6")
|
user, err := db.CreateUser("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
pak.Used = true
|
pak.Used = true
|
||||||
db.DB.Save(&pak)
|
db.DB.Save(&pak)
|
||||||
|
@ -160,15 +161,15 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||||
user, err := db.CreateUser("test8")
|
user, err := db.CreateUser("test8")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"})
|
||||||
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||||
|
|
||||||
tags := []string{"tag:test1", "tag:test2"}
|
tags := []string{"tag:test1", "tag:test2"}
|
||||||
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||||
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
listedPaks, err := db.ListPreAuthKeys("test8")
|
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
gotTags := listedPaks[0].Proto().GetAclTags()
|
gotTags := listedPaks[0].Proto().GetAclTags()
|
||||||
sort.Sort(sort.StringSlice(gotTags))
|
sort.Sort(sort.StringSlice(gotTags))
|
||||||
|
|
|
@ -639,7 +639,7 @@ func EnableAutoApprovedRoutes(
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Str("user", node.User.Name).
|
Uint("user.id", node.User.ID).
|
||||||
Strs("routeApprovers", routeApprovers).
|
Strs("routeApprovers", routeApprovers).
|
||||||
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
Str("prefix", netip.Prefix(advertisedRoute.Prefix).String()).
|
||||||
Msg("looking up route for autoapproving")
|
Msg("looking up route for autoapproving")
|
||||||
|
@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes(
|
||||||
if approvedAlias == node.User.Username() {
|
if approvedAlias == node.User.Username() {
|
||||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
} else {
|
} else {
|
||||||
|
users, err := ListUsers(tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("looking up users to expand route alias: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||||
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias)
|
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,10 +35,10 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "test_get_route_node")
|
_, err = db.getNode(types.UserID(user.ID), "test_get_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||||
|
@ -79,10 +79,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "test_enable_route_node")
|
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -153,10 +153,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "test_enable_route_node")
|
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netip.ParsePrefix(
|
route, err := netip.ParsePrefix(
|
||||||
|
@ -234,10 +234,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode("test", "test_enable_route_node")
|
_, err = db.getNode(types.UserID(user.ID), "test_enable_route_node")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
prefix, err := netip.ParsePrefix(
|
prefix, err := netip.ParsePrefix(
|
||||||
|
|
|
@ -1,12 +1,17 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
"zombiezen.com/go/postgrestest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test(t *testing.T) {
|
func Test(t *testing.T) {
|
||||||
|
@ -36,13 +41,15 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
db, err = newTestDB()
|
db, err = newSQLiteTestDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestDB() (*HSDatabase, error) {
|
// TODO(kradalby): make this a t.Helper when we dont depend
|
||||||
|
// on check test framework.
|
||||||
|
func newSQLiteTestDB() (*HSDatabase, error) {
|
||||||
var err error
|
var err error
|
||||||
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*")
|
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -53,7 +60,7 @@ func newTestDB() (*HSDatabase, error) {
|
||||||
|
|
||||||
db, err = NewHeadscaleDatabase(
|
db, err = NewHeadscaleDatabase(
|
||||||
types.DatabaseConfig{
|
types.DatabaseConfig{
|
||||||
Type: "sqlite3",
|
Type: types.DatabaseSqlite,
|
||||||
Sqlite: types.SqliteConfig{
|
Sqlite: types.SqliteConfig{
|
||||||
Path: tmpDir + "/headscale_test.db",
|
Path: tmpDir + "/headscale_test.db",
|
||||||
},
|
},
|
||||||
|
@ -67,3 +74,53 @@ func newTestDB() (*HSDatabase, error) {
|
||||||
|
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newPostgresTestDB(t *testing.T) *HSDatabase {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
tmpDir, err = os.MkdirTemp("", "headscale-db-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("database path: %s", tmpDir+"/headscale_test.db")
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
srv, err := postgrestest.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(srv.Cleanup)
|
||||||
|
|
||||||
|
u, err := srv.CreateDatabase(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Logf("created local postgres: %s", u)
|
||||||
|
pu, _ := url.Parse(u)
|
||||||
|
|
||||||
|
pass, _ := pu.User.Password()
|
||||||
|
port, _ := strconv.Atoi(pu.Port())
|
||||||
|
|
||||||
|
db, err = NewHeadscaleDatabase(
|
||||||
|
types.DatabaseConfig{
|
||||||
|
Type: types.DatabasePostgres,
|
||||||
|
Postgres: types.PostgresConfig{
|
||||||
|
Host: pu.Hostname(),
|
||||||
|
User: pu.User.Username(),
|
||||||
|
Name: strings.TrimLeft(pu.Path, "/"),
|
||||||
|
Pass: pass,
|
||||||
|
Port: port,
|
||||||
|
Ssl: "disable",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
emptyCache(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
|
@ -28,11 +28,9 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
user := types.User{}
|
user := types.User{
|
||||||
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
|
Name: name,
|
||||||
return nil, ErrUserExists
|
|
||||||
}
|
}
|
||||||
user.Name = name
|
|
||||||
if err := tx.Create(&user).Error; err != nil {
|
if err := tx.Create(&user).Error; err != nil {
|
||||||
return nil, fmt.Errorf("creating user: %w", err)
|
return nil, fmt.Errorf("creating user: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -40,21 +38,21 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return DestroyUser(tx, name)
|
return DestroyUser(tx, uid)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// DestroyUser destroys a User. Returns error if the User does
|
// DestroyUser destroys a User. Returns error if the User does
|
||||||
// not exist or if there are nodes associated with it.
|
// not exist or if there are nodes associated with it.
|
||||||
func DestroyUser(tx *gorm.DB, name string) error {
|
func DestroyUser(tx *gorm.DB, uid types.UserID) error {
|
||||||
user, err := GetUserByUsername(tx, name)
|
user, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrUserNotFound
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes, err := ListNodesByUser(tx, name)
|
nodes, err := ListNodesByUser(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -62,7 +60,7 @@ func DestroyUser(tx *gorm.DB, name string) error {
|
||||||
return ErrUserStillHasNodes
|
return ErrUserStillHasNodes
|
||||||
}
|
}
|
||||||
|
|
||||||
keys, err := ListPreAuthKeys(tx, name)
|
keys, err := ListPreAuthKeysByUser(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -80,17 +78,17 @@ func DestroyUser(tx *gorm.DB, name string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return RenameUser(tx, oldName, newName)
|
return RenameUser(tx, uid, newName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenameUser renames a User. Returns error if the User does
|
// RenameUser renames a User. Returns error if the User does
|
||||||
// not exist or if another User exists with the new name.
|
// not exist or if another User exists with the new name.
|
||||||
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||||
var err error
|
var err error
|
||||||
oldUser, err := GetUserByUsername(tx, oldName)
|
oldUser, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -98,50 +96,25 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = GetUserByUsername(tx, newName)
|
|
||||||
if err == nil {
|
|
||||||
return ErrUserExists
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrUserNotFound) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldUser.Name = newName
|
oldUser.Name = newName
|
||||||
|
|
||||||
if result := tx.Save(&oldUser); result.Error != nil {
|
if err := tx.Save(&oldUser).Error; err != nil {
|
||||||
return result.Error
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||||
return GetUserByUsername(rx, name)
|
return GetUserByID(rx, uid)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) {
|
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
||||||
user := types.User{}
|
user := types.User{}
|
||||||
if result := tx.First(&user, "name = ?", name); errors.Is(
|
if result := tx.First(&user, "id = ?", uid); errors.Is(
|
||||||
result.Error,
|
|
||||||
gorm.ErrRecordNotFound,
|
|
||||||
) {
|
|
||||||
return nil, ErrUserNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return &user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) {
|
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
|
||||||
return GetUserByID(rx, id)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) {
|
|
||||||
user := types.User{}
|
|
||||||
if result := tx.First(&user, "id = ?", id); errors.Is(
|
|
||||||
result.Error,
|
result.Error,
|
||||||
gorm.ErrRecordNotFound,
|
gorm.ErrRecordNotFound,
|
||||||
) {
|
) {
|
||||||
|
@ -169,54 +142,69 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||||
return ListUsers(rx)
|
return ListUsers(rx, where...)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers gets all the existing users.
|
// ListUsers gets all the existing users.
|
||||||
func ListUsers(tx *gorm.DB) ([]types.User, error) {
|
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||||
|
if len(where) > 1 {
|
||||||
|
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
|
||||||
|
}
|
||||||
|
|
||||||
|
var user *types.User
|
||||||
|
if len(where) == 1 {
|
||||||
|
user = where[0]
|
||||||
|
}
|
||||||
|
|
||||||
users := []types.User{}
|
users := []types.User{}
|
||||||
if err := tx.Find(&users).Error; err != nil {
|
if err := tx.Where(user).Find(&users).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNodesByUser gets all the nodes in a given user.
|
// GetUserByName returns a user if the provided username is
|
||||||
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
|
// unique, and otherwise an error.
|
||||||
err := util.CheckForFQDNRules(name)
|
func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
||||||
if err != nil {
|
users, err := hsdb.ListUsers(&types.User{Name: name})
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
user, err := GetUserByUsername(tx, name)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(users) == 0 {
|
||||||
|
return nil, ErrUserNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) != 1 {
|
||||||
|
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &users[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListNodesByUser gets all the nodes in a given user.
|
||||||
|
func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
|
||||||
nodes := types.Nodes{}
|
nodes := types.Nodes{}
|
||||||
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodes, nil
|
return nodes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
|
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return AssignNodeToUser(tx, node, username)
|
return AssignNodeToUser(tx, node, uid)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AssignNodeToUser assigns a Node to a user.
|
// AssignNodeToUser assigns a Node to a user.
|
||||||
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
|
func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error {
|
||||||
err := util.CheckForFQDNRules(username)
|
user, err := GetUserByID(tx, uid)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
user, err := GetUserByUsername(tx, username)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
@ -17,24 +19,24 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetUserByName("test")
|
_, err = db.GetUserByID(types.UserID(user.ID))
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
err := db.DestroyUser("test")
|
err := db.DestroyUser(9998)
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
|
result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||||
|
@ -44,7 +46,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
user, err = db.CreateUser("test")
|
user, err = db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -57,7 +59,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||||
trx := db.DB.Save(&node)
|
trx := db.DB.Save(&node)
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
|
|
||||||
err = db.DestroyUser("test")
|
err = db.DestroyUser(types.UserID(user.ID))
|
||||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,24 +72,29 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(users), check.Equals, 1)
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = db.RenameUser("test", "test-renamed")
|
err = db.RenameUser(types.UserID(userTest.ID), "test-renamed")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.GetUserByName("test")
|
users, err = db.ListUsers(&types.User{Name: "test"})
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, nil)
|
||||||
|
c.Assert(len(users), check.Equals, 0)
|
||||||
|
|
||||||
_, err = db.GetUserByName("test-renamed")
|
users, err = db.ListUsers(&types.User{Name: "test-renamed"})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(users), check.Equals, 1)
|
||||||
|
|
||||||
err = db.RenameUser("test-does-not-exit", "test")
|
err = db.RenameUser(99988, "test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
userTest2, err := db.CreateUser("test2")
|
userTest2, err := db.CreateUser("test2")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||||
|
|
||||||
err = db.RenameUser("test2", "test-renamed")
|
want := "UNIQUE constraint failed"
|
||||||
c.Assert(err, check.Equals, ErrUserExists)
|
err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed")
|
||||||
|
if err == nil || !strings.Contains(err.Error(), want) {
|
||||||
|
c.Fatalf("expected failure with unique constraint, want: %q got: %q", want, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
|
@ -97,7 +104,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
newUser, err := db.CreateUser("new")
|
newUser, err := db.CreateUser("new")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
node := types.Node{
|
node := types.Node{
|
||||||
|
@ -111,15 +118,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
c.Assert(trx.Error, check.IsNil)
|
c.Assert(trx.Error, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
c.Assert(node.UserID, check.Equals, oldUser.ID)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, newUser.Name)
|
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, newUser.ID)
|
c.Assert(node.UserID, check.Equals, newUser.ID)
|
||||||
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, "non-existing-user")
|
err = db.AssignNodeToUser(&node, 9584849)
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
err = db.AssignNodeToUser(&node, newUser.Name)
|
err = db.AssignNodeToUser(&node, types.UserID(newUser.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(node.UserID, check.Equals, newUser.ID)
|
c.Assert(node.UserID, check.Equals, newUser.ID)
|
||||||
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
c.Assert(node.User.Name, check.Equals, newUser.Name)
|
||||||
|
|
|
@ -65,24 +65,34 @@ func (api headscaleV1APIServer) RenameUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.RenameUserRequest,
|
request *v1.RenameUserRequest,
|
||||||
) (*v1.RenameUserResponse, error) {
|
) (*v1.RenameUserResponse, error) {
|
||||||
err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName())
|
oldUser, err := api.h.db.GetUserByName(request.GetOldName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := api.h.db.GetUserByName(request.GetNewName())
|
err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &v1.RenameUserResponse{User: user.Proto()}, nil
|
newUser, err := api.h.db.GetUserByName(request.GetNewName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &v1.RenameUserResponse{User: newUser.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api headscaleV1APIServer) DeleteUser(
|
func (api headscaleV1APIServer) DeleteUser(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteUserRequest,
|
request *v1.DeleteUserRequest,
|
||||||
) (*v1.DeleteUserResponse, error) {
|
) (*v1.DeleteUserResponse, error) {
|
||||||
err := api.h.db.DestroyUser(request.GetName())
|
user, err := api.h.db.GetUserByName(request.GetName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = api.h.db.DestroyUser(types.UserID(user.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -131,8 +141,13 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
preAuthKey, err := api.h.db.CreatePreAuthKey(
|
preAuthKey, err := api.h.db.CreatePreAuthKey(
|
||||||
request.GetUser(),
|
types.UserID(user.ID),
|
||||||
request.GetReusable(),
|
request.GetReusable(),
|
||||||
request.GetEphemeral(),
|
request.GetEphemeral(),
|
||||||
&expiration,
|
&expiration,
|
||||||
|
@ -168,7 +183,12 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListPreAuthKeysRequest,
|
request *v1.ListPreAuthKeysRequest,
|
||||||
) (*v1.ListPreAuthKeysResponse, error) {
|
) (*v1.ListPreAuthKeysResponse, error) {
|
||||||
preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser())
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -406,10 +426,20 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListNodesRequest,
|
request *v1.ListNodesRequest,
|
||||||
) (*v1.ListNodesResponse, error) {
|
) (*v1.ListNodesResponse, error) {
|
||||||
|
// TODO(kradalby): it looks like this can be simplified a lot,
|
||||||
|
// the filtering of nodes by user, vs nodes as a whole can
|
||||||
|
// probably be done once.
|
||||||
|
// TODO(kradalby): This should be done in one tx.
|
||||||
|
|
||||||
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||||
if request.GetUser() != "" {
|
if request.GetUser() != "" {
|
||||||
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return db.ListNodesByUser(rx, request.GetUser())
|
return db.ListNodesByUser(rx, types.UserID(user.ID))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -465,12 +495,18 @@ func (api headscaleV1APIServer) MoveNode(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.MoveNodeRequest,
|
request *v1.MoveNodeRequest,
|
||||||
) (*v1.MoveNodeResponse, error) {
|
) (*v1.MoveNodeResponse, error) {
|
||||||
|
// TODO(kradalby): This should be done in one tx.
|
||||||
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = api.h.db.AssignNodeToUser(node, request.GetUser())
|
user, err := api.h.db.GetUserByName(request.GetUser())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -737,14 +773,18 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||||
}
|
}
|
||||||
|
users, err := api.h.db.ListUsers()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = pol.CompileFilterRules(nodes)
|
_, err = pol.CompileFilterRules(users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("verifying policy rules: %w", err)
|
return nil, fmt.Errorf("verifying policy rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nodes) > 0 {
|
if len(nodes) > 0 {
|
||||||
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
|
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||||
func (m *Mapper) fullMapResponse(
|
func (m *Mapper) fullMapResponse(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
|
users []types.User,
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
|
@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse(
|
||||||
pol,
|
pol,
|
||||||
node,
|
node,
|
||||||
capVer,
|
capVer,
|
||||||
|
users,
|
||||||
peers,
|
peers,
|
||||||
peers,
|
peers,
|
||||||
m.cfg,
|
m.cfg,
|
||||||
|
@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
users, err := m.db.ListUsers()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
|
resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
users, err := m.db.ListUsers()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listing users for map response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
var removedIDs []tailcfg.NodeID
|
var removedIDs []tailcfg.NodeID
|
||||||
var changedIDs []types.NodeID
|
var changedIDs []types.NodeID
|
||||||
for nodeID, nodeChanged := range changed {
|
for nodeID, nodeChanged := range changed {
|
||||||
|
@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse(
|
||||||
pol,
|
pol,
|
||||||
node,
|
node,
|
||||||
mapRequest.Version,
|
mapRequest.Version,
|
||||||
|
users,
|
||||||
peers,
|
peers,
|
||||||
changedNodes,
|
changedNodes,
|
||||||
m.cfg,
|
m.cfg,
|
||||||
|
@ -508,16 +520,17 @@ func appendPeerChanges(
|
||||||
pol *policy.ACLPolicy,
|
pol *policy.ACLPolicy,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
|
users []types.User,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
changed types.Nodes,
|
changed types.Nodes,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
) error {
|
) error {
|
||||||
packetFilter, err := pol.CompileFilterRules(append(peers, node))
|
packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
|
sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
|
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
|
||||||
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
|
||||||
|
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
|
||||||
|
|
||||||
mini := &types.Node{
|
mini := &types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: mustMK(
|
MachineKey: mustMK(
|
||||||
|
@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
Hostname: "mini",
|
Hostname: "mini",
|
||||||
GivenName: "mini",
|
GivenName: "mini",
|
||||||
UserID: 0,
|
UserID: user1.ID,
|
||||||
User: types.User{Name: "mini"},
|
User: user1,
|
||||||
ForcedTags: []string{},
|
ForcedTags: []string{},
|
||||||
AuthKey: &types.PreAuthKey{},
|
AuthKey: &types.PreAuthKey{},
|
||||||
LastSeen: &lastSeen,
|
LastSeen: &lastSeen,
|
||||||
|
@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
Hostname: "peer1",
|
Hostname: "peer1",
|
||||||
GivenName: "peer1",
|
GivenName: "peer1",
|
||||||
UserID: 0,
|
UserID: user1.ID,
|
||||||
User: types.User{Name: "mini"},
|
User: user1,
|
||||||
ForcedTags: []string{},
|
ForcedTags: []string{},
|
||||||
LastSeen: &lastSeen,
|
LastSeen: &lastSeen,
|
||||||
Expiry: &expire,
|
Expiry: &expire,
|
||||||
|
@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
IPv4: iap("100.64.0.3"),
|
IPv4: iap("100.64.0.3"),
|
||||||
Hostname: "peer2",
|
Hostname: "peer2",
|
||||||
GivenName: "peer2",
|
GivenName: "peer2",
|
||||||
UserID: 1,
|
UserID: user2.ID,
|
||||||
User: types.User{Name: "peer2"},
|
User: user2,
|
||||||
ForcedTags: []string{},
|
ForcedTags: []string{},
|
||||||
LastSeen: &lastSeen,
|
LastSeen: &lastSeen,
|
||||||
Expiry: &expire,
|
Expiry: &expire,
|
||||||
|
@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
got, err := mappy.fullMapResponse(
|
got, err := mappy.fullMapResponse(
|
||||||
tt.node,
|
tt.node,
|
||||||
tt.peers,
|
tt.peers,
|
||||||
|
[]types.User{user1, user2},
|
||||||
tt.pol,
|
tt.pol,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
|
@ -436,24 +436,41 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||||
) (*types.User, error) {
|
) (*types.User, error) {
|
||||||
var user *types.User
|
var user *types.User
|
||||||
var err error
|
var err error
|
||||||
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
|
user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier())
|
||||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// This check is for legacy, if the user cannot be found by the OIDC identifier
|
// This check is for legacy, if the user cannot be found by the OIDC identifier
|
||||||
// look it up by username. This should only be needed once.
|
// look it up by username. This should only be needed once.
|
||||||
if user == nil {
|
// This branch will presist for a number of versions after the OIDC migration and
|
||||||
user, err = a.db.GetUserByName(claims.Username)
|
// then be removed following a deprecation.
|
||||||
|
// TODO(kradalby): Remove when strip_email_domain and migration is removed
|
||||||
|
// after #2170 is cleaned up.
|
||||||
|
if a.cfg.MapLegacyUsers && user == nil {
|
||||||
|
log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username")
|
||||||
|
if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil {
|
||||||
|
log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username")
|
||||||
|
user, err = a.db.GetUserByName(oldUsername)
|
||||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
return nil, fmt.Errorf("getting user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the user exists, but it already has a provider identifier (OIDC sub), create a new user.
|
||||||
|
// This is to prevent users that have already been migrated to the new OIDC format
|
||||||
|
// to be updated with the new OIDC identifier inexplicitly which might be the cause of an
|
||||||
|
// account takeover.
|
||||||
|
if user != nil && user.ProviderIdentifier.Valid {
|
||||||
|
log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.")
|
||||||
|
user = &types.User{}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the user is still not found, create a new empty user.
|
// if the user is still not found, create a new empty user.
|
||||||
if user == nil {
|
if user == nil {
|
||||||
user = &types.User{}
|
user = &types.User{}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
user.FromClaim(claims)
|
user.FromClaim(claims)
|
||||||
err = a.db.DB.Save(user).Error
|
err = a.db.DB.Save(user).Error
|
||||||
|
@ -488,7 +505,7 @@ func (a *AuthProviderOIDC) registerNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby):
|
// TODO(kradalby):
|
||||||
// Rewrite in elem-go
|
// Rewrite in elem-go.
|
||||||
func renderOIDCCallbackTemplate(
|
func renderOIDCCallbackTemplate(
|
||||||
user *types.User,
|
user *types.User,
|
||||||
) (*bytes.Buffer, error) {
|
) (*bytes.Buffer, error) {
|
||||||
|
@ -502,3 +519,24 @@ func renderOIDCCallbackTemplate(
|
||||||
|
|
||||||
return &content, nil
|
return &content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Reintroduce when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
// DEPRECATED: DO NOT USE
|
||||||
|
func getUserName(
|
||||||
|
claims *types.OIDCClaims,
|
||||||
|
stripEmaildomain bool,
|
||||||
|
) (string, error) {
|
||||||
|
if !claims.EmailVerified {
|
||||||
|
return "", fmt.Errorf("email not verified")
|
||||||
|
}
|
||||||
|
userName, err := util.NormalizeToFQDNRules(
|
||||||
|
claims.Email,
|
||||||
|
stripEmaildomain,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return userName, nil
|
||||||
|
}
|
||||||
|
|
|
@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests(
|
||||||
policy *ACLPolicy,
|
policy *ACLPolicy,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
|
users []types.User,
|
||||||
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
||||||
// If there is no policy defined, we default to allow all
|
// If there is no policy defined, we default to allow all
|
||||||
if policy == nil {
|
if policy == nil {
|
||||||
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := policy.CompileFilterRules(append(peers, node))
|
rules, err := policy.CompileFilterRules(users, append(peers, node))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
|
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
|
||||||
|
|
||||||
sshPolicy, err := policy.CompileSSHPolicy(node, peers)
|
sshPolicy, err := policy.CompileSSHPolicy(node, users, peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests(
|
||||||
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||||
func (pol *ACLPolicy) CompileFilterRules(
|
func (pol *ACLPolicy) CompileFilterRules(
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) ([]tailcfg.FilterRule, error) {
|
) ([]tailcfg.FilterRule, error) {
|
||||||
if pol == nil {
|
if pol == nil {
|
||||||
|
@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules(
|
||||||
|
|
||||||
var srcIPs []string
|
var srcIPs []string
|
||||||
for srcIndex, src := range acl.Sources {
|
for srcIndex, src := range acl.Sources {
|
||||||
srcs, err := pol.expandSource(src, nodes)
|
srcs, err := pol.expandSource(src, users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing policy, acl index: %d->%d: %w",
|
||||||
|
index,
|
||||||
|
srcIndex,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
srcIPs = append(srcIPs, srcs...)
|
srcIPs = append(srcIPs, srcs...)
|
||||||
}
|
}
|
||||||
|
@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules(
|
||||||
|
|
||||||
expanded, err := pol.ExpandAlias(
|
expanded, err := pol.ExpandAlias(
|
||||||
nodes,
|
nodes,
|
||||||
|
users,
|
||||||
alias,
|
alias,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
||||||
|
|
||||||
func (pol *ACLPolicy) CompileSSHPolicy(
|
func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
|
users []types.User,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
) (*tailcfg.SSHPolicy, error) {
|
) (*tailcfg.SSHPolicy, error) {
|
||||||
if pol == nil {
|
if pol == nil {
|
||||||
|
@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
for index, sshACL := range pol.SSHs {
|
for index, sshACL := range pol.SSHs {
|
||||||
var dest netipx.IPSetBuilder
|
var dest netipx.IPSetBuilder
|
||||||
for _, src := range sshACL.Destinations {
|
for _, src := range sshACL.Destinations {
|
||||||
expanded, err := pol.ExpandAlias(append(peers, node), src)
|
expanded, err := pol.ExpandAlias(append(peers, node), users, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
case "check":
|
case "check":
|
||||||
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
|
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing SSH policy, parsing check duration, index: %d: %w",
|
||||||
|
index,
|
||||||
|
err,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
action = *checkAction
|
action = *checkAction
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing SSH policy, unknown action %q, index: %d: %w",
|
||||||
|
sshACL.Action,
|
||||||
|
index,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
||||||
|
@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
} else {
|
} else {
|
||||||
expandedSrcs, err := pol.ExpandAlias(
|
expandedSrcs, err := pol.ExpandAlias(
|
||||||
peers,
|
peers,
|
||||||
|
users,
|
||||||
rawSrc,
|
rawSrc,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
|
||||||
// with the given src alias.
|
// with the given src alias.
|
||||||
func (pol *ACLPolicy) expandSource(
|
func (pol *ACLPolicy) expandSource(
|
||||||
src string,
|
src string,
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
ipSet, err := pol.ExpandAlias(nodes, src)
|
ipSet, err := pol.ExpandAlias(nodes, users, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, err
|
return []string{}, err
|
||||||
}
|
}
|
||||||
|
@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource(
|
||||||
// and transform these in IPAddresses.
|
// and transform these in IPAddresses.
|
||||||
func (pol *ACLPolicy) ExpandAlias(
|
func (pol *ACLPolicy) ExpandAlias(
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
|
users []types.User,
|
||||||
alias string,
|
alias string,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
if isWildcard(alias) {
|
if isWildcard(alias) {
|
||||||
|
@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
|
|
||||||
// if alias is a group
|
// if alias is a group
|
||||||
if isGroup(alias) {
|
if isGroup(alias) {
|
||||||
return pol.expandIPsFromGroup(alias, nodes)
|
return pol.expandIPsFromGroup(alias, users, nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is a tag
|
// if alias is a tag
|
||||||
if isTag(alias) {
|
if isTag(alias) {
|
||||||
return pol.expandIPsFromTag(alias, nodes)
|
return pol.expandIPsFromTag(alias, users, nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isAutoGroup(alias) {
|
if isAutoGroup(alias) {
|
||||||
|
@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is a user
|
// if alias is a user
|
||||||
if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil {
|
if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil {
|
||||||
return ips, err
|
return ips, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
if h, ok := pol.Hosts[alias]; ok {
|
if h, ok := pol.Hosts[alias]; ok {
|
||||||
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
||||||
|
|
||||||
return pol.ExpandAlias(nodes, h.String())
|
return pol.ExpandAlias(nodes, users, h.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is an IP
|
// if alias is an IP
|
||||||
|
@ -599,7 +620,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
// TODO(kradalby): It is quite hard to understand what this function is doing,
|
// TODO(kradalby): It is quite hard to understand what this function is doing,
|
||||||
// it seems like it trying to ensure that we dont include nodes that are tagged
|
// it seems like it trying to ensure that we dont include nodes that are tagged
|
||||||
// when we look up the nodes owned by a user.
|
// when we look up the nodes owned by a user.
|
||||||
// This should be refactored to be more clear as part of the Tags work in #1369
|
// This should be refactored to be more clear as part of the Tags work in #1369.
|
||||||
func excludeCorrectlyTaggedNodes(
|
func excludeCorrectlyTaggedNodes(
|
||||||
aclPolicy *ACLPolicy,
|
aclPolicy *ACLPolicy,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
|
@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup(
|
||||||
|
|
||||||
func (pol *ACLPolicy) expandIPsFromGroup(
|
func (pol *ACLPolicy) expandIPsFromGroup(
|
||||||
group string,
|
group string,
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
var build netipx.IPSetBuilder
|
var build netipx.IPSetBuilder
|
||||||
|
|
||||||
users, err := pol.expandUsersFromGroup(group)
|
userTokens, err := pol.expandUsersFromGroup(group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &netipx.IPSet{}, err
|
return &netipx.IPSet{}, err
|
||||||
}
|
}
|
||||||
for _, user := range users {
|
for _, user := range userTokens {
|
||||||
filteredNodes := filterNodesByUser(nodes, user)
|
filteredNodes := filterNodesByUser(nodes, users, user)
|
||||||
for _, node := range filteredNodes {
|
for _, node := range filteredNodes {
|
||||||
node.AppendToIPSet(&build)
|
node.AppendToIPSet(&build)
|
||||||
}
|
}
|
||||||
|
@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
|
||||||
|
|
||||||
func (pol *ACLPolicy) expandIPsFromTag(
|
func (pol *ACLPolicy) expandIPsFromTag(
|
||||||
alias string,
|
alias string,
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
var build netipx.IPSetBuilder
|
var build netipx.IPSetBuilder
|
||||||
|
@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
||||||
|
|
||||||
// filter out nodes per tag owner
|
// filter out nodes per tag owner
|
||||||
for _, user := range owners {
|
for _, user := range owners {
|
||||||
nodes := filterNodesByUser(nodes, user)
|
nodes := filterNodesByUser(nodes, users, user)
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.Hostinfo == nil {
|
if node.Hostinfo == nil {
|
||||||
continue
|
continue
|
||||||
|
@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
||||||
|
|
||||||
func (pol *ACLPolicy) expandIPsFromUser(
|
func (pol *ACLPolicy) expandIPsFromUser(
|
||||||
user string,
|
user string,
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
var build netipx.IPSetBuilder
|
var build netipx.IPSetBuilder
|
||||||
|
|
||||||
filteredNodes := filterNodesByUser(nodes, user)
|
filteredNodes := filterNodesByUser(nodes, users, user)
|
||||||
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
|
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
|
||||||
|
|
||||||
// shortcurcuit if we have no nodes to get ips from.
|
// shortcurcuit if we have no nodes to get ips from.
|
||||||
|
@ -953,10 +977,43 @@ func (pol *ACLPolicy) TagsOfNode(
|
||||||
return validTags, invalidTags
|
return validTags, invalidTags
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
|
// filterNodesByUser returns a list of nodes that match the given userToken from a
|
||||||
|
// policy.
|
||||||
|
// Matching nodes are determined by first matching the user token to a user by checking:
|
||||||
|
// - If it is an ID that mactches the user database ID
|
||||||
|
// - It is the Provider Identifier from OIDC
|
||||||
|
// - It matches the username or email of a user
|
||||||
|
//
|
||||||
|
// If the token matches more than one user, zero nodes will returned.
|
||||||
|
func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes {
|
||||||
var out types.Nodes
|
var out types.Nodes
|
||||||
|
|
||||||
|
var potentialUsers []types.User
|
||||||
|
for _, user := range users {
|
||||||
|
if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == userToken {
|
||||||
|
// If a user is matching with a known unique field,
|
||||||
|
// disgard all other users and only keep the current
|
||||||
|
// user.
|
||||||
|
potentialUsers = []types.User{user}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if user.Email == userToken {
|
||||||
|
potentialUsers = append(potentialUsers, user)
|
||||||
|
}
|
||||||
|
if user.Name == userToken {
|
||||||
|
potentialUsers = append(potentialUsers, user)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(potentialUsers) != 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user := potentialUsers[0]
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.User.Username() == user {
|
if node.User.ID == user.ID {
|
||||||
out = append(out, node)
|
out = append(out, node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -977,10 +1034,7 @@ func FilterNodesByACL(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Checking if %s can access %s", node.Hostname, peer.Hostname)
|
|
||||||
|
|
||||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||||
log.Printf("CAN ACCESS %s can access %s", node.Hostname, peer.Hostname)
|
|
||||||
result = append(result, peer)
|
result = append(result, peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -13,7 +13,7 @@ func Windows(url string) *elem.Element {
|
||||||
elem.Text("headscale - Windows"),
|
elem.Text("headscale - Windows"),
|
||||||
),
|
),
|
||||||
elem.Body(attrs.Props{
|
elem.Body(attrs.Props{
|
||||||
attrs.Style : bodyStyle.ToInline(),
|
attrs.Style: bodyStyle.ToInline(),
|
||||||
},
|
},
|
||||||
headerOne("headscale: Windows configuration"),
|
headerOne("headscale: Windows configuration"),
|
||||||
elem.P(nil,
|
elem.P(nil,
|
||||||
|
@ -21,7 +21,8 @@ func Windows(url string) *elem.Element {
|
||||||
elem.A(attrs.Props{
|
elem.A(attrs.Props{
|
||||||
attrs.Href: "https://tailscale.com/download/windows",
|
attrs.Href: "https://tailscale.com/download/windows",
|
||||||
attrs.Rel: "noreferrer noopener",
|
attrs.Rel: "noreferrer noopener",
|
||||||
attrs.Target: "_blank"},
|
attrs.Target: "_blank",
|
||||||
|
},
|
||||||
elem.Text("Tailscale for Windows ")),
|
elem.Text("Tailscale for Windows ")),
|
||||||
elem.Text("and install it."),
|
elem.Text("and install it."),
|
||||||
),
|
),
|
||||||
|
|
|
@ -105,6 +105,7 @@ type Nameservers struct {
|
||||||
type SqliteConfig struct {
|
type SqliteConfig struct {
|
||||||
Path string
|
Path string
|
||||||
WriteAheadLog bool
|
WriteAheadLog bool
|
||||||
|
WALAutoCheckPoint int
|
||||||
}
|
}
|
||||||
|
|
||||||
type PostgresConfig struct {
|
type PostgresConfig struct {
|
||||||
|
@ -163,8 +164,10 @@ type OIDCConfig struct {
|
||||||
AllowedDomains []string
|
AllowedDomains []string
|
||||||
AllowedUsers []string
|
AllowedUsers []string
|
||||||
AllowedGroups []string
|
AllowedGroups []string
|
||||||
|
StripEmaildomain bool
|
||||||
Expiry time.Duration
|
Expiry time.Duration
|
||||||
UseExpiryFromToken bool
|
UseExpiryFromToken bool
|
||||||
|
MapLegacyUsers bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type DERPConfig struct {
|
type DERPConfig struct {
|
||||||
|
@ -271,11 +274,14 @@ func LoadConfig(path string, isFile bool) error {
|
||||||
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
|
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
|
||||||
|
|
||||||
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
||||||
|
viper.SetDefault("database.sqlite.wal_autocheckpoint", 1000) // SQLite default
|
||||||
|
|
||||||
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
||||||
|
viper.SetDefault("oidc.strip_email_domain", true)
|
||||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||||
viper.SetDefault("oidc.expiry", "180d")
|
viper.SetDefault("oidc.expiry", "180d")
|
||||||
viper.SetDefault("oidc.use_expiry_from_token", false)
|
viper.SetDefault("oidc.use_expiry_from_token", false)
|
||||||
|
viper.SetDefault("oidc.map_legacy_users", true)
|
||||||
|
|
||||||
viper.SetDefault("logtail.enabled", false)
|
viper.SetDefault("logtail.enabled", false)
|
||||||
viper.SetDefault("randomize_client_port", false)
|
viper.SetDefault("randomize_client_port", false)
|
||||||
|
@ -319,14 +325,18 @@ func validateServerConfig() error {
|
||||||
depr.warn("dns_config.use_username_in_magic_dns")
|
depr.warn("dns_config.use_username_in_magic_dns")
|
||||||
depr.warn("dns.use_username_in_magic_dns")
|
depr.warn("dns.use_username_in_magic_dns")
|
||||||
|
|
||||||
depr.fatal("oidc.strip_email_domain")
|
// TODO(kradalby): Reintroduce when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
// depr.fatal("oidc.strip_email_domain")
|
||||||
depr.fatal("dns.use_username_in_musername_in_magic_dns")
|
depr.fatal("dns.use_username_in_musername_in_magic_dns")
|
||||||
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
|
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
|
||||||
|
|
||||||
depr.Log()
|
depr.Log()
|
||||||
|
|
||||||
for _, removed := range []string{
|
for _, removed := range []string{
|
||||||
"oidc.strip_email_domain",
|
// TODO(kradalby): Reintroduce when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
// "oidc.strip_email_domain",
|
||||||
"dns_config.use_username_in_musername_in_magic_dns",
|
"dns_config.use_username_in_musername_in_magic_dns",
|
||||||
} {
|
} {
|
||||||
if viper.IsSet(removed) {
|
if viper.IsSet(removed) {
|
||||||
|
@ -544,6 +554,7 @@ func databaseConfig() DatabaseConfig {
|
||||||
viper.GetString("database.sqlite.path"),
|
viper.GetString("database.sqlite.path"),
|
||||||
),
|
),
|
||||||
WriteAheadLog: viper.GetBool("database.sqlite.write_ahead_log"),
|
WriteAheadLog: viper.GetBool("database.sqlite.write_ahead_log"),
|
||||||
|
WALAutoCheckPoint: viper.GetInt("database.sqlite.wal_autocheckpoint"),
|
||||||
},
|
},
|
||||||
Postgres: PostgresConfig{
|
Postgres: PostgresConfig{
|
||||||
Host: viper.GetString("database.postgres.host"),
|
Host: viper.GetString("database.postgres.host"),
|
||||||
|
@ -897,6 +908,10 @@ func LoadServerConfig() (*Config, error) {
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
|
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
|
||||||
|
// TODO(kradalby): Remove when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
|
||||||
|
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
|
||||||
},
|
},
|
||||||
|
|
||||||
LogTail: logTailConfig,
|
LogTail: logTailConfig,
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
)
|
)
|
||||||
|
@ -36,8 +37,17 @@ func TestReadConfig(t *testing.T) {
|
||||||
MagicDNS: true,
|
MagicDNS: true,
|
||||||
BaseDomain: "example.com",
|
BaseDomain: "example.com",
|
||||||
Nameservers: Nameservers{
|
Nameservers: Nameservers{
|
||||||
Global: []string{"1.1.1.1", "1.0.0.1", "2606:4700:4700::1111", "2606:4700:4700::1001", "https://dns.nextdns.io/abc123"},
|
Global: []string{
|
||||||
Split: map[string][]string{"darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, "foo.bar.com": {"1.1.1.1"}},
|
"1.1.1.1",
|
||||||
|
"1.0.0.1",
|
||||||
|
"2606:4700:4700::1111",
|
||||||
|
"2606:4700:4700::1001",
|
||||||
|
"https://dns.nextdns.io/abc123",
|
||||||
|
},
|
||||||
|
Split: map[string][]string{
|
||||||
|
"darp.headscale.net": {"1.1.1.1", "8.8.8.8"},
|
||||||
|
"foo.bar.com": {"1.1.1.1"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
ExtraRecords: []tailcfg.DNSRecord{
|
ExtraRecords: []tailcfg.DNSRecord{
|
||||||
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
||||||
|
@ -92,8 +102,17 @@ func TestReadConfig(t *testing.T) {
|
||||||
MagicDNS: false,
|
MagicDNS: false,
|
||||||
BaseDomain: "example.com",
|
BaseDomain: "example.com",
|
||||||
Nameservers: Nameservers{
|
Nameservers: Nameservers{
|
||||||
Global: []string{"1.1.1.1", "1.0.0.1", "2606:4700:4700::1111", "2606:4700:4700::1001", "https://dns.nextdns.io/abc123"},
|
Global: []string{
|
||||||
Split: map[string][]string{"darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, "foo.bar.com": {"1.1.1.1"}},
|
"1.1.1.1",
|
||||||
|
"1.0.0.1",
|
||||||
|
"2606:4700:4700::1111",
|
||||||
|
"2606:4700:4700::1001",
|
||||||
|
"https://dns.nextdns.io/abc123",
|
||||||
|
},
|
||||||
|
Split: map[string][]string{
|
||||||
|
"darp.headscale.net": {"1.1.1.1", "8.8.8.8"},
|
||||||
|
"foo.bar.com": {"1.1.1.1"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
ExtraRecords: []tailcfg.DNSRecord{
|
ExtraRecords: []tailcfg.DNSRecord{
|
||||||
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
{Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"},
|
||||||
|
@ -187,7 +206,7 @@ func TestReadConfig(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
viper.Reset()
|
viper.Reset()
|
||||||
err := LoadConfig(tt.configPath, true)
|
err := LoadConfig(tt.configPath, true)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conf, err := tt.setup(t)
|
conf, err := tt.setup(t)
|
||||||
|
|
||||||
|
@ -197,7 +216,7 @@ func TestReadConfig(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, conf); diff != "" {
|
if diff := cmp.Diff(tt.want, conf); diff != "" {
|
||||||
t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff)
|
t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff)
|
||||||
|
@ -277,10 +296,10 @@ func TestReadConfigFromEnv(t *testing.T) {
|
||||||
|
|
||||||
viper.Reset()
|
viper.Reset()
|
||||||
err := LoadConfig("testdata/minimal.yaml", true)
|
err := LoadConfig("testdata/minimal.yaml", true)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conf, err := tt.setup(t)
|
conf, err := tt.setup(t)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, conf); diff != "" {
|
if diff := cmp.Diff(tt.want, conf); diff != "" {
|
||||||
t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff)
|
t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff)
|
||||||
|
@ -311,13 +330,25 @@ noise:
|
||||||
|
|
||||||
// Check configuration validation errors (1)
|
// Check configuration validation errors (1)
|
||||||
err = LoadConfig(tmpDir, false)
|
err = LoadConfig(tmpDir, false)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = validateServerConfig()
|
err = validateServerConfig()
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both")
|
assert.Contains(
|
||||||
assert.Contains(t, err.Error(), "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are")
|
t,
|
||||||
assert.Contains(t, err.Error(), "Fatal config error: server_url must start with https:// or http://")
|
err.Error(),
|
||||||
|
"Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both",
|
||||||
|
)
|
||||||
|
assert.Contains(
|
||||||
|
t,
|
||||||
|
err.Error(),
|
||||||
|
"Fatal config error: the only supported values for tls_letsencrypt_challenge_type are",
|
||||||
|
)
|
||||||
|
assert.Contains(
|
||||||
|
t,
|
||||||
|
err.Error(),
|
||||||
|
"Fatal config error: server_url must start with https:// or http://",
|
||||||
|
)
|
||||||
|
|
||||||
// Check configuration validation errors (2)
|
// Check configuration validation errors (2)
|
||||||
configYaml = []byte(`---
|
configYaml = []byte(`---
|
||||||
|
@ -332,7 +363,7 @@ tls_letsencrypt_challenge_type: TLS-ALPN-01
|
||||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||||
}
|
}
|
||||||
err = LoadConfig(tmpDir, false)
|
err = LoadConfig(tmpDir, false)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OK
|
// OK
|
||||||
|
|
|
@ -26,7 +26,7 @@ type PreAuthKey struct {
|
||||||
|
|
||||||
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
||||||
protoKey := v1.PreAuthKey{
|
protoKey := v1.PreAuthKey{
|
||||||
User: key.User.Name,
|
User: key.User.Username(),
|
||||||
Id: strconv.FormatUint(key.ID, util.Base10),
|
Id: strconv.FormatUint(key.ID, util.Base10),
|
||||||
Key: key.Key,
|
Key: key.Key,
|
||||||
Ephemeral: key.Ephemeral,
|
Ephemeral: key.Ephemeral,
|
||||||
|
|
|
@ -2,7 +2,9 @@ package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/mail"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -34,10 +36,14 @@ func (u Users) String() string {
|
||||||
// that contain our machines.
|
// that contain our machines.
|
||||||
type User struct {
|
type User struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
|
// The index `idx_name_provider_identifier` is to enforce uniqueness
|
||||||
|
// between Name and ProviderIdentifier. This ensures that
|
||||||
|
// you can have multiple users with the same name in OIDC,
|
||||||
|
// but not if you only run with CLI users.
|
||||||
|
|
||||||
// Username for the user, is used if email is empty
|
// Username for the user, is used if email is empty
|
||||||
// Should not be used, please use Username().
|
// Should not be used, please use Username().
|
||||||
Name string `gorm:"unique"`
|
Name string
|
||||||
|
|
||||||
// Typically the full name of the user
|
// Typically the full name of the user
|
||||||
DisplayName string
|
DisplayName string
|
||||||
|
@ -49,7 +55,7 @@ type User struct {
|
||||||
// Unique identifier of the user from OIDC,
|
// Unique identifier of the user from OIDC,
|
||||||
// comes from `sub` claim in the OIDC token
|
// comes from `sub` claim in the OIDC token
|
||||||
// and is used to lookup the user.
|
// and is used to lookup the user.
|
||||||
ProviderIdentifier string `gorm:"index"`
|
ProviderIdentifier sql.NullString
|
||||||
|
|
||||||
// Provider is the origin of the user account,
|
// Provider is the origin of the user account,
|
||||||
// same as RegistrationMethod, without authkey.
|
// same as RegistrationMethod, without authkey.
|
||||||
|
@ -66,7 +72,12 @@ type User struct {
|
||||||
// should be used throughout headscale, in information returned to the
|
// should be used throughout headscale, in information returned to the
|
||||||
// user and the Policy engine.
|
// user and the Policy engine.
|
||||||
func (u *User) Username() string {
|
func (u *User) Username() string {
|
||||||
return cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
|
return cmp.Or(
|
||||||
|
u.Email,
|
||||||
|
u.Name,
|
||||||
|
u.ProviderIdentifier.String,
|
||||||
|
strconv.FormatUint(uint64(u.ID), 10),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
||||||
|
@ -122,7 +133,7 @@ func (u *User) Proto() *v1.User {
|
||||||
CreatedAt: timestamppb.New(u.CreatedAt),
|
CreatedAt: timestamppb.New(u.CreatedAt),
|
||||||
DisplayName: u.DisplayName,
|
DisplayName: u.DisplayName,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
ProviderId: u.ProviderIdentifier,
|
ProviderId: u.ProviderIdentifier.String,
|
||||||
Provider: u.Provider,
|
Provider: u.Provider,
|
||||||
ProfilePicUrl: u.ProfilePicURL,
|
ProfilePicUrl: u.ProfilePicURL,
|
||||||
}
|
}
|
||||||
|
@ -131,6 +142,7 @@ func (u *User) Proto() *v1.User {
|
||||||
type OIDCClaims struct {
|
type OIDCClaims struct {
|
||||||
// Sub is the user's unique identifier at the provider.
|
// Sub is the user's unique identifier at the provider.
|
||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
|
||||||
// Name is the user's full name.
|
// Name is the user's full name.
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
|
@ -141,13 +153,27 @@ type OIDCClaims struct {
|
||||||
Username string `json:"preferred_username,omitempty"`
|
Username string `json:"preferred_username,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *OIDCClaims) Identifier() string {
|
||||||
|
return c.Iss + "/" + c.Sub
|
||||||
|
}
|
||||||
|
|
||||||
// FromClaim overrides a User from OIDC claims.
|
// FromClaim overrides a User from OIDC claims.
|
||||||
// All fields will be updated, except for the ID.
|
// All fields will be updated, except for the ID.
|
||||||
func (u *User) FromClaim(claims *OIDCClaims) {
|
func (u *User) FromClaim(claims *OIDCClaims) {
|
||||||
u.ProviderIdentifier = claims.Sub
|
err := util.CheckForFQDNRules(claims.Username)
|
||||||
u.DisplayName = claims.Name
|
if err == nil {
|
||||||
u.Email = claims.Email
|
|
||||||
u.Name = claims.Username
|
u.Name = claims.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.EmailVerified {
|
||||||
|
_, err = mail.ParseAddress(claims.Email)
|
||||||
|
if err == nil {
|
||||||
|
u.Email = claims.Email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true}
|
||||||
|
u.DisplayName = claims.Name
|
||||||
u.ProfilePicURL = claims.ProfilePictureURL
|
u.ProfilePicURL = claims.ProfilePictureURL
|
||||||
u.Provider = util.RegisterMethodOIDC
|
u.Provider = util.RegisterMethodOIDC
|
||||||
}
|
}
|
||||||
|
|
|
@ -182,3 +182,33 @@ func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||||
|
|
||||||
return fqdns
|
return fqdns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Reintroduce when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
// DEPRECATED: DO NOT USE
|
||||||
|
// NormalizeToFQDNRules will replace forbidden chars in user
|
||||||
|
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
||||||
|
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
||||||
|
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
name = strings.ReplaceAll(name, "'", "")
|
||||||
|
atIdx := strings.Index(name, "@")
|
||||||
|
if stripEmailDomain && atIdx > 0 {
|
||||||
|
name = name[:atIdx]
|
||||||
|
} else {
|
||||||
|
name = strings.ReplaceAll(name, "@", ".")
|
||||||
|
}
|
||||||
|
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
|
||||||
|
|
||||||
|
for _, elt := range strings.Split(name, ".") {
|
||||||
|
if len(elt) > LabelHostnameLength {
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"label %v is more than 63 chars: %w",
|
||||||
|
elt,
|
||||||
|
ErrInvalidUserName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
|
|
@ -4,12 +4,13 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGenerateRandomStringDNSSafe(t *testing.T) {
|
func TestGenerateRandomStringDNSSafe(t *testing.T) {
|
||||||
for i := 0; i < 100000; i++ {
|
for i := 0; i < 100000; i++ {
|
||||||
str, err := GenerateRandomStringDNSSafe(8)
|
str, err := GenerateRandomStringDNSSafe(8)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, str, 8)
|
assert.Len(t, str, 8)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var veryLargeDestination = []string{
|
var veryLargeDestination = []string{
|
||||||
|
@ -54,7 +55,7 @@ func aclScenario(
|
||||||
) *Scenario {
|
) *Scenario {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": clientsPerUser,
|
"user1": clientsPerUser,
|
||||||
|
@ -77,10 +78,10 @@ func aclScenario(
|
||||||
hsic.WithACLPolicy(policy),
|
hsic.WithACLPolicy(policy),
|
||||||
hsic.WithTestName("acl"),
|
hsic.WithTestName("acl"),
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||||
assertNoErrListFQDN(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return scenario
|
return scenario
|
||||||
}
|
}
|
||||||
|
@ -267,7 +268,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
for name, testCase := range tests {
|
for name, testCase := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
spec := testCase.users
|
spec := testCase.users
|
||||||
|
|
||||||
|
@ -275,22 +276,22 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
[]tsic.Option{},
|
[]tsic.Option{},
|
||||||
hsic.WithACLPolicy(&testCase.policy),
|
hsic.WithACLPolicy(&testCase.policy),
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
allClients, err := scenario.ListTailscaleClients()
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = scenario.WaitForTailscaleSyncWithPeerCount(testCase.want["user1"])
|
err = scenario.WaitForTailscaleSyncWithPeerCount(testCase.want["user1"])
|
||||||
assertNoErrSync(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
status, err := client.Status()
|
status, err := client.Status()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user := status.User[status.Self.UserID].LoginName
|
user := status.User[status.Self.UserID].LoginName
|
||||||
|
|
||||||
assert.Equal(t, (testCase.want[user]), len(status.Peer))
|
assert.Len(t, status.Peer, (testCase.want[user]))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -319,23 +320,23 @@ func TestACLAllowUser80Dst(t *testing.T) {
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test that user1 can visit all user2
|
// Test that user1 can visit all user2
|
||||||
for _, client := range user1Clients {
|
for _, client := range user1Clients {
|
||||||
for _, peer := range user2Clients {
|
for _, peer := range user2Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Len(t, result, 13)
|
assert.Len(t, result, 13)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,14 +344,14 @@ func TestACLAllowUser80Dst(t *testing.T) {
|
||||||
for _, client := range user2Clients {
|
for _, client := range user2Clients {
|
||||||
for _, peer := range user1Clients {
|
for _, peer := range user1Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -376,10 +377,10 @@ func TestACLDenyAllPort80(t *testing.T) {
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
allClients, err := scenario.ListTailscaleClients()
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
allHostnames, err := scenario.ListTailscaleClientsFQDNs()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, client := range allClients {
|
for _, client := range allClients {
|
||||||
for _, hostname := range allHostnames {
|
for _, hostname := range allHostnames {
|
||||||
|
@ -394,7 +395,7 @@ func TestACLDenyAllPort80(t *testing.T) {
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -420,23 +421,23 @@ func TestACLAllowUserDst(t *testing.T) {
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test that user1 can visit all user2
|
// Test that user1 can visit all user2
|
||||||
for _, client := range user1Clients {
|
for _, client := range user1Clients {
|
||||||
for _, peer := range user2Clients {
|
for _, peer := range user2Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Len(t, result, 13)
|
assert.Len(t, result, 13)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -444,14 +445,14 @@ func TestACLAllowUserDst(t *testing.T) {
|
||||||
for _, client := range user2Clients {
|
for _, client := range user2Clients {
|
||||||
for _, peer := range user1Clients {
|
for _, peer := range user1Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -476,23 +477,23 @@ func TestACLAllowStarDst(t *testing.T) {
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test that user1 can visit all user2
|
// Test that user1 can visit all user2
|
||||||
for _, client := range user1Clients {
|
for _, client := range user1Clients {
|
||||||
for _, peer := range user2Clients {
|
for _, peer := range user2Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Len(t, result, 13)
|
assert.Len(t, result, 13)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -500,14 +501,14 @@ func TestACLAllowStarDst(t *testing.T) {
|
||||||
for _, client := range user2Clients {
|
for _, client := range user2Clients {
|
||||||
for _, peer := range user1Clients {
|
for _, peer := range user1Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -537,23 +538,23 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test that user1 can visit all user2
|
// Test that user1 can visit all user2
|
||||||
for _, client := range user1Clients {
|
for _, client := range user1Clients {
|
||||||
for _, peer := range user2Clients {
|
for _, peer := range user2Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Len(t, result, 13)
|
assert.Len(t, result, 13)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -561,14 +562,14 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
|
||||||
for _, client := range user2Clients {
|
for _, client := range user2Clients {
|
||||||
for _, peer := range user1Clients {
|
for _, peer := range user1Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Len(t, result, 13)
|
assert.Len(t, result, 13)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -679,10 +680,10 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test1ip4 := netip.MustParseAddr("100.64.0.1")
|
test1ip4 := netip.MustParseAddr("100.64.0.1")
|
||||||
test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
|
test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
|
||||||
test1, err := scenario.FindTailscaleClientByIP(test1ip6)
|
test1, err := scenario.FindTailscaleClientByIP(test1ip6)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
test1fqdn, err := test1.FQDN()
|
test1fqdn, err := test1.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
test1ip4URL := fmt.Sprintf("http://%s/etc/hostname", test1ip4.String())
|
test1ip4URL := fmt.Sprintf("http://%s/etc/hostname", test1ip4.String())
|
||||||
test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String())
|
test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String())
|
||||||
test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn)
|
test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn)
|
||||||
|
@ -690,10 +691,10 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test2ip4 := netip.MustParseAddr("100.64.0.2")
|
test2ip4 := netip.MustParseAddr("100.64.0.2")
|
||||||
test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2")
|
test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2")
|
||||||
test2, err := scenario.FindTailscaleClientByIP(test2ip6)
|
test2, err := scenario.FindTailscaleClientByIP(test2ip6)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
test2fqdn, err := test2.FQDN()
|
test2fqdn, err := test2.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
test2ip4URL := fmt.Sprintf("http://%s/etc/hostname", test2ip4.String())
|
test2ip4URL := fmt.Sprintf("http://%s/etc/hostname", test2ip4.String())
|
||||||
test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String())
|
test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String())
|
||||||
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
|
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
|
||||||
|
@ -701,10 +702,10 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test3ip4 := netip.MustParseAddr("100.64.0.3")
|
test3ip4 := netip.MustParseAddr("100.64.0.3")
|
||||||
test3ip6 := netip.MustParseAddr("fd7a:115c:a1e0::3")
|
test3ip6 := netip.MustParseAddr("fd7a:115c:a1e0::3")
|
||||||
test3, err := scenario.FindTailscaleClientByIP(test3ip6)
|
test3, err := scenario.FindTailscaleClientByIP(test3ip6)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
test3fqdn, err := test3.FQDN()
|
test3fqdn, err := test3.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
test3ip4URL := fmt.Sprintf("http://%s/etc/hostname", test3ip4.String())
|
test3ip4URL := fmt.Sprintf("http://%s/etc/hostname", test3ip4.String())
|
||||||
test3ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test3ip6.String())
|
test3ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test3ip6.String())
|
||||||
test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn)
|
test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn)
|
||||||
|
@ -719,7 +720,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test3ip4URL,
|
test3ip4URL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err = test1.Curl(test3ip6URL)
|
result, err = test1.Curl(test3ip6URL)
|
||||||
assert.Lenf(
|
assert.Lenf(
|
||||||
|
@ -730,7 +731,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test3ip6URL,
|
test3ip6URL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err = test1.Curl(test3fqdnURL)
|
result, err = test1.Curl(test3fqdnURL)
|
||||||
assert.Lenf(
|
assert.Lenf(
|
||||||
|
@ -741,7 +742,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test3fqdnURL,
|
test3fqdnURL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// test2 can query test3
|
// test2 can query test3
|
||||||
result, err = test2.Curl(test3ip4URL)
|
result, err = test2.Curl(test3ip4URL)
|
||||||
|
@ -753,7 +754,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test3ip4URL,
|
test3ip4URL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err = test2.Curl(test3ip6URL)
|
result, err = test2.Curl(test3ip6URL)
|
||||||
assert.Lenf(
|
assert.Lenf(
|
||||||
|
@ -764,7 +765,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test3ip6URL,
|
test3ip6URL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err = test2.Curl(test3fqdnURL)
|
result, err = test2.Curl(test3fqdnURL)
|
||||||
assert.Lenf(
|
assert.Lenf(
|
||||||
|
@ -775,33 +776,33 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test3fqdnURL,
|
test3fqdnURL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// test3 cannot query test1
|
// test3 cannot query test1
|
||||||
result, err = test3.Curl(test1ip4URL)
|
result, err = test3.Curl(test1ip4URL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
result, err = test3.Curl(test1ip6URL)
|
result, err = test3.Curl(test1ip6URL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
result, err = test3.Curl(test1fqdnURL)
|
result, err = test3.Curl(test1fqdnURL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
// test3 cannot query test2
|
// test3 cannot query test2
|
||||||
result, err = test3.Curl(test2ip4URL)
|
result, err = test3.Curl(test2ip4URL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
result, err = test3.Curl(test2ip6URL)
|
result, err = test3.Curl(test2ip6URL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
result, err = test3.Curl(test2fqdnURL)
|
result, err = test3.Curl(test2fqdnURL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
// test1 can query test2
|
// test1 can query test2
|
||||||
result, err = test1.Curl(test2ip4URL)
|
result, err = test1.Curl(test2ip4URL)
|
||||||
|
@ -814,7 +815,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
|
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
result, err = test1.Curl(test2ip6URL)
|
result, err = test1.Curl(test2ip6URL)
|
||||||
assert.Lenf(
|
assert.Lenf(
|
||||||
t,
|
t,
|
||||||
|
@ -824,7 +825,7 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test2ip6URL,
|
test2ip6URL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err = test1.Curl(test2fqdnURL)
|
result, err = test1.Curl(test2fqdnURL)
|
||||||
assert.Lenf(
|
assert.Lenf(
|
||||||
|
@ -835,20 +836,20 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||||
test2fqdnURL,
|
test2fqdnURL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// test2 cannot query test1
|
// test2 cannot query test1
|
||||||
result, err = test2.Curl(test1ip4URL)
|
result, err = test2.Curl(test1ip4URL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
result, err = test2.Curl(test1ip6URL)
|
result, err = test2.Curl(test1ip6URL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
result, err = test2.Curl(test1fqdnURL)
|
result, err = test2.Curl(test1fqdnURL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -946,10 +947,10 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
|
test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
|
||||||
test1, err := scenario.FindTailscaleClientByIP(test1ip)
|
test1, err := scenario.FindTailscaleClientByIP(test1ip)
|
||||||
assert.NotNil(t, test1)
|
assert.NotNil(t, test1)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
test1fqdn, err := test1.FQDN()
|
test1fqdn, err := test1.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String())
|
test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String())
|
||||||
test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String())
|
test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String())
|
||||||
test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn)
|
test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn)
|
||||||
|
@ -958,10 +959,10 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2")
|
test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2")
|
||||||
test2, err := scenario.FindTailscaleClientByIP(test2ip)
|
test2, err := scenario.FindTailscaleClientByIP(test2ip)
|
||||||
assert.NotNil(t, test2)
|
assert.NotNil(t, test2)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
test2fqdn, err := test2.FQDN()
|
test2fqdn, err := test2.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String())
|
test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String())
|
||||||
test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String())
|
test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String())
|
||||||
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
|
test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn)
|
||||||
|
@ -976,7 +977,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
test2ipURL,
|
test2ipURL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err = test1.Curl(test2ip6URL)
|
result, err = test1.Curl(test2ip6URL)
|
||||||
assert.Lenf(
|
assert.Lenf(
|
||||||
|
@ -987,7 +988,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
test2ip6URL,
|
test2ip6URL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err = test1.Curl(test2fqdnURL)
|
result, err = test1.Curl(test2fqdnURL)
|
||||||
assert.Lenf(
|
assert.Lenf(
|
||||||
|
@ -998,19 +999,19 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||||
test2fqdnURL,
|
test2fqdnURL,
|
||||||
result,
|
result,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result, err = test2.Curl(test1ipURL)
|
result, err = test2.Curl(test1ipURL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
result, err = test2.Curl(test1ip6URL)
|
result, err = test2.Curl(test1ip6URL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
||||||
result, err = test2.Curl(test1fqdnURL)
|
result, err = test2.Curl(test1fqdnURL)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1020,7 +1021,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -1046,19 +1047,19 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||||
"HEADSCALE_POLICY_MODE": "database",
|
"HEADSCALE_POLICY_MODE": "database",
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = scenario.ListTailscaleClientsFQDNs()
|
_, err = scenario.ListTailscaleClientsFQDNs()
|
||||||
assertNoErrListFQDN(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = scenario.WaitForTailscaleSync()
|
err = scenario.WaitForTailscaleSync()
|
||||||
assertNoErrSync(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user1Clients, err := scenario.ListTailscaleClients("user1")
|
user1Clients, err := scenario.ListTailscaleClients("user1")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
user2Clients, err := scenario.ListTailscaleClients("user2")
|
user2Clients, err := scenario.ListTailscaleClients("user2")
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
all := append(user1Clients, user2Clients...)
|
all := append(user1Clients, user2Clients...)
|
||||||
|
|
||||||
|
@ -1070,19 +1071,19 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Len(t, result, 13)
|
assert.Len(t, result, 13)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
p := policy.ACLPolicy{
|
p := policy.ACLPolicy{
|
||||||
ACLs: []policy.ACL{
|
ACLs: []policy.ACL{
|
||||||
|
@ -1100,7 +1101,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||||
policyFilePath := "/etc/headscale/policy.json"
|
policyFilePath := "/etc/headscale/policy.json"
|
||||||
|
|
||||||
err = headscale.WriteFile(policyFilePath, pBytes)
|
err = headscale.WriteFile(policyFilePath, pBytes)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// No policy is present at this time.
|
// No policy is present at this time.
|
||||||
// Add a new policy from a file.
|
// Add a new policy from a file.
|
||||||
|
@ -1113,7 +1114,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||||
policyFilePath,
|
policyFilePath,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Get the current policy and check
|
// Get the current policy and check
|
||||||
// if it is the same as the one we set.
|
// if it is the same as the one we set.
|
||||||
|
@ -1129,7 +1130,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||||
},
|
},
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, output.ACLs, 1)
|
assert.Len(t, output.ACLs, 1)
|
||||||
|
|
||||||
|
@ -1141,14 +1142,14 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||||
for _, client := range user1Clients {
|
for _, client := range user1Clients {
|
||||||
for _, peer := range user2Clients {
|
for _, peer := range user2Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Len(t, result, 13)
|
assert.Len(t, result, 13)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1156,14 +1157,14 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||||
for _, client := range user2Clients {
|
for _, client := range user2Clients {
|
||||||
for _, peer := range user1Clients {
|
for _, peer := range user1Clients {
|
||||||
fqdn, err := peer.FQDN()
|
fqdn, err := peer.FQDN()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
url := fmt.Sprintf("http://%s/etc/hostname", fqdn)
|
||||||
t.Logf("url from %s to %s", client.Hostname(), url)
|
t.Logf("url from %s to %s", client.Hostname(), url)
|
||||||
|
|
||||||
result, err := client.Curl(url)
|
result, err := client.Curl(url)
|
||||||
assert.Empty(t, result)
|
assert.Empty(t, result)
|
||||||
assert.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package integration
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -10,14 +11,19 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
|
"github.com/oauth2-proxy/mockoidc"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
"github.com/ory/dockertest/v3/docker"
|
"github.com/ory/dockertest/v3/docker"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
|
@ -48,20 +54,34 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
scenario := AuthOIDCScenario{
|
scenario := AuthOIDCScenario{
|
||||||
Scenario: baseScenario,
|
Scenario: baseScenario,
|
||||||
}
|
}
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
// defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
// Logins to MockOIDC is served by a queue with a strict order,
|
||||||
|
// if we use more than one node per user, the order of the logins
|
||||||
|
// will not be deterministic and the test will fail.
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": len(MustTestVersions),
|
"user1": 1,
|
||||||
|
"user2": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL)
|
mockusers := []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
oidcMockUser("user2", false),
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
|
||||||
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||||
|
defer scenario.mockOIDC.Close()
|
||||||
|
|
||||||
oidcMap := map[string]string{
|
oidcMap := map[string]string{
|
||||||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||||
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
// TODO(kradalby): Remove when strip_email_domain is removed
|
||||||
|
// after #2170 is cleaned up
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(
|
err = scenario.CreateHeadscaleEnv(
|
||||||
|
@ -91,6 +111,55 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
|
|
||||||
success := pingAllHelper(t, allClients, allAddrs)
|
success := pingAllHelper(t, allClients, allAddrs)
|
||||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
var listUsers []v1.User
|
||||||
|
err = executeAndUnmarshal(headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"users",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
&listUsers,
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
want := []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Email: "", // Unverified
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: oidcConfig.Issuer + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].Id < listUsers[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||||
|
t.Fatalf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This test is really flaky.
|
// This test is really flaky.
|
||||||
|
@ -111,11 +180,16 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": 3,
|
"user1": 1,
|
||||||
|
"user2": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcConfig, err := scenario.runMockOIDC(shortAccessTTL)
|
oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{
|
||||||
|
oidcMockUser("user1", true),
|
||||||
|
oidcMockUser("user2", false),
|
||||||
|
})
|
||||||
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||||
|
defer scenario.mockOIDC.Close()
|
||||||
|
|
||||||
oidcMap := map[string]string{
|
oidcMap := map[string]string{
|
||||||
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
|
@ -159,6 +233,297 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||||
assertTailscaleNodesLogout(t, allClients)
|
assertTailscaleNodesLogout(t, allClients)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby):
|
||||||
|
// - Test that creates a new user when one exists when migration is turned off
|
||||||
|
// - Test that takes over a user when one exists when migration is turned on
|
||||||
|
// - But email is not verified
|
||||||
|
// - stripped email domain on/off
|
||||||
|
func TestOIDC024UserCreation(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config map[string]string
|
||||||
|
emailVerified bool
|
||||||
|
cliUsers []string
|
||||||
|
oidcUsers []string
|
||||||
|
want func(iss string) []v1.User
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-migration-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||||
|
},
|
||||||
|
emailVerified: true,
|
||||||
|
cliUsers: []string{"user1", "user2"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Email: "user2@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-migration-not-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
|
||||||
|
},
|
||||||
|
emailVerified: false,
|
||||||
|
cliUsers: []string{"user1", "user2"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "migration-strip-domains-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "1",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1",
|
||||||
|
},
|
||||||
|
emailVerified: true,
|
||||||
|
cliUsers: []string{"user1", "user2"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user2",
|
||||||
|
Email: "user2@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "migration-strip-domains-not-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "1",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1",
|
||||||
|
},
|
||||||
|
emailVerified: false,
|
||||||
|
cliUsers: []string{"user1", "user2"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "migration-no-strip-domains-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "1",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||||
|
},
|
||||||
|
emailVerified: true,
|
||||||
|
cliUsers: []string{"user1.headscale.net", "user2.headscale.net"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
// Hmm I think we will have to overwrite the initial name here
|
||||||
|
// createuser with "user1.headscale.net", but oidc with "user1"
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1",
|
||||||
|
Email: "user1@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user2",
|
||||||
|
Email: "user2@headscale.net",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "migration-no-strip-domains-not-verified-email",
|
||||||
|
config: map[string]string{
|
||||||
|
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "1",
|
||||||
|
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
|
||||||
|
},
|
||||||
|
emailVerified: false,
|
||||||
|
cliUsers: []string{"user1.headscale.net", "user2.headscale.net"},
|
||||||
|
oidcUsers: []string{"user1", "user2"},
|
||||||
|
want: func(iss string) []v1.User {
|
||||||
|
return []v1.User{
|
||||||
|
{
|
||||||
|
Id: "1",
|
||||||
|
Name: "user1.headscale.net",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "2",
|
||||||
|
Name: "user1",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "3",
|
||||||
|
Name: "user2.headscale.net",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "4",
|
||||||
|
Name: "user2",
|
||||||
|
Provider: "oidc",
|
||||||
|
ProviderId: iss + "/user2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
scenario := AuthOIDCScenario{
|
||||||
|
Scenario: baseScenario,
|
||||||
|
}
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
spec := map[string]int{}
|
||||||
|
for _, user := range tt.cliUsers {
|
||||||
|
spec[user] = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var mockusers []mockoidc.MockUser
|
||||||
|
for _, user := range tt.oidcUsers {
|
||||||
|
mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified))
|
||||||
|
}
|
||||||
|
|
||||||
|
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
|
||||||
|
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
|
||||||
|
defer scenario.mockOIDC.Close()
|
||||||
|
|
||||||
|
oidcMap := map[string]string{
|
||||||
|
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
|
||||||
|
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
|
||||||
|
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
|
||||||
|
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range tt.config {
|
||||||
|
oidcMap[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(
|
||||||
|
spec,
|
||||||
|
hsic.WithTestName("oidcmigration"),
|
||||||
|
hsic.WithConfigEnv(oidcMap),
|
||||||
|
hsic.WithTLS(),
|
||||||
|
hsic.WithHostnameAsServerURL(),
|
||||||
|
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
|
||||||
|
)
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
// Ensure that the nodes have logged in, this is what
|
||||||
|
// triggers user creation via OIDC.
|
||||||
|
err = scenario.WaitForTailscaleSync()
|
||||||
|
assertNoErrSync(t, err)
|
||||||
|
|
||||||
|
headscale, err := scenario.Headscale()
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
want := tt.want(oidcConfig.Issuer)
|
||||||
|
|
||||||
|
var listUsers []v1.User
|
||||||
|
err = executeAndUnmarshal(headscale,
|
||||||
|
[]string{
|
||||||
|
"headscale",
|
||||||
|
"users",
|
||||||
|
"list",
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
},
|
||||||
|
&listUsers,
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
sort.Slice(listUsers, func(i, j int) bool {
|
||||||
|
return listUsers[i].Id < listUsers[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" {
|
||||||
|
t.Errorf("unexpected users: %s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||||
users map[string]int,
|
users map[string]int,
|
||||||
opts ...hsic.Option,
|
opts ...hsic.Option,
|
||||||
|
@ -174,6 +539,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||||
}
|
}
|
||||||
|
|
||||||
for userName, clientCount := range users {
|
for userName, clientCount := range users {
|
||||||
|
if clientCount != 1 {
|
||||||
|
// OIDC scenario only supports one client per user.
|
||||||
|
// This is because the MockOIDC server can only serve login
|
||||||
|
// requests based on a queue it has been given on startup.
|
||||||
|
// We currently only populates it with one login request per user.
|
||||||
|
return fmt.Errorf("client count must be 1 for OIDC scenario.")
|
||||||
|
}
|
||||||
log.Printf("creating user %s with %d clients", userName, clientCount)
|
log.Printf("creating user %s with %d clients", userName, clientCount)
|
||||||
err = s.CreateUser(userName)
|
err = s.CreateUser(userName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -194,7 +566,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) {
|
func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) {
|
||||||
port, err := dockertestutil.RandomFreeHostPort()
|
port, err := dockertestutil.RandomFreeHostPort()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("could not find an open port: %s", err)
|
log.Fatalf("could not find an open port: %s", err)
|
||||||
|
@ -205,6 +577,11 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
|
||||||
|
|
||||||
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
|
hostname := fmt.Sprintf("hs-oidcmock-%s", hash)
|
||||||
|
|
||||||
|
usersJSON, err := json.Marshal(users)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
mockOidcOptions := &dockertest.RunOptions{
|
mockOidcOptions := &dockertest.RunOptions{
|
||||||
Name: hostname,
|
Name: hostname,
|
||||||
Cmd: []string{"headscale", "mockoidc"},
|
Cmd: []string{"headscale", "mockoidc"},
|
||||||
|
@ -219,11 +596,12 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf
|
||||||
"MOCKOIDC_CLIENT_ID=superclient",
|
"MOCKOIDC_CLIENT_ID=superclient",
|
||||||
"MOCKOIDC_CLIENT_SECRET=supersecret",
|
"MOCKOIDC_CLIENT_SECRET=supersecret",
|
||||||
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
|
fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()),
|
||||||
|
fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
headscaleBuildOptions := &dockertest.BuildOptions{
|
headscaleBuildOptions := &dockertest.BuildOptions{
|
||||||
Dockerfile: "Dockerfile.debug",
|
Dockerfile: hsic.IntegrationTestDockerFileName,
|
||||||
ContextDir: dockerContextPath,
|
ContextDir: dockerContextPath,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,7 +688,6 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||||
|
|
||||||
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
|
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
|
||||||
|
|
||||||
if err := s.pool.Retry(func() error {
|
|
||||||
log.Printf("%s logging in with url", c.Hostname())
|
log.Printf("%s logging in with url", c.Hostname())
|
||||||
httpClient := &http.Client{Transport: insecureTransport}
|
httpClient := &http.Client{Transport: insecureTransport}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -329,6 +706,8 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
|
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
log.Printf("body: %s", body)
|
||||||
|
|
||||||
return errStatusCodeNotOK
|
return errStatusCodeNotOK
|
||||||
}
|
}
|
||||||
|
@ -342,13 +721,7 @@ func (s *AuthOIDCScenario) runTailscaleUp(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Finished request for %s to join tailnet", c.Hostname())
|
log.Printf("Finished request for %s to join tailnet", c.Hostname())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -395,3 +768,12 @@ func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) {
|
||||||
assert.Equal(t, "NeedsLogin", status.BackendState)
|
assert.Equal(t, "NeedsLogin", status.BackendState)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser {
|
||||||
|
return mockoidc.MockUser{
|
||||||
|
Subject: username,
|
||||||
|
PreferredUsername: username,
|
||||||
|
Email: fmt.Sprintf("%s@headscale.net", username),
|
||||||
|
EmailVerified: emailVerified,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
|
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
|
||||||
|
@ -34,7 +35,7 @@ func TestUserCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -43,10 +44,10 @@ func TestUserCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var listUsers []v1.User
|
var listUsers []v1.User
|
||||||
err = executeAndUnmarshal(headscale,
|
err = executeAndUnmarshal(headscale,
|
||||||
|
@ -59,7 +60,7 @@ func TestUserCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listUsers,
|
&listUsers,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result := []string{listUsers[0].GetName(), listUsers[1].GetName()}
|
result := []string{listUsers[0].GetName(), listUsers[1].GetName()}
|
||||||
sort.Strings(result)
|
sort.Strings(result)
|
||||||
|
@ -81,7 +82,7 @@ func TestUserCommand(t *testing.T) {
|
||||||
"newname",
|
"newname",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var listAfterRenameUsers []v1.User
|
var listAfterRenameUsers []v1.User
|
||||||
err = executeAndUnmarshal(headscale,
|
err = executeAndUnmarshal(headscale,
|
||||||
|
@ -94,7 +95,7 @@ func TestUserCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listAfterRenameUsers,
|
&listAfterRenameUsers,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = []string{listAfterRenameUsers[0].GetName(), listAfterRenameUsers[1].GetName()}
|
result = []string{listAfterRenameUsers[0].GetName(), listAfterRenameUsers[1].GetName()}
|
||||||
sort.Strings(result)
|
sort.Strings(result)
|
||||||
|
@ -114,7 +115,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
count := 3
|
count := 3
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -122,13 +123,13 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
keys := make([]*v1.PreAuthKey, count)
|
keys := make([]*v1.PreAuthKey, count)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for index := 0; index < count; index++ {
|
for index := 0; index < count; index++ {
|
||||||
var preAuthKey v1.PreAuthKey
|
var preAuthKey v1.PreAuthKey
|
||||||
|
@ -150,7 +151,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&preAuthKey,
|
&preAuthKey,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
keys[index] = &preAuthKey
|
keys[index] = &preAuthKey
|
||||||
}
|
}
|
||||||
|
@ -171,7 +172,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listedPreAuthKeys,
|
&listedPreAuthKeys,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// There is one key created by "scenario.CreateHeadscaleEnv"
|
// There is one key created by "scenario.CreateHeadscaleEnv"
|
||||||
assert.Len(t, listedPreAuthKeys, 4)
|
assert.Len(t, listedPreAuthKeys, 4)
|
||||||
|
@ -212,7 +213,9 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"})
|
tags := listedPreAuthKeys[index].GetAclTags()
|
||||||
|
sort.Strings(tags)
|
||||||
|
assert.Equal(t, []string{"tag:test1", "tag:test2"}, tags)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test key expiry
|
// Test key expiry
|
||||||
|
@ -226,7 +229,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
listedPreAuthKeys[1].GetKey(),
|
listedPreAuthKeys[1].GetKey(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var listedPreAuthKeysAfterExpire []v1.PreAuthKey
|
var listedPreAuthKeysAfterExpire []v1.PreAuthKey
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -242,7 +245,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listedPreAuthKeysAfterExpire,
|
&listedPreAuthKeysAfterExpire,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now()))
|
assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now()))
|
||||||
assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now()))
|
assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now()))
|
||||||
|
@ -256,7 +259,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||||
user := "pre-auth-key-without-exp-user"
|
user := "pre-auth-key-without-exp-user"
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -264,10 +267,10 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var preAuthKey v1.PreAuthKey
|
var preAuthKey v1.PreAuthKey
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -284,7 +287,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||||
},
|
},
|
||||||
&preAuthKey,
|
&preAuthKey,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var listedPreAuthKeys []v1.PreAuthKey
|
var listedPreAuthKeys []v1.PreAuthKey
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -300,7 +303,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listedPreAuthKeys,
|
&listedPreAuthKeys,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// There is one key created by "scenario.CreateHeadscaleEnv"
|
// There is one key created by "scenario.CreateHeadscaleEnv"
|
||||||
assert.Len(t, listedPreAuthKeys, 2)
|
assert.Len(t, listedPreAuthKeys, 2)
|
||||||
|
@ -319,7 +322,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||||
user := "pre-auth-key-reus-ephm-user"
|
user := "pre-auth-key-reus-ephm-user"
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -327,10 +330,10 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var preAuthReusableKey v1.PreAuthKey
|
var preAuthReusableKey v1.PreAuthKey
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -347,7 +350,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||||
},
|
},
|
||||||
&preAuthReusableKey,
|
&preAuthReusableKey,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var preAuthEphemeralKey v1.PreAuthKey
|
var preAuthEphemeralKey v1.PreAuthKey
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -364,7 +367,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||||
},
|
},
|
||||||
&preAuthEphemeralKey,
|
&preAuthEphemeralKey,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, preAuthEphemeralKey.GetEphemeral())
|
assert.True(t, preAuthEphemeralKey.GetEphemeral())
|
||||||
assert.False(t, preAuthEphemeralKey.GetReusable())
|
assert.False(t, preAuthEphemeralKey.GetReusable())
|
||||||
|
@ -383,7 +386,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listedPreAuthKeys,
|
&listedPreAuthKeys,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// There is one key created by "scenario.CreateHeadscaleEnv"
|
// There is one key created by "scenario.CreateHeadscaleEnv"
|
||||||
assert.Len(t, listedPreAuthKeys, 3)
|
assert.Len(t, listedPreAuthKeys, 3)
|
||||||
|
@ -397,7 +400,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||||
user2 := "user2"
|
user2 := "user2"
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -413,10 +416,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||||
hsic.WithTLS(),
|
hsic.WithTLS(),
|
||||||
hsic.WithHostnameAsServerURL(),
|
hsic.WithHostnameAsServerURL(),
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var user2Key v1.PreAuthKey
|
var user2Key v1.PreAuthKey
|
||||||
|
|
||||||
|
@ -438,10 +441,10 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&user2Key,
|
&user2Key,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
allClients, err := scenario.ListTailscaleClients()
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
assertNoErrListClients(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, allClients, 1)
|
assert.Len(t, allClients, 1)
|
||||||
|
|
||||||
|
@ -449,22 +452,22 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||||
|
|
||||||
// Log out from user1
|
// Log out from user1
|
||||||
err = client.Logout()
|
err = client.Logout()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = scenario.WaitForTailscaleLogout()
|
err = scenario.WaitForTailscaleLogout()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
status, err := client.Status()
|
status, err := client.Status()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
if status.BackendState == "Starting" || status.BackendState == "Running" {
|
if status.BackendState == "Starting" || status.BackendState == "Running" {
|
||||||
t.Fatalf("expected node to be logged out, backend state: %s", status.BackendState)
|
t.Fatalf("expected node to be logged out, backend state: %s", status.BackendState)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
|
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
status, err = client.Status()
|
status, err = client.Status()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
if status.BackendState != "Running" {
|
if status.BackendState != "Running" {
|
||||||
t.Fatalf("expected node to be logged in, backend state: %s", status.BackendState)
|
t.Fatalf("expected node to be logged in, backend state: %s", status.BackendState)
|
||||||
}
|
}
|
||||||
|
@ -485,7 +488,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listNodes,
|
&listNodes,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, listNodes, 1)
|
assert.Len(t, listNodes, 1)
|
||||||
|
|
||||||
assert.Equal(t, "user2", listNodes[0].GetUser().GetName())
|
assert.Equal(t, "user2", listNodes[0].GetUser().GetName())
|
||||||
|
@ -498,7 +501,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
count := 5
|
count := 5
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -507,10 +510,10 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
keys := make([]string, count)
|
keys := make([]string, count)
|
||||||
|
|
||||||
|
@ -526,7 +529,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, apiResult)
|
assert.NotEmpty(t, apiResult)
|
||||||
|
|
||||||
keys[idx] = apiResult
|
keys[idx] = apiResult
|
||||||
|
@ -545,7 +548,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listedAPIKeys,
|
&listedAPIKeys,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listedAPIKeys, 5)
|
assert.Len(t, listedAPIKeys, 5)
|
||||||
|
|
||||||
|
@ -601,7 +604,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
listedAPIKeys[idx].GetPrefix(),
|
listedAPIKeys[idx].GetPrefix(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true
|
expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true
|
||||||
}
|
}
|
||||||
|
@ -617,7 +620,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listedAfterExpireAPIKeys,
|
&listedAfterExpireAPIKeys,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for index := range listedAfterExpireAPIKeys {
|
for index := range listedAfterExpireAPIKeys {
|
||||||
if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok {
|
if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok {
|
||||||
|
@ -643,7 +646,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
"--prefix",
|
"--prefix",
|
||||||
listedAPIKeys[0].GetPrefix(),
|
listedAPIKeys[0].GetPrefix(),
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var listedAPIKeysAfterDelete []v1.ApiKey
|
var listedAPIKeysAfterDelete []v1.ApiKey
|
||||||
err = executeAndUnmarshal(headscale,
|
err = executeAndUnmarshal(headscale,
|
||||||
|
@ -656,7 +659,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listedAPIKeysAfterDelete,
|
&listedAPIKeysAfterDelete,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listedAPIKeysAfterDelete, 4)
|
assert.Len(t, listedAPIKeysAfterDelete, 4)
|
||||||
}
|
}
|
||||||
|
@ -666,7 +669,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -674,17 +677,17 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
machineKeys := []string{
|
machineKeys := []string{
|
||||||
"mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
"mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||||
"mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c",
|
"mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c",
|
||||||
}
|
}
|
||||||
nodes := make([]*v1.Node, len(machineKeys))
|
nodes := make([]*v1.Node, len(machineKeys))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for index, machineKey := range machineKeys {
|
for index, machineKey := range machineKeys {
|
||||||
_, err := headscale.Execute(
|
_, err := headscale.Execute(
|
||||||
|
@ -702,7 +705,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var node v1.Node
|
var node v1.Node
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -720,7 +723,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes[index] = &node
|
nodes[index] = &node
|
||||||
}
|
}
|
||||||
|
@ -739,7 +742,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
|
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
|
||||||
|
|
||||||
|
@ -753,7 +756,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
"--output", "json",
|
"--output", "json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.ErrorContains(t, err, "tag must start with the string 'tag:'")
|
require.ErrorContains(t, err, "tag must start with the string 'tag:'")
|
||||||
|
|
||||||
// Test list all nodes after added seconds
|
// Test list all nodes after added seconds
|
||||||
resultMachines := make([]*v1.Node, len(machineKeys))
|
resultMachines := make([]*v1.Node, len(machineKeys))
|
||||||
|
@ -767,7 +770,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&resultMachines,
|
&resultMachines,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
found := false
|
found := false
|
||||||
for _, node := range resultMachines {
|
for _, node := range resultMachines {
|
||||||
if node.GetForcedTags() != nil {
|
if node.GetForcedTags() != nil {
|
||||||
|
@ -778,9 +781,8 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.Equal(
|
assert.True(
|
||||||
t,
|
t,
|
||||||
true,
|
|
||||||
found,
|
found,
|
||||||
"should find a node with the tag 'tag:test' in the list of nodes",
|
"should find a node with the tag 'tag:test' in the list of nodes",
|
||||||
)
|
)
|
||||||
|
@ -791,18 +793,22 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": 1,
|
"user1": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags"))
|
err = scenario.CreateHeadscaleEnv(
|
||||||
assertNoErr(t, err)
|
spec,
|
||||||
|
[]tsic.Option{tsic.WithTags([]string{"tag:test"})},
|
||||||
|
hsic.WithTestName("cliadvtags"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test list all nodes after added seconds
|
// Test list all nodes after added seconds
|
||||||
resultMachines := make([]*v1.Node, spec["user1"])
|
resultMachines := make([]*v1.Node, spec["user1"])
|
||||||
|
@ -817,7 +823,7 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&resultMachines,
|
&resultMachines,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
found := false
|
found := false
|
||||||
for _, node := range resultMachines {
|
for _, node := range resultMachines {
|
||||||
if node.GetInvalidTags() != nil {
|
if node.GetInvalidTags() != nil {
|
||||||
|
@ -828,9 +834,8 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.Equal(
|
assert.True(
|
||||||
t,
|
t,
|
||||||
true,
|
|
||||||
found,
|
found,
|
||||||
"should not find a node with the tag 'tag:test' in the list of nodes",
|
"should not find a node with the tag 'tag:test' in the list of nodes",
|
||||||
)
|
)
|
||||||
|
@ -841,14 +846,18 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
"user1": 1,
|
"user1": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy(
|
err = scenario.CreateHeadscaleEnv(
|
||||||
|
spec,
|
||||||
|
[]tsic.Option{tsic.WithTags([]string{"tag:exists"})},
|
||||||
|
hsic.WithTestName("cliadvtags"),
|
||||||
|
hsic.WithACLPolicy(
|
||||||
&policy.ACLPolicy{
|
&policy.ACLPolicy{
|
||||||
ACLs: []policy.ACL{
|
ACLs: []policy.ACL{
|
||||||
{
|
{
|
||||||
|
@ -861,11 +870,12 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
|
||||||
"tag:exists": {"user1"},
|
"tag:exists": {"user1"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
))
|
),
|
||||||
assertNoErr(t, err)
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test list all nodes after added seconds
|
// Test list all nodes after added seconds
|
||||||
resultMachines := make([]*v1.Node, spec["user1"])
|
resultMachines := make([]*v1.Node, spec["user1"])
|
||||||
|
@ -880,7 +890,7 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&resultMachines,
|
&resultMachines,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
found := false
|
found := false
|
||||||
for _, node := range resultMachines {
|
for _, node := range resultMachines {
|
||||||
if node.GetValidTags() != nil {
|
if node.GetValidTags() != nil {
|
||||||
|
@ -891,9 +901,8 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assert.Equal(
|
assert.True(
|
||||||
t,
|
t,
|
||||||
true,
|
|
||||||
found,
|
found,
|
||||||
"should not find a node with the tag 'tag:exists' in the list of nodes",
|
"should not find a node with the tag 'tag:exists' in the list of nodes",
|
||||||
)
|
)
|
||||||
|
@ -904,7 +913,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -913,10 +922,10 @@ func TestNodeCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Pregenerated machine keys
|
// Pregenerated machine keys
|
||||||
machineKeys := []string{
|
machineKeys := []string{
|
||||||
|
@ -927,7 +936,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
"mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
"mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084",
|
||||||
}
|
}
|
||||||
nodes := make([]*v1.Node, len(machineKeys))
|
nodes := make([]*v1.Node, len(machineKeys))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for index, machineKey := range machineKeys {
|
for index, machineKey := range machineKeys {
|
||||||
_, err := headscale.Execute(
|
_, err := headscale.Execute(
|
||||||
|
@ -945,7 +954,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var node v1.Node
|
var node v1.Node
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -963,7 +972,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes[index] = &node
|
nodes[index] = &node
|
||||||
}
|
}
|
||||||
|
@ -983,7 +992,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listAll,
|
&listAll,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listAll, 5)
|
assert.Len(t, listAll, 5)
|
||||||
|
|
||||||
|
@ -1004,7 +1013,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
"mkey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584",
|
"mkey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584",
|
||||||
}
|
}
|
||||||
otherUserMachines := make([]*v1.Node, len(otherUserMachineKeys))
|
otherUserMachines := make([]*v1.Node, len(otherUserMachineKeys))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for index, machineKey := range otherUserMachineKeys {
|
for index, machineKey := range otherUserMachineKeys {
|
||||||
_, err := headscale.Execute(
|
_, err := headscale.Execute(
|
||||||
|
@ -1022,7 +1031,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var node v1.Node
|
var node v1.Node
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -1040,7 +1049,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
otherUserMachines[index] = &node
|
otherUserMachines[index] = &node
|
||||||
}
|
}
|
||||||
|
@ -1060,7 +1069,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listAllWithotherUser,
|
&listAllWithotherUser,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// All nodes, nodes + otherUser
|
// All nodes, nodes + otherUser
|
||||||
assert.Len(t, listAllWithotherUser, 7)
|
assert.Len(t, listAllWithotherUser, 7)
|
||||||
|
@ -1086,7 +1095,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listOnlyotherUserMachineUser,
|
&listOnlyotherUserMachineUser,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listOnlyotherUserMachineUser, 2)
|
assert.Len(t, listOnlyotherUserMachineUser, 2)
|
||||||
|
|
||||||
|
@ -1118,7 +1127,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
"--force",
|
"--force",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test: list main user after node is deleted
|
// Test: list main user after node is deleted
|
||||||
var listOnlyMachineUserAfterDelete []v1.Node
|
var listOnlyMachineUserAfterDelete []v1.Node
|
||||||
|
@ -1135,7 +1144,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listOnlyMachineUserAfterDelete,
|
&listOnlyMachineUserAfterDelete,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listOnlyMachineUserAfterDelete, 4)
|
assert.Len(t, listOnlyMachineUserAfterDelete, 4)
|
||||||
}
|
}
|
||||||
|
@ -1145,7 +1154,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -1153,10 +1162,10 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Pregenerated machine keys
|
// Pregenerated machine keys
|
||||||
machineKeys := []string{
|
machineKeys := []string{
|
||||||
|
@ -1184,7 +1193,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var node v1.Node
|
var node v1.Node
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -1202,7 +1211,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes[index] = &node
|
nodes[index] = &node
|
||||||
}
|
}
|
||||||
|
@ -1221,7 +1230,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listAll,
|
&listAll,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listAll, 5)
|
assert.Len(t, listAll, 5)
|
||||||
|
|
||||||
|
@ -1241,7 +1250,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||||
fmt.Sprintf("%d", listAll[idx].GetId()),
|
fmt.Sprintf("%d", listAll[idx].GetId()),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var listAllAfterExpiry []v1.Node
|
var listAllAfterExpiry []v1.Node
|
||||||
|
@ -1256,7 +1265,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listAllAfterExpiry,
|
&listAllAfterExpiry,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listAllAfterExpiry, 5)
|
assert.Len(t, listAllAfterExpiry, 5)
|
||||||
|
|
||||||
|
@ -1272,7 +1281,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -1280,10 +1289,10 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Pregenerated machine keys
|
// Pregenerated machine keys
|
||||||
machineKeys := []string{
|
machineKeys := []string{
|
||||||
|
@ -1294,7 +1303,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
"mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
"mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe",
|
||||||
}
|
}
|
||||||
nodes := make([]*v1.Node, len(machineKeys))
|
nodes := make([]*v1.Node, len(machineKeys))
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for index, machineKey := range machineKeys {
|
for index, machineKey := range machineKeys {
|
||||||
_, err := headscale.Execute(
|
_, err := headscale.Execute(
|
||||||
|
@ -1312,7 +1321,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var node v1.Node
|
var node v1.Node
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -1330,7 +1339,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
nodes[index] = &node
|
nodes[index] = &node
|
||||||
}
|
}
|
||||||
|
@ -1349,7 +1358,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listAll,
|
&listAll,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listAll, 5)
|
assert.Len(t, listAll, 5)
|
||||||
|
|
||||||
|
@ -1370,7 +1379,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
fmt.Sprintf("newnode-%d", idx+1),
|
fmt.Sprintf("newnode-%d", idx+1),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Contains(t, res, "Node renamed")
|
assert.Contains(t, res, "Node renamed")
|
||||||
}
|
}
|
||||||
|
@ -1387,7 +1396,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listAllAfterRename,
|
&listAllAfterRename,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listAllAfterRename, 5)
|
assert.Len(t, listAllAfterRename, 5)
|
||||||
|
|
||||||
|
@ -1408,7 +1417,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
strings.Repeat("t", 64),
|
strings.Repeat("t", 64),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.ErrorContains(t, err, "not be over 63 chars")
|
require.ErrorContains(t, err, "not be over 63 chars")
|
||||||
|
|
||||||
var listAllAfterRenameAttempt []v1.Node
|
var listAllAfterRenameAttempt []v1.Node
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -1422,7 +1431,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&listAllAfterRenameAttempt,
|
&listAllAfterRenameAttempt,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, listAllAfterRenameAttempt, 5)
|
assert.Len(t, listAllAfterRenameAttempt, 5)
|
||||||
|
|
||||||
|
@ -1438,7 +1447,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -1447,10 +1456,10 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins"))
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Randomly generated node key
|
// Randomly generated node key
|
||||||
machineKey := "mkey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa"
|
machineKey := "mkey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa"
|
||||||
|
@ -1470,7 +1479,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var node v1.Node
|
var node v1.Node
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -1488,11 +1497,11 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, uint64(1), node.GetId())
|
assert.Equal(t, uint64(1), node.GetId())
|
||||||
assert.Equal(t, "nomad-node", node.GetName())
|
assert.Equal(t, "nomad-node", node.GetName())
|
||||||
assert.Equal(t, node.GetUser().GetName(), "old-user")
|
assert.Equal(t, "old-user", node.GetUser().GetName())
|
||||||
|
|
||||||
nodeID := fmt.Sprintf("%d", node.GetId())
|
nodeID := fmt.Sprintf("%d", node.GetId())
|
||||||
|
|
||||||
|
@ -1511,9 +1520,9 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, node.GetUser().GetName(), "new-user")
|
assert.Equal(t, "new-user", node.GetUser().GetName())
|
||||||
|
|
||||||
var allNodes []v1.Node
|
var allNodes []v1.Node
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
|
@ -1527,13 +1536,13 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&allNodes,
|
&allNodes,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, allNodes, 1)
|
assert.Len(t, allNodes, 1)
|
||||||
|
|
||||||
assert.Equal(t, allNodes[0].GetId(), node.GetId())
|
assert.Equal(t, allNodes[0].GetId(), node.GetId())
|
||||||
assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
|
assert.Equal(t, allNodes[0].GetUser(), node.GetUser())
|
||||||
assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user")
|
assert.Equal(t, "new-user", allNodes[0].GetUser().GetName())
|
||||||
|
|
||||||
_, err = headscale.Execute(
|
_, err = headscale.Execute(
|
||||||
[]string{
|
[]string{
|
||||||
|
@ -1548,12 +1557,12 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.ErrorContains(
|
require.ErrorContains(
|
||||||
t,
|
t,
|
||||||
err,
|
err,
|
||||||
"user not found",
|
"user not found",
|
||||||
)
|
)
|
||||||
assert.Equal(t, node.GetUser().GetName(), "new-user")
|
assert.Equal(t, "new-user", node.GetUser().GetName())
|
||||||
|
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
headscale,
|
headscale,
|
||||||
|
@ -1570,9 +1579,9 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, node.GetUser().GetName(), "old-user")
|
assert.Equal(t, "old-user", node.GetUser().GetName())
|
||||||
|
|
||||||
err = executeAndUnmarshal(
|
err = executeAndUnmarshal(
|
||||||
headscale,
|
headscale,
|
||||||
|
@ -1589,9 +1598,9 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&node,
|
&node,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, node.GetUser().GetName(), "old-user")
|
assert.Equal(t, "old-user", node.GetUser().GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicyCommand(t *testing.T) {
|
func TestPolicyCommand(t *testing.T) {
|
||||||
|
@ -1599,7 +1608,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -1614,10 +1623,10 @@ func TestPolicyCommand(t *testing.T) {
|
||||||
"HEADSCALE_POLICY_MODE": "database",
|
"HEADSCALE_POLICY_MODE": "database",
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
p := policy.ACLPolicy{
|
p := policy.ACLPolicy{
|
||||||
ACLs: []policy.ACL{
|
ACLs: []policy.ACL{
|
||||||
|
@ -1637,7 +1646,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||||
policyFilePath := "/etc/headscale/policy.json"
|
policyFilePath := "/etc/headscale/policy.json"
|
||||||
|
|
||||||
err = headscale.WriteFile(policyFilePath, pBytes)
|
err = headscale.WriteFile(policyFilePath, pBytes)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// No policy is present at this time.
|
// No policy is present at this time.
|
||||||
// Add a new policy from a file.
|
// Add a new policy from a file.
|
||||||
|
@ -1651,7 +1660,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Get the current policy and check
|
// Get the current policy and check
|
||||||
// if it is the same as the one we set.
|
// if it is the same as the one we set.
|
||||||
|
@ -1667,11 +1676,11 @@ func TestPolicyCommand(t *testing.T) {
|
||||||
},
|
},
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, output.TagOwners, 1)
|
assert.Len(t, output.TagOwners, 1)
|
||||||
assert.Len(t, output.ACLs, 1)
|
assert.Len(t, output.ACLs, 1)
|
||||||
assert.Equal(t, output.TagOwners["tag:exists"], []string{"policy-user"})
|
assert.Equal(t, []string{"policy-user"}, output.TagOwners["tag:exists"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPolicyBrokenConfigCommand(t *testing.T) {
|
func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||||
|
@ -1679,7 +1688,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -1694,10 +1703,10 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||||
"HEADSCALE_POLICY_MODE": "database",
|
"HEADSCALE_POLICY_MODE": "database",
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
headscale, err := scenario.Headscale()
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
p := policy.ACLPolicy{
|
p := policy.ACLPolicy{
|
||||||
ACLs: []policy.ACL{
|
ACLs: []policy.ACL{
|
||||||
|
@ -1719,7 +1728,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||||
policyFilePath := "/etc/headscale/policy.json"
|
policyFilePath := "/etc/headscale/policy.json"
|
||||||
|
|
||||||
err = headscale.WriteFile(policyFilePath, pBytes)
|
err = headscale.WriteFile(policyFilePath, pBytes)
|
||||||
assertNoErr(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// No policy is present at this time.
|
// No policy is present at this time.
|
||||||
// Add a new policy from a file.
|
// Add a new policy from a file.
|
||||||
|
@ -1732,7 +1741,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||||
policyFilePath,
|
policyFilePath,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.ErrorContains(t, err, "verifying policy rules: invalid action")
|
require.ErrorContains(t, err, "verifying policy rules: invalid action")
|
||||||
|
|
||||||
// The new policy was invalid, the old one should still be in place, which
|
// The new policy was invalid, the old one should still be in place, which
|
||||||
// is none.
|
// is none.
|
||||||
|
@ -1745,5 +1754,5 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||||
"json",
|
"json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.ErrorContains(t, err, "acl policy not found")
|
require.ErrorContains(t, err, "acl policy not found")
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ func ExecuteCommand(
|
||||||
select {
|
select {
|
||||||
case res := <-resultChan:
|
case res := <-resultChan:
|
||||||
if res.err != nil {
|
if res.err != nil {
|
||||||
return stdout.String(), stderr.String(), res.err
|
return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), res.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.exitCode != 0 {
|
if res.exitCode != 0 {
|
||||||
|
@ -83,12 +83,12 @@ func ExecuteCommand(
|
||||||
// log.Println("stdout: ", stdout.String())
|
// log.Println("stdout: ", stdout.String())
|
||||||
// log.Println("stderr: ", stderr.String())
|
// log.Println("stderr: ", stderr.String())
|
||||||
|
|
||||||
return stdout.String(), stderr.String(), ErrDockertestCommandFailed
|
return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
return stdout.String(), stderr.String(), nil
|
return stdout.String(), stderr.String(), nil
|
||||||
case <-time.After(execConfig.timeout):
|
case <-time.After(execConfig.timeout):
|
||||||
|
|
||||||
return stdout.String(), stderr.String(), ErrDockertestCommandTimeout
|
return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"tailscale.com/client/tailscale/apitype"
|
"tailscale.com/client/tailscale/apitype"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
@ -244,7 +245,11 @@ func TestEphemeral(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEphemeralInAlternateTimezone(t *testing.T) {
|
func TestEphemeralInAlternateTimezone(t *testing.T) {
|
||||||
testEphemeralWithOptions(t, hsic.WithTestName("ephemeral-tz"), hsic.WithTimezone("America/Los_Angeles"))
|
testEphemeralWithOptions(
|
||||||
|
t,
|
||||||
|
hsic.WithTestName("ephemeral-tz"),
|
||||||
|
hsic.WithTimezone("America/Los_Angeles"),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) {
|
||||||
|
@ -1164,10 +1169,10 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||||
},
|
},
|
||||||
&nodeList,
|
&nodeList,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodeList, 2)
|
assert.Len(t, nodeList, 2)
|
||||||
assert.True(t, nodeList[0].Online)
|
assert.True(t, nodeList[0].GetOnline())
|
||||||
assert.True(t, nodeList[1].Online)
|
assert.True(t, nodeList[1].GetOnline())
|
||||||
|
|
||||||
// Delete the first node, which is online
|
// Delete the first node, which is online
|
||||||
_, err = headscale.Execute(
|
_, err = headscale.Execute(
|
||||||
|
@ -1177,13 +1182,13 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||||
"delete",
|
"delete",
|
||||||
"--identifier",
|
"--identifier",
|
||||||
// Delete the last added machine
|
// Delete the last added machine
|
||||||
fmt.Sprintf("%d", nodeList[0].Id),
|
fmt.Sprintf("%d", nodeList[0].GetId()),
|
||||||
"--output",
|
"--output",
|
||||||
"json",
|
"json",
|
||||||
"--force",
|
"--force",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
@ -1200,9 +1205,8 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) {
|
||||||
},
|
},
|
||||||
&nodeListAfter,
|
&nodeListAfter,
|
||||||
)
|
)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodeListAfter, 1)
|
assert.Len(t, nodeListAfter, 1)
|
||||||
assert.True(t, nodeListAfter[0].Online)
|
assert.True(t, nodeListAfter[0].GetOnline())
|
||||||
assert.Equal(t, nodeList[1].Id, nodeListAfter[0].Id)
|
assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,6 +37,7 @@ const (
|
||||||
tlsCertPath = "/etc/headscale/tls.cert"
|
tlsCertPath = "/etc/headscale/tls.cert"
|
||||||
tlsKeyPath = "/etc/headscale/tls.key"
|
tlsKeyPath = "/etc/headscale/tls.key"
|
||||||
headscaleDefaultPort = 8080
|
headscaleDefaultPort = 8080
|
||||||
|
IntegrationTestDockerFileName = "Dockerfile.integration"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok")
|
var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok")
|
||||||
|
@ -303,7 +304,7 @@ func New(
|
||||||
}
|
}
|
||||||
|
|
||||||
headscaleBuildOptions := &dockertest.BuildOptions{
|
headscaleBuildOptions := &dockertest.BuildOptions{
|
||||||
Dockerfile: "Dockerfile.debug",
|
Dockerfile: IntegrationTestDockerFileName,
|
||||||
ContextDir: dockerContextPath,
|
ContextDir: dockerContextPath,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -92,9 +92,9 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
assert.Len(t, routes, 3)
|
assert.Len(t, routes, 3)
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
assert.Equal(t, true, route.GetAdvertised())
|
assert.True(t, route.GetAdvertised())
|
||||||
assert.Equal(t, false, route.GetEnabled())
|
assert.False(t, route.GetEnabled())
|
||||||
assert.Equal(t, false, route.GetIsPrimary())
|
assert.False(t, route.GetIsPrimary())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that no routes has been sent to the client,
|
// Verify that no routes has been sent to the client,
|
||||||
|
@ -139,9 +139,9 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
assert.Len(t, enablingRoutes, 3)
|
assert.Len(t, enablingRoutes, 3)
|
||||||
|
|
||||||
for _, route := range enablingRoutes {
|
for _, route := range enablingRoutes {
|
||||||
assert.Equal(t, true, route.GetAdvertised())
|
assert.True(t, route.GetAdvertised())
|
||||||
assert.Equal(t, true, route.GetEnabled())
|
assert.True(t, route.GetEnabled())
|
||||||
assert.Equal(t, true, route.GetIsPrimary())
|
assert.True(t, route.GetIsPrimary())
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
|
@ -212,18 +212,18 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
for _, route := range disablingRoutes {
|
for _, route := range disablingRoutes {
|
||||||
assert.Equal(t, true, route.GetAdvertised())
|
assert.True(t, route.GetAdvertised())
|
||||||
|
|
||||||
if route.GetId() == routeToBeDisabled.GetId() {
|
if route.GetId() == routeToBeDisabled.GetId() {
|
||||||
assert.Equal(t, false, route.GetEnabled())
|
assert.False(t, route.GetEnabled())
|
||||||
|
|
||||||
// since this is the only route of this cidr,
|
// since this is the only route of this cidr,
|
||||||
// it will not failover, and remain Primary
|
// it will not failover, and remain Primary
|
||||||
// until something can replace it.
|
// until something can replace it.
|
||||||
assert.Equal(t, true, route.GetIsPrimary())
|
assert.True(t, route.GetIsPrimary())
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, true, route.GetEnabled())
|
assert.True(t, route.GetEnabled())
|
||||||
assert.Equal(t, true, route.GetIsPrimary())
|
assert.True(t, route.GetIsPrimary())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,9 +342,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
t.Logf("initial routes %#v", routes)
|
t.Logf("initial routes %#v", routes)
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
assert.Equal(t, true, route.GetAdvertised())
|
assert.True(t, route.GetAdvertised())
|
||||||
assert.Equal(t, false, route.GetEnabled())
|
assert.False(t, route.GetEnabled())
|
||||||
assert.Equal(t, false, route.GetIsPrimary())
|
assert.False(t, route.GetIsPrimary())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that no routes has been sent to the client,
|
// Verify that no routes has been sent to the client,
|
||||||
|
@ -391,14 +391,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assert.Len(t, enablingRoutes, 2)
|
assert.Len(t, enablingRoutes, 2)
|
||||||
|
|
||||||
// Node 1 is primary
|
// Node 1 is primary
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
|
assert.True(t, enablingRoutes[0].GetAdvertised())
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
|
assert.True(t, enablingRoutes[0].GetEnabled())
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary")
|
assert.True(t, enablingRoutes[0].GetIsPrimary(), "both subnet routers are up, expected r1 to be primary")
|
||||||
|
|
||||||
// Node 2 is not primary
|
// Node 2 is not primary
|
||||||
assert.Equal(t, true, enablingRoutes[1].GetAdvertised())
|
assert.True(t, enablingRoutes[1].GetAdvertised())
|
||||||
assert.Equal(t, true, enablingRoutes[1].GetEnabled())
|
assert.True(t, enablingRoutes[1].GetEnabled())
|
||||||
assert.Equal(t, false, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary")
|
assert.False(t, enablingRoutes[1].GetIsPrimary(), "both subnet routers are up, expected r2 to be non-primary")
|
||||||
|
|
||||||
// Verify that the client has routes from the primary machine
|
// Verify that the client has routes from the primary machine
|
||||||
srs1, err := subRouter1.Status()
|
srs1, err := subRouter1.Status()
|
||||||
|
@ -446,14 +446,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assert.Len(t, routesAfterMove, 2)
|
assert.Len(t, routesAfterMove, 2)
|
||||||
|
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfterMove[0].GetAdvertised())
|
assert.True(t, routesAfterMove[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterMove[0].GetEnabled())
|
assert.True(t, routesAfterMove[0].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary")
|
assert.False(t, routesAfterMove[0].GetIsPrimary(), "r1 is down, expected r2 to be primary")
|
||||||
|
|
||||||
// Node 2 is primary
|
// Node 2 is primary
|
||||||
assert.Equal(t, true, routesAfterMove[1].GetAdvertised())
|
assert.True(t, routesAfterMove[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterMove[1].GetEnabled())
|
assert.True(t, routesAfterMove[1].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary")
|
assert.True(t, routesAfterMove[1].GetIsPrimary(), "r1 is down, expected r2 to be primary")
|
||||||
|
|
||||||
srs2, err = subRouter2.Status()
|
srs2, err = subRouter2.Status()
|
||||||
|
|
||||||
|
@ -501,16 +501,16 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assert.Len(t, routesAfterBothDown, 2)
|
assert.Len(t, routesAfterBothDown, 2)
|
||||||
|
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised())
|
assert.True(t, routesAfterBothDown[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterBothDown[0].GetEnabled())
|
assert.True(t, routesAfterBothDown[0].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
|
assert.False(t, routesAfterBothDown[0].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
|
||||||
|
|
||||||
// Node 2 is primary
|
// Node 2 is primary
|
||||||
// if the node goes down, but no other suitable route is
|
// if the node goes down, but no other suitable route is
|
||||||
// available, keep the last known good route.
|
// available, keep the last known good route.
|
||||||
assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised())
|
assert.True(t, routesAfterBothDown[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterBothDown[1].GetEnabled())
|
assert.True(t, routesAfterBothDown[1].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
|
assert.True(t, routesAfterBothDown[1].GetIsPrimary(), "r1 and r2 is down, expected r2 to _still_ be primary")
|
||||||
|
|
||||||
// TODO(kradalby): Check client status
|
// TODO(kradalby): Check client status
|
||||||
// Both are expected to be down
|
// Both are expected to be down
|
||||||
|
@ -560,14 +560,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assert.Len(t, routesAfter1Up, 2)
|
assert.Len(t, routesAfter1Up, 2)
|
||||||
|
|
||||||
// Node 1 is primary
|
// Node 1 is primary
|
||||||
assert.Equal(t, true, routesAfter1Up[0].GetAdvertised())
|
assert.True(t, routesAfter1Up[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfter1Up[0].GetEnabled())
|
assert.True(t, routesAfter1Up[0].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
|
assert.True(t, routesAfter1Up[0].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
|
||||||
|
|
||||||
// Node 2 is not primary
|
// Node 2 is not primary
|
||||||
assert.Equal(t, true, routesAfter1Up[1].GetAdvertised())
|
assert.True(t, routesAfter1Up[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfter1Up[1].GetEnabled())
|
assert.True(t, routesAfter1Up[1].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
|
assert.False(t, routesAfter1Up[1].GetIsPrimary(), "r1 is back up, expected r1 to become be primary")
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
|
@ -614,14 +614,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assert.Len(t, routesAfter2Up, 2)
|
assert.Len(t, routesAfter2Up, 2)
|
||||||
|
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfter2Up[0].GetAdvertised())
|
assert.True(t, routesAfter2Up[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfter2Up[0].GetEnabled())
|
assert.True(t, routesAfter2Up[0].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
|
assert.True(t, routesAfter2Up[0].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
|
||||||
|
|
||||||
// Node 2 is primary
|
// Node 2 is primary
|
||||||
assert.Equal(t, true, routesAfter2Up[1].GetAdvertised())
|
assert.True(t, routesAfter2Up[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfter2Up[1].GetEnabled())
|
assert.True(t, routesAfter2Up[1].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
|
assert.False(t, routesAfter2Up[1].GetIsPrimary(), "r1 and r2 is back up, expected r1 to _still_ be primary")
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
|
@ -677,14 +677,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
t.Logf("routes after disabling r1 %#v", routesAfterDisabling1)
|
t.Logf("routes after disabling r1 %#v", routesAfterDisabling1)
|
||||||
|
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised())
|
assert.True(t, routesAfterDisabling1[0].GetAdvertised())
|
||||||
assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled())
|
assert.False(t, routesAfterDisabling1[0].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfterDisabling1[0].GetIsPrimary())
|
assert.False(t, routesAfterDisabling1[0].GetIsPrimary())
|
||||||
|
|
||||||
// Node 2 is primary
|
// Node 2 is primary
|
||||||
assert.Equal(t, true, routesAfterDisabling1[1].GetAdvertised())
|
assert.True(t, routesAfterDisabling1[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterDisabling1[1].GetEnabled())
|
assert.True(t, routesAfterDisabling1[1].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfterDisabling1[1].GetIsPrimary())
|
assert.True(t, routesAfterDisabling1[1].GetIsPrimary())
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
|
@ -735,14 +735,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
assert.Len(t, routesAfterEnabling1, 2)
|
assert.Len(t, routesAfterEnabling1, 2)
|
||||||
|
|
||||||
// Node 1 is not primary
|
// Node 1 is not primary
|
||||||
assert.Equal(t, true, routesAfterEnabling1[0].GetAdvertised())
|
assert.True(t, routesAfterEnabling1[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterEnabling1[0].GetEnabled())
|
assert.True(t, routesAfterEnabling1[0].GetEnabled())
|
||||||
assert.Equal(t, false, routesAfterEnabling1[0].GetIsPrimary())
|
assert.False(t, routesAfterEnabling1[0].GetIsPrimary())
|
||||||
|
|
||||||
// Node 2 is primary
|
// Node 2 is primary
|
||||||
assert.Equal(t, true, routesAfterEnabling1[1].GetAdvertised())
|
assert.True(t, routesAfterEnabling1[1].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterEnabling1[1].GetEnabled())
|
assert.True(t, routesAfterEnabling1[1].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfterEnabling1[1].GetIsPrimary())
|
assert.True(t, routesAfterEnabling1[1].GetIsPrimary())
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
|
@ -795,9 +795,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
t.Logf("routes after deleting r2 %#v", routesAfterDeleting2)
|
t.Logf("routes after deleting r2 %#v", routesAfterDeleting2)
|
||||||
|
|
||||||
// Node 1 is primary
|
// Node 1 is primary
|
||||||
assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised())
|
assert.True(t, routesAfterDeleting2[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routesAfterDeleting2[0].GetEnabled())
|
assert.True(t, routesAfterDeleting2[0].GetEnabled())
|
||||||
assert.Equal(t, true, routesAfterDeleting2[0].GetIsPrimary())
|
assert.True(t, routesAfterDeleting2[0].GetIsPrimary())
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
|
@ -893,9 +893,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
|
|
||||||
// All routes should be auto approved and enabled
|
// All routes should be auto approved and enabled
|
||||||
assert.Equal(t, true, routes[0].GetAdvertised())
|
assert.True(t, routes[0].GetAdvertised())
|
||||||
assert.Equal(t, true, routes[0].GetEnabled())
|
assert.True(t, routes[0].GetEnabled())
|
||||||
assert.Equal(t, true, routes[0].GetIsPrimary())
|
assert.True(t, routes[0].GetIsPrimary())
|
||||||
|
|
||||||
// Stop advertising route
|
// Stop advertising route
|
||||||
command = []string{
|
command = []string{
|
||||||
|
@ -924,9 +924,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||||
assert.Len(t, notAdvertisedRoutes, 1)
|
assert.Len(t, notAdvertisedRoutes, 1)
|
||||||
|
|
||||||
// Route is no longer advertised
|
// Route is no longer advertised
|
||||||
assert.Equal(t, false, notAdvertisedRoutes[0].GetAdvertised())
|
assert.False(t, notAdvertisedRoutes[0].GetAdvertised())
|
||||||
assert.Equal(t, false, notAdvertisedRoutes[0].GetEnabled())
|
assert.False(t, notAdvertisedRoutes[0].GetEnabled())
|
||||||
assert.Equal(t, true, notAdvertisedRoutes[0].GetIsPrimary())
|
assert.True(t, notAdvertisedRoutes[0].GetIsPrimary())
|
||||||
|
|
||||||
// Advertise route again
|
// Advertise route again
|
||||||
command = []string{
|
command = []string{
|
||||||
|
@ -955,9 +955,9 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||||
assert.Len(t, reAdvertisedRoutes, 1)
|
assert.Len(t, reAdvertisedRoutes, 1)
|
||||||
|
|
||||||
// All routes should be auto approved and enabled
|
// All routes should be auto approved and enabled
|
||||||
assert.Equal(t, true, reAdvertisedRoutes[0].GetAdvertised())
|
assert.True(t, reAdvertisedRoutes[0].GetAdvertised())
|
||||||
assert.Equal(t, true, reAdvertisedRoutes[0].GetEnabled())
|
assert.True(t, reAdvertisedRoutes[0].GetEnabled())
|
||||||
assert.Equal(t, true, reAdvertisedRoutes[0].GetIsPrimary())
|
assert.True(t, reAdvertisedRoutes[0].GetIsPrimary())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoApprovedSubRoute2068(t *testing.T) {
|
func TestAutoApprovedSubRoute2068(t *testing.T) {
|
||||||
|
@ -1163,9 +1163,9 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||||
assert.Len(t, routes, 1)
|
assert.Len(t, routes, 1)
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
assert.Equal(t, true, route.GetAdvertised())
|
assert.True(t, route.GetAdvertised())
|
||||||
assert.Equal(t, false, route.GetEnabled())
|
assert.False(t, route.GetEnabled())
|
||||||
assert.Equal(t, false, route.GetIsPrimary())
|
assert.False(t, route.GetIsPrimary())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that no routes has been sent to the client,
|
// Verify that no routes has been sent to the client,
|
||||||
|
@ -1212,9 +1212,9 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||||
assert.Len(t, enablingRoutes, 1)
|
assert.Len(t, enablingRoutes, 1)
|
||||||
|
|
||||||
// Node 1 has active route
|
// Node 1 has active route
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetAdvertised())
|
assert.True(t, enablingRoutes[0].GetAdvertised())
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetEnabled())
|
assert.True(t, enablingRoutes[0].GetEnabled())
|
||||||
assert.Equal(t, true, enablingRoutes[0].GetIsPrimary())
|
assert.True(t, enablingRoutes[0].GetIsPrimary())
|
||||||
|
|
||||||
// Verify that the client has routes from the primary machine
|
// Verify that the client has routes from the primary machine
|
||||||
srs1, _ := subRouter1.Status()
|
srs1, _ := subRouter1.Status()
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/puzpuzpuz/xsync/v3"
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
)
|
)
|
||||||
|
@ -205,11 +206,11 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
|
||||||
|
|
||||||
if t != nil {
|
if t != nil {
|
||||||
stdout, err := os.ReadFile(stdoutPath)
|
stdout, err := os.ReadFile(stdoutPath)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotContains(t, string(stdout), "panic")
|
assert.NotContains(t, string(stdout), "panic")
|
||||||
|
|
||||||
stderr, err := os.ReadFile(stderrPath)
|
stderr, err := os.ReadFile(stderrPath)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotContains(t, string(stderr), "panic")
|
assert.NotContains(t, string(stderr), "panic")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
---
|
||||||
site_name: Headscale
|
site_name: Headscale
|
||||||
site_url: https://juanfont.github.io/headscale
|
site_url: https://juanfont.github.io/headscale/
|
||||||
edit_uri: blob/main/docs/ # Change the master branch to main as we are using main as a main branch
|
edit_uri: blob/main/docs/ # Change the master branch to main as we are using main as a main branch
|
||||||
site_author: Headscale authors
|
site_author: Headscale authors
|
||||||
site_description: >-
|
site_description: >-
|
||||||
|
|
Loading…
Reference in a new issue