Compare commits

...

17 commits

Author SHA1 Message Date
Kristoffer Dalby
d858ddf1f0
Merge 014ee87066 into edf9e25001 2024-11-22 13:30:07 +01:00
ArcticLampyrid
edf9e25001
feat: support client verify for derp (add integration tests) (#2046)
Some checks are pending
Build / build (push) Waiting to run
Build documentation / build (push) Waiting to run
Build documentation / deploy (push) Blocked by required conditions
Tests / test (push) Waiting to run
* feat: support client verify for derp

* docs: fix doc for integration test

* tests: add integration test for DERP verify endpoint

* tests: use `tailcfg.DERPMap` instead of `[]byte`

* refactor: introduce func `ContainsNodeKey`

* tests(dsic): use string builder for cmd args

* ci: fix tests order

* tests: fix derper failure

* chore: cleanup

* tests(verify-client): perfer to use `CreateHeadscaleEnv`

* refactor(verify-client): simplify error handling

* tests: fix `TestDERPVerifyEndpoint`

* refactor: make `doVerify` a seperated func

---------

Co-authored-by: 117503445 <t117503445@gmail.com>
2024-11-22 13:23:05 +01:00
Motiejus Jakštys
c6336adb01
config: loosen up BaseDomain and ServerURL checks (#2248)
* config: loosen up BaseDomain and ServerURL checks

Requirements [here][1]:

> OK:
> server_url: headscale.com, base: clients.headscale.com
> server_url: headscale.com, base: headscale.net
>
> Not OK:
> server_url: server.headscale.com, base: headscale.com
>
> Essentially we have to prevent the possibility where the headscale
> server has a URL which can also be assigned to a node.
>
> So for the Not OK scenario:
>
> if the server is: server.headscale.com, and a node joins with the name
> server, it will be assigned server.headscale.com and that will break
> the connection for nodes which will now try to connect to that node
> instead of the headscale server.

Fixes #2210

[1]: https://github.com/juanfont/headscale/issues/2210#issuecomment-2488165187

* server_url and base_domain: re-word error message, fix a one-off bug and add a test case for the bug.

* lint

* lint again
2024-11-22 13:21:44 +01:00
enoperm
5fbf3f8327
Websocket derp test fixes (#2247)
* integration testing: add and validate build-time options for tailscale head

* fixup! integration testing: add and validate build-time options for tailscale head

integration testing: comply with linter

* fixup! fixup! integration testing: add and validate build-time options for tailscale head

integration testing: tsic.New must never return nil

* fixup! fixup! fixup! integration testing: add and validate build-time options for tailscale head

* minor fixes
2024-11-22 11:57:01 +01:00
Kristoffer Dalby
014ee87066
update integration test build
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-29 15:02:53 -04:00
Kristoffer Dalby
a942fcf50a
fix autoapprove
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-29 08:40:15 -04:00
Kristoffer Dalby
24f3895b2b
update error string
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-27 13:58:55 -04:00
Kristoffer Dalby
85a038cfca
tags approved via acl
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-27 13:53:45 -04:00
Kristoffer Dalby
dbf2faa4bf
fix nil in router
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-27 12:56:56 -04:00
Kristoffer Dalby
7f665023d8
fix nil pointer in oidc for policy
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-27 11:50:47 -04:00
Kristoffer Dalby
f2ab5e05c9
split out reading policy and applying
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 18:29:27 -04:00
Kristoffer Dalby
50b62ddfb3
fix loading policy manager
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 18:19:14 -04:00
Kristoffer Dalby
8d5b04f3d3
hook up user and node changes to policy
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 12:53:04 -05:00
Kristoffer Dalby
19bc8b6e01
report if filter has changed
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 12:27:12 -05:00
Kristoffer Dalby
8ecba121cc
remove unused args
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 11:43:18 -05:00
Kristoffer Dalby
6afb554e20
wrap policy in policy manager interface
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-26 11:42:14 -05:00
Kristoffer Dalby
50165ce9e1
resolve user identifier to stable ID
currently, the policy approach node to user matching
with a quite naive approach looking at the username
provided in the policy and matched it with the username
on the nodes. This worked ok as long as usernames were
unique and did not change.

As usernames are no longer guarenteed to be unique in
an OIDC environment we cant rely on this.

This changes the mechanism that matches the user string
(now user token) with nodes:

- first find all potential users by looking up:
  - database ID
  - provider ID (OIDC)
  - username/email

If more than one user is matching, then the query is
rejected, and zero matching nodes are returned.

When a single user is found, the node is matched against
the User database ID, which are also present on the actual
node.

This means that from this commit, users can use the following
to identify users in the policy:
- provider identity (iss + sub)
- username
- email
- database id

There are more changes coming to this, so it is not recommended
to start using any of these new abilities, with the exception
of email, which will not change since it includes an @.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-10-23 22:31:37 -05:00
34 changed files with 1945 additions and 486 deletions

View file

@ -30,14 +30,14 @@ jobs:
- TestPreAuthKeyCorrectUserLoggedInCommand
- TestApiKeyCommand
- TestNodeTagCommand
- TestNodeAdvertiseTagNoACLCommand
- TestNodeAdvertiseTagWithACLCommand
- TestNodeAdvertiseTagCommand
- TestNodeCommand
- TestNodeExpireCommand
- TestNodeRenameCommand
- TestNodeMoveCommand
- TestPolicyCommand
- TestPolicyBrokenConfigCommand
- TestDERPVerifyEndpoint
- TestResolveMagicDNS
- TestValidateResolvConf
- TestDERPServerScenario

View file

@ -23,6 +23,7 @@
- Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198)
- Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199)
- Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232)
- Loosened up `server_url` and `base_domain` check. It was overly strict in some cases.
## 0.23.0 (2024-09-18)

19
Dockerfile.derper Normal file
View file

@ -0,0 +1,19 @@
# For testing purposes only
FROM golang:alpine AS build-env
WORKDIR /go/src
RUN apk add --no-cache git
ARG VERSION_BRANCH=main
RUN git clone https://github.com/tailscale/tailscale.git --branch=$VERSION_BRANCH --depth=1
WORKDIR /go/src/tailscale
ARG TARGETARCH
RUN GOARCH=$TARGETARCH go install -v ./cmd/derper
FROM alpine:3.18
RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl
COPY --from=build-env /go/bin/* /usr/local/bin/
ENTRYPOINT [ "/usr/local/bin/derper" ]

View file

@ -28,7 +28,9 @@ ARG VERSION_GIT_HASH=""
ENV VERSION_GIT_HASH=$VERSION_GIT_HASH
ARG TARGETARCH
RUN GOARCH=$TARGETARCH go install -ldflags="\
ARG BUILD_TAGS=""
RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\
-X tailscale.com/version.longStamp=$VERSION_LONG \
-X tailscale.com/version.shortStamp=$VERSION_SHORT \
-X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \

View file

@ -88,7 +88,8 @@ type Headscale struct {
DERPMap *tailcfg.DERPMap
DERPServer *derpServer.DERPServer
ACLPolicy *policy.ACLPolicy
polManOnce sync.Once
polMan policy.PolicyManager
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
@ -153,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
}
})
if err = app.loadPolicyManager(); err != nil {
return nil, fmt.Errorf("failed to load ACL policy: %w", err)
}
var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" {
@ -165,6 +170,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app.db,
app.nodeNotifier,
app.ipAlloc,
app.polMan,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
@ -457,6 +463,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
Methods(http.MethodGet)
router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost)
if h.cfg.DERP.ServerEnabled {
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
@ -473,6 +481,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
users, err := db.ListUsers()
if err != nil {
return err
}
changed, err := polMan.SetUsers(users)
if err != nil {
return err
}
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}
return nil
}
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// Maybe we should attempt a new in memory state and not go via the DB?
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
nodes, err := db.ListNodes()
if err != nil {
return err
}
changed, err := polMan.SetNodes(nodes)
if err != nil {
return err
}
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
notif.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}
return nil
}
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
if profilingEnabled {
@ -488,19 +542,13 @@ func (h *Headscale) Serve() error {
}
}
var err error
if err = h.loadACLPolicy(); err != nil {
return fmt.Errorf("failed to load ACL policy: %w", err)
}
if dumpConfig {
spew.Dump(h.cfg)
}
// Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)
if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server
@ -770,12 +818,21 @@ func (h *Headscale) Serve() error {
Str("signal", sig.String()).
Msg("Received SIGHUP, reloading ACL and Config")
// TODO(kradalby): Reload config on SIGHUP
if err := h.loadACLPolicy(); err != nil {
log.Error().Err(err).Msg("failed to reload ACL policy")
if err := h.loadPolicyManager(); err != nil {
log.Error().Err(err).Msg("failed to reload Policy")
}
if h.ACLPolicy != nil {
pol, err := h.policyBytes()
if err != nil {
log.Error().Err(err).Msg("failed to get policy blob")
}
changed, err := h.polMan.SetPolicy(pol)
if err != nil {
log.Error().Err(err).Msg("failed to set new policy")
}
if changed {
log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change")
@ -994,27 +1051,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil
}
func (h *Headscale) loadACLPolicy() error {
var (
pol *policy.ACLPolicy
err error
)
// policyBytes returns the appropriate policy for the
// current configuration as a []byte array.
func (h *Headscale) policyBytes() ([]byte, error) {
switch h.cfg.Policy.Mode {
case types.PolicyModeFile:
path := h.cfg.Policy.Path
// It is fine to start headscale without a policy file.
if len(path) == 0 {
return nil
return nil, nil
}
absPath := util.AbsolutePathFromConfigPath(path)
pol, err = policy.LoadACLPolicyFromPath(absPath)
policyFile, err := os.Open(absPath)
if err != nil {
return fmt.Errorf("failed to load ACL policy from file: %w", err)
return nil, err
}
defer policyFile.Close()
return io.ReadAll(policyFile)
case types.PolicyModeDB:
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil, nil
}
return nil, err
}
return []byte(p.Data), err
}
return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
}
func (h *Headscale) loadPolicyManager() error {
var errOut error
h.polManOnce.Do(func() {
// Validate and reject configuration that would error when applied
// when creating a map response. This requires nodes, so there is still
// a scenario where they might be allowed if the server has no nodes
@ -1025,42 +1101,35 @@ func (h *Headscale) loadACLPolicy() error {
// allowed to be written to the database.
nodes, err := h.db.ListNodes()
if err != nil {
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err)
return
}
users, err := h.db.ListUsers()
if err != nil {
errOut = fmt.Errorf("loading users from database to validate policy: %w", err)
return
}
_, err = pol.CompileFilterRules(nodes)
pol, err := h.policyBytes()
if err != nil {
return fmt.Errorf("verifying policy rules: %w", err)
errOut = fmt.Errorf("loading policy bytes: %w", err)
return
}
h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
if err != nil {
errOut = fmt.Errorf("creating policy manager: %w", err)
return
}
if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
_, err = h.polMan.SSHPolicy(nodes[0])
if err != nil {
return fmt.Errorf("verifying SSH rules: %w", err)
errOut = fmt.Errorf("verifying SSH rules: %w", err)
return
}
}
})
case types.PolicyModeDB:
p, err := h.db.GetPolicy()
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil
}
return fmt.Errorf("failed to get policy from database: %w", err)
}
pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data))
if err != nil {
return fmt.Errorf("failed to parse policy: %w", err)
}
default:
log.Fatal().
Str("mode", string(h.cfg.Policy.Mode)).
Msg("Unknown ACL policy mode")
}
h.ACLPolicy = pol
return nil
return errOut
}

View file

@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey(
return
}
err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
if err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
}
err = h.db.Write(func(tx *gorm.DB) error {

View file

@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(testPeers), check.Equals, 9)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil)
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
c.Assert(err, check.IsNil)
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
@ -559,10 +559,6 @@ func TestAutoApproveRoutes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
adb, err := newTestDB()
assert.NoError(t, err)
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
assert.NoError(t, err)
assert.NotNil(t, pol)
user, err := adb.CreateUser("test")
assert.NoError(t, err)
@ -599,8 +595,17 @@ func TestAutoApproveRoutes(t *testing.T) {
node0ByID, err := adb.GetNodeByID(0)
assert.NoError(t, err)
users, err := adb.ListUsers()
assert.NoError(t, err)
nodes, err := adb.ListNodes()
assert.NoError(t, err)
pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
assert.NoError(t, err)
// TODO(kradalby): Check state update
err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
assert.NoError(t, err)
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)

View file

@ -598,18 +598,18 @@ func failoverRoute(
}
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
aclPolicy *policy.ACLPolicy,
polMan policy.PolicyManager,
node *types.Node,
) error {
return hsdb.Write(func(tx *gorm.DB) error {
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
return EnableAutoApprovedRoutes(tx, polMan, node)
})
}
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
func EnableAutoApprovedRoutes(
tx *gorm.DB,
aclPolicy *policy.ACLPolicy,
polMan policy.PolicyManager,
node *types.Node,
) error {
if node.IPv4 == nil && node.IPv6 == nil {
@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes(
continue
}
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
netip.Prefix(advertisedRoute.Prefix),
)
if err != nil {
return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err)
}
routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))
log.Trace().
Str("node", node.Hostname).
@ -649,7 +644,7 @@ func EnableAutoApprovedRoutes(
approvedRoutes = append(approvedRoutes, advertisedRoute)
} else {
// TODO(kradalby): figure out how to get this to depend on less stuff
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias)
approvedIps, err := polMan.ExpandAlias(approvedAlias)
if err != nil {
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
}

View file

@ -21,7 +21,6 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
)
@ -58,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser(
return nil, err
}
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return &v1.CreateUserResponse{User: user.Proto()}, nil
}
@ -87,6 +91,11 @@ func (api headscaleV1APIServer) DeleteUser(
return nil, err
}
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return &v1.DeleteUserResponse{}, nil
}
@ -221,6 +230,11 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, err
}
err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
if err != nil {
return nil, fmt.Errorf("updating resources using node: %w", err)
}
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
}
@ -450,10 +464,7 @@ func (api headscaleV1APIServer) ListNodes(
resp.Online = true
}
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
node,
)
resp.InvalidTags = invalidTags
validTags := api.h.polMan.Tags(node)
resp.ValidTags = validTags
response[index] = resp
}
@ -723,11 +734,6 @@ func (api headscaleV1APIServer) SetPolicy(
p := request.GetPolicy()
pol, err := policy.LoadACLPolicyFromBytes([]byte(p))
if err != nil {
return nil, fmt.Errorf("loading ACL policy file: %w", err)
}
// Validate and reject configuration that would error when applied
// when creating a map response. This requires nodes, so there is still
// a scenario where they might be allowed if the server has no nodes
@ -737,14 +743,13 @@ func (api headscaleV1APIServer) SetPolicy(
if err != nil {
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
}
_, err = pol.CompileFilterRules(nodes)
changed, err := api.h.polMan.SetPolicy([]byte(p))
if err != nil {
return nil, fmt.Errorf("verifying policy rules: %w", err)
return nil, fmt.Errorf("setting policy: %w", err)
}
if len(nodes) > 0 {
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
_, err = api.h.polMan.SSHPolicy(nodes[0])
if err != nil {
return nil, fmt.Errorf("verifying SSH rules: %w", err)
}
@ -755,12 +760,13 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err
}
api.h.ACLPolicy = pol
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
// Only send update if the packet filter has changed.
if changed {
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateFullUpdate,
})
}
response := &v1.SetPolicyResponse{
Policy: updated.Data,

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
@ -56,6 +57,65 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error)
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
}
func (h *Headscale) handleVerifyRequest(
req *http.Request,
) (bool, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return false, fmt.Errorf("cannot read request body: %w", err)
}
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
return false, fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
}
nodes, err := h.db.ListNodes()
if err != nil {
return false, fmt.Errorf("cannot list nodes: %w", err)
}
return nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), nil
}
// see https://github.com/tailscale/tailscale/blob/964282d34f06ecc06ce644769c66b0b31d118340/derp/derp_server.go#L1159, Derp use verifyClientsURL to verify whether a client is allowed to connect to the DERP server.
func (h *Headscale) VerifyHandler(
writer http.ResponseWriter,
req *http.Request,
) {
if req.Method != http.MethodPost {
http.Error(writer, "Wrong method", http.StatusMethodNotAllowed)
return
}
log.Debug().
Str("handler", "/verify").
Msg("verify client")
allow, err := h.handleVerifyRequest(req)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to verify client")
http.Error(writer, "Internal error", http.StatusInternalServerError)
}
resp := tailcfg.DERPAdmitClientResponse{
Allow: allow,
}
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
err = json.NewEncoder(writer).Encode(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// KeyHandler provides the Headscale pub key
// Listens in /key.
func (h *Headscale) KeyHandler(

View file

@ -55,6 +55,7 @@ type Mapper struct {
cfg *types.Config
derpMap *tailcfg.DERPMap
notif *notifier.Notifier
polMan policy.PolicyManager
uid string
created time.Time
@ -71,6 +72,7 @@ func NewMapper(
cfg *types.Config,
derpMap *tailcfg.DERPMap,
notif *notifier.Notifier,
polMan policy.PolicyManager,
) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
@ -79,6 +81,7 @@ func NewMapper(
cfg: cfg,
derpMap: derpMap,
notif: notif,
polMan: polMan,
uid: uid,
created: time.Now(),
@ -153,10 +156,9 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
func (m *Mapper) fullMapResponse(
node *types.Node,
peers types.Nodes,
pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
resp, err := m.baseWithConfigMapResponse(node, capVer)
if err != nil {
return nil, err
}
@ -164,11 +166,10 @@ func (m *Mapper) fullMapResponse(
err = appendPeerChanges(
resp,
true, // full change
pol,
m.polMan,
node,
capVer,
peers,
peers,
m.cfg,
)
if err != nil {
@ -182,7 +183,6 @@ func (m *Mapper) fullMapResponse(
func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) {
peers, err := m.ListPeers(node.ID)
@ -190,7 +190,7 @@ func (m *Mapper) FullMapResponse(
return nil, err
}
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
if err != nil {
return nil, err
}
@ -204,10 +204,9 @@ func (m *Mapper) FullMapResponse(
func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
if err != nil {
return nil, err
}
@ -243,7 +242,6 @@ func (m *Mapper) PeerChangedResponse(
node *types.Node,
changed map[types.NodeID]bool,
patches []*tailcfg.PeerChange,
pol *policy.ACLPolicy,
messages ...string,
) ([]byte, error) {
resp := m.baseMapResponse()
@ -273,10 +271,9 @@ func (m *Mapper) PeerChangedResponse(
err = appendPeerChanges(
&resp,
false, // partial change
pol,
m.polMan,
node,
mapRequest.Version,
peers,
changedNodes,
m.cfg,
)
@ -303,7 +300,7 @@ func (m *Mapper) PeerChangedResponse(
// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
if err != nil {
return nil, err
}
@ -318,7 +315,6 @@ func (m *Mapper) PeerChangedPatchResponse(
mapRequest tailcfg.MapRequest,
node *types.Node,
changed []*tailcfg.PeerChange,
pol *policy.ACLPolicy,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.PeersChangedPatch = changed
@ -447,12 +443,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
// incremental.
func (m *Mapper) baseWithConfigMapResponse(
node *types.Node,
pol *policy.ACLPolicy,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()
tailnode, err := tailNode(node, capVer, pol, m.cfg)
tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
if err != nil {
return nil, err
}
@ -505,34 +500,30 @@ func appendPeerChanges(
resp *tailcfg.MapResponse,
fullChange bool,
pol *policy.ACLPolicy,
polMan policy.PolicyManager,
node *types.Node,
capVer tailcfg.CapabilityVersion,
peers types.Nodes,
changed types.Nodes,
cfg *types.Config,
) error {
packetFilter, err := pol.CompileFilterRules(append(peers, node))
if err != nil {
return err
}
filter := polMan.Filter()
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
sshPolicy, err := polMan.SSHPolicy(node)
if err != nil {
return err
}
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
if len(packetFilter) > 0 {
changed = policy.FilterNodesByACL(node, changed, packetFilter)
if len(filter) > 0 {
changed = policy.FilterNodesByACL(node, changed, filter)
}
profiles := generateUserProfiles(node, changed)
dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes(changed, capVer, pol, cfg)
tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
if err != nil {
return err
}
@ -557,7 +548,7 @@ func appendPeerChanges(
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node, packetFilter),
"base": policy.ReduceFilterRules(node, filter),
}
} else {
// This is a hack to avoid sending an empty list of packet filters.
@ -565,11 +556,11 @@ func appendPeerChanges(
// be omitted, causing the client to consider it unchanged, keeping the
// previous packet filter. Worst case, this can cause a node that previously
// has access to a node to _not_ loose access if an empty (allow none) is sent.
reduced := policy.ReduceFilterRules(node, packetFilter)
reduced := policy.ReduceFilterRules(node, filter)
if len(reduced) > 0 {
resp.PacketFilter = reduced
} else {
resp.PacketFilter = packetFilter
resp.PacketFilter = filter
}
}

View file

@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
mini := &types.Node{
ID: 0,
MachineKey: mustMK(
@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: 0,
User: types.User{Name: "mini"},
UserID: user1.ID,
User: user1,
ForcedTags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.2"),
Hostname: "peer1",
GivenName: "peer1",
UserID: 0,
User: types.User{Name: "mini"},
UserID: user1.ID,
User: user1,
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
IPv4: iap("100.64.0.3"),
Hostname: "peer2",
GivenName: "peer2",
UserID: 1,
User: types.User{Name: "peer2"},
UserID: user2.ID,
User: user2,
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
@ -458,17 +461,19 @@ func Test_fullMapResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
mappy := NewMapper(
nil,
tt.cfg,
tt.derpMap,
nil,
polMan,
)
got, err := mappy.fullMapResponse(
tt.node,
tt.peers,
tt.pol,
0,
)

View file

@ -14,7 +14,7 @@ import (
func tailNodes(
nodes types.Nodes,
capVer tailcfg.CapabilityVersion,
pol *policy.ACLPolicy,
polMan policy.PolicyManager,
cfg *types.Config,
) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, len(nodes))
@ -23,7 +23,7 @@ func tailNodes(
node, err := tailNode(
node,
capVer,
pol,
polMan,
cfg,
)
if err != nil {
@ -40,7 +40,7 @@ func tailNodes(
func tailNode(
node *types.Node,
capVer tailcfg.CapabilityVersion,
pol *policy.ACLPolicy,
polMan policy.PolicyManager,
cfg *types.Config,
) (*tailcfg.Node, error) {
addrs := node.Prefixes()
@ -81,7 +81,7 @@ func tailNode(
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
}
tags, _ := pol.TagsOfNode(node)
tags := polMan.Tags(node)
tags = lo.Uniq(append(tags, node.ForcedTags...))
tNode := tailcfg.Node{

View file

@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
cfg := &types.Config{
BaseDomain: tt.baseDomain,
DNSConfig: tt.dnsConfig,
@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) {
got, err := tailNode(
tt.node,
0,
tt.pol,
polMan,
cfg,
)
@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) {
tn, err := tailNode(
node,
0,
&policy.ACLPolicy{},
&policy.PolicyManagerV1{},
&types.Config{},
)
if err != nil {

View file

@ -18,6 +18,7 @@ import (
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -53,6 +54,7 @@ type AuthProviderOIDC struct {
registrationCache *zcache.Cache[string, key.MachinePublic]
notifier *notifier.Notifier
ipAlloc *db.IPAllocator
polMan policy.PolicyManager
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@ -65,6 +67,7 @@ func NewAuthProviderOIDC(
db *db.HSDatabase,
notif *notifier.Notifier,
ipAlloc *db.IPAllocator,
polMan policy.PolicyManager,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
@ -96,6 +99,7 @@ func NewAuthProviderOIDC(
registrationCache: registrationCache,
notifier: notif,
ipAlloc: ipAlloc,
polMan: polMan,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
@ -461,6 +465,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return nil, fmt.Errorf("creating or updating user: %w", err)
}
err = usersChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return nil, fmt.Errorf("updating resources using user: %w", err)
}
return user, nil
}
@ -484,6 +493,11 @@ func (a *AuthProviderOIDC) registerNode(
return fmt.Errorf("could not register node: %w", err)
}
err = nodesChangedHook(a.db, a.polMan, a.notifier)
if err != nil {
return fmt.Errorf("updating resources using node: %w", err)
}
return nil
}

View file

@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests(
policy *ACLPolicy,
node *types.Node,
peers types.Nodes,
users []types.User,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all
if policy == nil {
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
}
rules, err := policy.CompileFilterRules(append(peers, node))
rules, err := policy.CompileFilterRules(users, append(peers, node))
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
sshPolicy, err := policy.CompileSSHPolicy(node, peers)
sshPolicy, err := policy.CompileSSHPolicy(node, users, peers)
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests(
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) CompileFilterRules(
users []types.User,
nodes types.Nodes,
) ([]tailcfg.FilterRule, error) {
if pol == nil {
@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules(
var srcIPs []string
for srcIndex, src := range acl.Sources {
srcs, err := pol.expandSource(src, nodes)
srcs, err := pol.expandSource(src, users, nodes)
if err != nil {
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err)
return nil, fmt.Errorf(
"parsing policy, acl index: %d->%d: %w",
index,
srcIndex,
err,
)
}
srcIPs = append(srcIPs, srcs...)
}
@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules(
expanded, err := pol.ExpandAlias(
nodes,
users,
alias,
)
if err != nil {
@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
func (pol *ACLPolicy) CompileSSHPolicy(
node *types.Node,
users []types.User,
peers types.Nodes,
) (*tailcfg.SSHPolicy, error) {
if pol == nil {
@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
for index, sshACL := range pol.SSHs {
var dest netipx.IPSetBuilder
for _, src := range sshACL.Destinations {
expanded, err := pol.ExpandAlias(append(peers, node), src)
expanded, err := pol.ExpandAlias(append(peers, node), users, src)
if err != nil {
return nil, err
}
@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy(
case "check":
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err)
return nil, fmt.Errorf(
"parsing SSH policy, parsing check duration, index: %d: %w",
index,
err,
)
} else {
action = *checkAction
}
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err)
return nil, fmt.Errorf(
"parsing SSH policy, unknown action %q, index: %d: %w",
sshACL.Action,
index,
err,
)
}
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
} else {
expandedSrcs, err := pol.ExpandAlias(
peers,
users,
rawSrc,
)
if err != nil {
@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// with the given src alias.
func (pol *ACLPolicy) expandSource(
src string,
users []types.User,
nodes types.Nodes,
) ([]string, error) {
ipSet, err := pol.ExpandAlias(nodes, src)
ipSet, err := pol.ExpandAlias(nodes, users, src)
if err != nil {
return []string{}, err
}
@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource(
// and transform these in IPAddresses.
func (pol *ACLPolicy) ExpandAlias(
nodes types.Nodes,
users []types.User,
alias string,
) (*netipx.IPSet, error) {
if isWildcard(alias) {
@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group
if isGroup(alias) {
return pol.expandIPsFromGroup(alias, nodes)
return pol.expandIPsFromGroup(alias, users, nodes)
}
// if alias is a tag
if isTag(alias) {
return pol.expandIPsFromTag(alias, nodes)
return pol.expandIPsFromTag(alias, users, nodes)
}
if isAutoGroup(alias) {
@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias(
}
// if alias is a user
if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil {
if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil {
return ips, err
}
@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias(
if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.ExpandAlias(nodes, h.String())
return pol.ExpandAlias(nodes, users, h.String())
}
// if alias is an IP
@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup(
func (pol *ACLPolicy) expandIPsFromGroup(
group string,
users []types.User,
nodes types.Nodes,
) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder
users, err := pol.expandUsersFromGroup(group)
userTokens, err := pol.expandUsersFromGroup(group)
if err != nil {
return &netipx.IPSet{}, err
}
for _, user := range users {
filteredNodes := filterNodesByUser(nodes, user)
for _, user := range userTokens {
filteredNodes := filterNodesByUser(nodes, users, user)
for _, node := range filteredNodes {
node.AppendToIPSet(&build)
}
@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
func (pol *ACLPolicy) expandIPsFromTag(
alias string,
users []types.User,
nodes types.Nodes,
) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder
@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
// filter out nodes per tag owner
for _, user := range owners {
nodes := filterNodesByUser(nodes, user)
nodes := filterNodesByUser(nodes, users, user)
for _, node := range nodes {
if node.Hostinfo == nil {
continue
@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag(
func (pol *ACLPolicy) expandIPsFromUser(
user string,
users []types.User,
nodes types.Nodes,
) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder
filteredNodes := filterNodesByUser(nodes, user)
filteredNodes := filterNodesByUser(nodes, users, user)
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
// shortcurcuit if we have no nodes to get ips from.
@ -953,10 +977,40 @@ func (pol *ACLPolicy) TagsOfNode(
return validTags, invalidTags
}
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
// filterNodesByUser returns a list of nodes that match the given userToken from a
// policy.
// Matching nodes are determined by first matching the user token to a user by checking:
// - If it is an ID that mactches the user database ID
// - It is the Provider Identifier from OIDC
// - It matches the username or email of a user
//
// If the token matches more than one user, zero nodes will returned.
func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes {
var out types.Nodes
var potentialUsers []types.User
for _, user := range users {
if user.ProviderIdentifier == userToken {
potentialUsers = append(potentialUsers, user)
break
}
if user.Email == userToken {
potentialUsers = append(potentialUsers, user)
}
if user.Name == userToken {
potentialUsers = append(potentialUsers, user)
}
}
if len(potentialUsers) != 1 {
return nil
}
user := potentialUsers[0]
for _, node := range nodes {
if node.User.Username() == user {
if node.User.ID == user.ID {
out = append(out, node)
}
}

View file

@ -2,8 +2,10 @@ package policy
import (
"errors"
"math/rand/v2"
"net/netip"
"slices"
"sort"
"testing"
"github.com/google/go-cmp/cmp"
@ -14,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"go4.org/netipx"
"gopkg.in/check.v1"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
)
@ -375,18 +378,24 @@ func TestParsing(t *testing.T) {
return
}
rules, err := pol.CompileFilterRules(types.Nodes{
&types.Node{
IPv4: iap("100.100.100.100"),
user := types.User{
Model: gorm.Model{ID: 1},
Name: "testuser",
}
rules, err := pol.CompileFilterRules(
[]types.User{
user,
},
&types.Node{
IPv4: iap("200.200.200.200"),
User: types.User{
Name: "testuser",
types.Nodes{
&types.Node{
IPv4: iap("100.100.100.100"),
},
Hostinfo: &tailcfg.Hostinfo{},
},
})
&types.Node{
IPv4: iap("200.200.200.200"),
User: user,
Hostinfo: &tailcfg.Hostinfo{},
},
})
if (err != nil) != tt.wantErr {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
@ -533,7 +542,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
c.Assert(pol.ACLs, check.HasLen, 6)
c.Assert(err, check.IsNil)
rules, err := pol.CompileFilterRules(types.Nodes{})
rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{})
c.Assert(err, check.NotNil)
c.Assert(rules, check.IsNil)
}
@ -549,7 +558,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
},
},
}
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
}
@ -568,7 +577,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
},
},
}
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
}
@ -584,7 +593,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
},
}
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
}
@ -861,6 +870,14 @@ func Test_expandPorts(t *testing.T) {
}
func Test_listNodesInUser(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "marc"},
{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"},
{Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"},
{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"},
{Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"},
}
type args struct {
nodes types.Nodes
user string
@ -874,50 +891,239 @@ func Test_listNodesInUser(t *testing.T) {
name: "1 node in user",
args: args{
nodes: types.Nodes{
&types.Node{User: types.User{Name: "joe"}},
&types.Node{User: users[1]},
},
user: "joe",
},
want: types.Nodes{
&types.Node{User: types.User{Name: "joe"}},
&types.Node{User: users[1]},
},
},
{
name: "3 nodes, 2 in user",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}},
&types.Node{ID: 2, User: types.User{Name: "marc"}},
&types.Node{ID: 3, User: types.User{Name: "marc"}},
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: users[0]},
},
user: "marc",
},
want: types.Nodes{
&types.Node{ID: 2, User: types.User{Name: "marc"}},
&types.Node{ID: 3, User: types.User{Name: "marc"}},
&types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: users[0]},
},
},
{
name: "5 nodes, 0 in user",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: types.User{Name: "joe"}},
&types.Node{ID: 2, User: types.User{Name: "marc"}},
&types.Node{ID: 3, User: types.User{Name: "marc"}},
&types.Node{ID: 4, User: types.User{Name: "marc"}},
&types.Node{ID: 5, User: types.User{Name: "marc"}},
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[0]},
&types.Node{ID: 3, User: users[0]},
&types.Node{ID: 4, User: users[0]},
&types.Node{ID: 5, User: users[0]},
},
user: "mickael",
},
want: nil,
},
{
name: "match-by-provider-ident",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[2]},
},
},
{
name: "match-by-email",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[1]},
},
},
{
name: "multi-match-is-zero",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[1]},
&types.Node{ID: 2, User: users[2]},
&types.Node{ID: 3, User: users[3]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-email-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match email, then provider id
&types.Node{ID: 3, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "multi-username-first-match-is-zero",
args: args{
nodes: types.Nodes{
// First match username, then provider id
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 2, User: users[2]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-duplicate-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael",
},
want: nil,
},
{
name: "all-users-unique-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "marc",
},
want: types.Nodes{
&types.Node{ID: 1, User: users[0]},
},
},
{
name: "all-users-no-username-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "mikael@headscale.net",
},
want: nil,
},
{
name: "all-users-duplicate-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "joe@headscale.net",
},
want: types.Nodes{
&types.Node{ID: 2, User: users[1]},
},
},
{
name: "all-users-no-email-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "not-working@headscale.net",
},
want: nil,
},
{
name: "all-users-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "http://oidc.org/1234",
},
want: types.Nodes{
&types.Node{ID: 3, User: users[2]},
},
},
{
name: "all-users-no-provider-id-random-order",
args: args{
nodes: types.Nodes{
&types.Node{ID: 1, User: users[0]},
&types.Node{ID: 2, User: users[1]},
&types.Node{ID: 3, User: users[2]},
&types.Node{ID: 4, User: users[3]},
&types.Node{ID: 5, User: users[4]},
},
user: "http://oidc.org/4321",
},
want: nil,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := filterNodesByUser(test.args.nodes, test.args.user)
for range 1000 {
ns := test.args.nodes
rand.Shuffle(len(ns), func(i, j int) {
ns[i], ns[j] = ns[j], ns[i]
})
got := filterNodesByUser(ns, users, test.args.user)
sort.Slice(got, func(i, j int) bool {
return got[i].ID < got[j].ID
})
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff)
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff)
}
}
})
}
@ -940,6 +1146,12 @@ func Test_expandAlias(t *testing.T) {
return s
}
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "joe"},
{Model: gorm.Model{ID: 2}, Name: "marc"},
{Model: gorm.Model{ID: 3}, Name: "mickael"},
}
type field struct {
pol ACLPolicy
}
@ -989,19 +1201,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"},
User: users[2],
},
},
},
@ -1022,19 +1234,19 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"},
User: users[2],
},
},
},
@ -1185,7 +1397,7 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
@ -1194,7 +1406,7 @@ func Test_expandAlias(t *testing.T) {
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
@ -1203,11 +1415,11 @@ func Test_expandAlias(t *testing.T) {
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"},
User: users[0],
},
},
},
@ -1260,21 +1472,21 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
ForcedTags: []string{"tag:hr-webserver"},
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
ForcedTags: []string{"tag:hr-webserver"},
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"},
User: users[2],
},
},
},
@ -1295,12 +1507,12 @@ func Test_expandAlias(t *testing.T) {
nodes: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.1"),
User: types.User{Name: "joe"},
User: users[0],
ForcedTags: []string{"tag:hr-webserver"},
},
&types.Node{
IPv4: iap("100.64.0.2"),
User: types.User{Name: "joe"},
User: users[0],
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
@ -1309,11 +1521,11 @@ func Test_expandAlias(t *testing.T) {
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "mickael"},
User: users[2],
},
},
},
@ -1350,12 +1562,12 @@ func Test_expandAlias(t *testing.T) {
},
&types.Node{
IPv4: iap("100.64.0.3"),
User: types.User{Name: "marc"},
User: users[1],
Hostinfo: &tailcfg.Hostinfo{},
},
&types.Node{
IPv4: iap("100.64.0.4"),
User: types.User{Name: "joe"},
User: users[0],
Hostinfo: &tailcfg.Hostinfo{},
},
},
@ -1368,6 +1580,7 @@ func Test_expandAlias(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
got, err := test.field.pol.ExpandAlias(
test.args.nodes,
users,
test.args.alias,
)
if (err != nil) != test.wantErr {
@ -1715,6 +1928,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.field.pol.CompileFilterRules(
[]types.User{},
tt.args.nodes,
)
if (err != nil) != tt.wantErr {
@ -1834,6 +2048,13 @@ func TestTheInternet(t *testing.T) {
}
func TestReduceFilterRules(t *testing.T) {
users := []types.User{
{Model: gorm.Model{ID: 1}, Name: "mickael"},
{Model: gorm.Model{ID: 2}, Name: "user1"},
{Model: gorm.Model{ID: 3}, Name: "user2"},
{Model: gorm.Model{ID: 4}, Name: "user100"},
}
tests := []struct {
name string
node *types.Node
@ -1855,13 +2076,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: types.User{Name: "mickael"},
User: users[0],
},
peers: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: types.User{Name: "mickael"},
User: users[0],
},
},
want: []tailcfg.FilterRule{},
@ -1888,7 +2109,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
@ -1899,7 +2120,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -1967,19 +2188,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
peers: types.Nodes{
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
// "internal" exit node
&types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -2026,12 +2247,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2113,7 +2334,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -2122,12 +2343,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2215,7 +2436,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
@ -2224,12 +2445,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2292,7 +2513,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
@ -2301,12 +2522,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.2"),
IPv6: iap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: users[2],
},
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2362,7 +2583,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"),
User: types.User{Name: "user100"},
User: users[3],
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
@ -2372,7 +2593,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: users[1],
},
},
want: []tailcfg.FilterRule{
@ -2400,6 +2621,7 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, _ := tt.pol.CompileFilterRules(
users,
append(tt.peers, tt.node),
)
@ -3391,7 +3613,7 @@ func TestSSHRules(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers)
assert.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" {
@ -3474,14 +3696,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
RequestTags: []string{"tag:test"},
}
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{
ID: 0,
Hostname: "testnodes",
IPv4: iap("100.64.0.1"),
UserID: 0,
User: types.User{
Name: "user1",
},
ID: 0,
Hostname: "testnodes",
IPv4: iap("100.64.0.1"),
UserID: 0,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo,
}
@ -3498,7 +3723,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
},
}
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user})
assert.NoError(t, err)
want := []tailcfg.FilterRule{
@ -3532,7 +3757,8 @@ func TestInvalidTagValidUser(t *testing.T) {
IPv4: iap("100.64.0.1"),
UserID: 1,
User: types.User{
Name: "user1",
Model: gorm.Model{ID: 1},
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo,
@ -3549,7 +3775,7 @@ func TestInvalidTagValidUser(t *testing.T) {
},
}
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
assert.NoError(t, err)
want := []tailcfg.FilterRule{
@ -3583,7 +3809,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
IPv4: iap("100.64.0.1"),
UserID: 1,
User: types.User{
Name: "user1",
Model: gorm.Model{ID: 1},
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo,
@ -3608,7 +3835,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
assert.NoError(t, err)
want := []tailcfg.FilterRule{
@ -3637,15 +3864,17 @@ func TestValidTagInvalidUser(t *testing.T) {
Hostname: "webserver",
RequestTags: []string{"tag:webapp"},
}
user := types.User{
Model: gorm.Model{ID: 1},
Name: "user1",
}
node := &types.Node{
ID: 1,
Hostname: "webserver",
IPv4: iap("100.64.0.1"),
UserID: 1,
User: types.User{
Name: "user1",
},
ID: 1,
Hostname: "webserver",
IPv4: iap("100.64.0.1"),
UserID: 1,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo,
}
@ -3656,13 +3885,11 @@ func TestValidTagInvalidUser(t *testing.T) {
}
nodes2 := &types.Node{
ID: 2,
Hostname: "user",
IPv4: iap("100.64.0.2"),
UserID: 1,
User: types.User{
Name: "user1",
},
ID: 2,
Hostname: "user",
IPv4: iap("100.64.0.2"),
UserID: 1,
User: user,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &hostInfo2,
}
@ -3678,7 +3905,7 @@ func TestValidTagInvalidUser(t *testing.T) {
},
}
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2})
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user})
assert.NoError(t, err)
want := []tailcfg.FilterRule{

181
hscontrol/policy/pm.go Normal file
View file

@ -0,0 +1,181 @@
package policy
import (
"fmt"
"io"
"net/netip"
"os"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
)
type PolicyManager interface {
Filter() []tailcfg.FilterRule
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
Tags(*types.Node) []string
ApproversForRoute(netip.Prefix) []string
ExpandAlias(string) (*netipx.IPSet, error)
SetPolicy([]byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes types.Nodes) (bool, error)
}
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
policyFile, err := os.Open(path)
if err != nil {
return nil, err
}
defer policyFile.Close()
policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return nil, err
}
return NewPolicyManager(policyBytes, users, nodes)
}
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
var pol *ACLPolicy
var err error
if polB != nil && len(polB) > 0 {
pol, err = LoadACLPolicyFromBytes(polB)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}
}
pm := PolicyManagerV1{
pol: pol,
users: users,
nodes: nodes,
}
_, err = pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
pm := PolicyManagerV1{
pol: pol,
users: users,
nodes: nodes,
}
_, err := pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
type PolicyManagerV1 struct {
mu sync.Mutex
pol *ACLPolicy
users []types.User
nodes types.Nodes
filterHash deephash.Sum
filter []tailcfg.FilterRule
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return false, fmt.Errorf("compiling filter rules: %w", err)
}
filterHash := deephash.Hash(&filter)
if filterHash == pm.filterHash {
return false, nil
}
pm.filter = filter
pm.filterHash = filterHash
return true, nil
}
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
}
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
pol, err := LoadACLPolicyFromBytes(polB)
if err != nil {
return false, fmt.Errorf("parsing policy: %w", err)
}
pm.mu.Lock()
defer pm.mu.Unlock()
pm.pol = pol
return pm.updateLocked()
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
if pm == nil {
return nil
}
tags, _ := pm.pol.TagsOfNode(node)
return tags
}
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
// TODO(kradalby): This can be a parse error of the address in the policy,
// in the new policy this will be typed and not a problem, in this policy
// we will just return empty list
if pm.pol == nil {
return nil
}
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
return approvers
}
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
if err != nil {
return nil, err
}
return ips, nil
}

158
hscontrol/policy/pm_test.go Normal file
View file

@ -0,0 +1,158 @@
package policy
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func TestPolicySetChange(t *testing.T) {
users := []types.User{
{
Model: gorm.Model{ID: 1},
Name: "testuser",
},
}
tests := []struct {
name string
users []types.User
nodes types.Nodes
policy []byte
wantUsersChange bool
wantNodesChange bool
wantPolicyChange bool
wantFilter []tailcfg.FilterRule
}{
{
name: "set-nodes",
nodes: types.Nodes{
{
IPv4: iap("100.64.0.2"),
User: users[0],
},
},
wantNodesChange: false,
wantFilter: []tailcfg.FilterRule{
{
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-users",
users: users,
wantUsersChange: false,
wantFilter: []tailcfg.FilterRule{
{
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-users-and-node",
users: users,
nodes: types.Nodes{
{
IPv4: iap("100.64.0.2"),
User: users[0],
},
},
wantUsersChange: false,
wantNodesChange: true,
wantFilter: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.2/32"},
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
},
},
},
{
name: "set-policy",
policy: []byte(`
{
"acls": [
{
"action": "accept",
"src": [
"100.64.0.61",
],
"dst": [
"100.64.0.62:*",
],
},
],
}
`),
wantPolicyChange: true,
wantFilter: []tailcfg.FilterRule{
{
SrcIPs: []string{"100.64.0.61/32"},
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pol := `
{
"groups": {
"group:example": [
"testuser",
],
},
"hosts": {
"host-1": "100.64.0.1",
"subnet-1": "100.100.101.100/24",
},
"acls": [
{
"action": "accept",
"src": [
"group:example",
],
"dst": [
"host-1:*",
],
},
],
}
`
pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
require.NoError(t, err)
if tt.policy != nil {
change, err := pm.SetPolicy(tt.policy)
require.NoError(t, err)
assert.Equal(t, tt.wantPolicyChange, change)
}
if tt.users != nil {
change, err := pm.SetUsers(tt.users)
require.NoError(t, err)
assert.Equal(t, tt.wantUsersChange, change)
}
if tt.nodes != nil {
change, err := pm.SetNodes(tt.nodes)
require.NoError(t, err)
assert.Equal(t, tt.wantNodesChange, change)
}
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
}
})
}
}

View file

@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() {
switch update.Type {
case types.StateFullUpdate:
m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() {
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "change"
case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy)
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
updateType = "patch"
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed))
@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() {
changed[nodeID] = false
}
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateSelfUpdate:
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse")
@ -488,9 +488,12 @@ func (m *mapSession) handleEndpointUpdate() {
return
}
if m.h.ACLPolicy != nil {
// TODO(kradalby): Only update the node that has actually changed
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
if m.h.polMan != nil {
// update routes with peer information
err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node)
err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
if err != nil {
m.errf(err, "Error running auto approved routes")
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
@ -544,7 +547,7 @@ func (m *mapSession) handleEndpointUpdate() {
func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy)
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
if err != nil {
m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError)

View file

@ -28,8 +28,9 @@ const (
maxDuration time.Duration = 1<<63 - 1
)
var errOidcMutuallyExclusive = errors.New(
"oidc_client_secret and oidc_client_secret_path are mutually exclusive",
var (
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
)
type IPAllocationStrategy string
@ -827,11 +828,10 @@ func LoadServerConfig() (*Config, error) {
// - DERP run on their own domains
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
if dnsConfig.BaseDomain != "" &&
strings.Contains(serverURL, dnsConfig.BaseDomain) {
return nil, errors.New(
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
)
if dnsConfig.BaseDomain != "" {
if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil {
return nil, err
}
}
return &Config{
@ -924,6 +924,37 @@ func LoadServerConfig() (*Config, error) {
}, nil
}
// BaseDomain cannot be a suffix of the server URL.
// This is because Tailscale takes over the domain in BaseDomain,
// causing the headscale server and DERP to be unreachable.
// For Tailscale upstream, the following is true:
// - DERP run on their own domains.
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com.
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net).
func isSafeServerURL(serverURL, baseDomain string) error {
server, err := url.Parse(serverURL)
if err != nil {
return err
}
serverDomainParts := strings.Split(server.Host, ".")
baseDomainParts := strings.Split(baseDomain, ".")
if len(serverDomainParts) <= len(baseDomainParts) {
return nil
}
s := len(serverDomainParts)
b := len(baseDomainParts)
for i := range len(baseDomainParts) {
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
return nil
}
}
return errServerURLSuffix
}
type deprecator struct {
warns set.Set[string]
fatals set.Set[string]

View file

@ -1,6 +1,7 @@
package types
import (
"fmt"
"os"
"path/filepath"
"testing"
@ -139,7 +140,7 @@ func TestReadConfig(t *testing.T) {
return LoadServerConfig()
},
want: nil,
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
wantErr: errServerURLSuffix.Error(),
},
{
name: "base-domain-not-in-server-url",
@ -333,3 +334,64 @@ tls_letsencrypt_challenge_type: TLS-ALPN-01
err = LoadConfig(tmpDir, false)
assert.NoError(t, err)
}
// OK
// server_url: headscale.com, base: clients.headscale.com
// server_url: headscale.com, base: headscale.net
//
// NOT OK
// server_url: server.headscale.com, base: headscale.com.
func TestSafeServerURL(t *testing.T) {
tests := []struct {
serverURL, baseDomain,
wantErr string
}{
{
serverURL: "https://example.com",
baseDomain: "example.org",
},
{
serverURL: "https://headscale.com",
baseDomain: "headscale.com",
},
{
serverURL: "https://headscale.com",
baseDomain: "clients.headscale.com",
},
{
serverURL: "https://headscale.com",
baseDomain: "clients.subdomain.headscale.com",
},
{
serverURL: "https://headscale.kristoffer.com",
baseDomain: "mybase",
},
{
serverURL: "https://server.headscale.com",
baseDomain: "headscale.com",
wantErr: errServerURLSuffix.Error(),
},
{
serverURL: "https://server.subdomain.headscale.com",
baseDomain: "headscale.com",
wantErr: errServerURLSuffix.Error(),
},
{
serverURL: "http://foo\x00",
wantErr: `parse "http://foo\x00": net/url: invalid control character in URL`,
},
}
for _, tt := range tests {
testName := fmt.Sprintf("server=%s domain=%s", tt.serverURL, tt.baseDomain)
t.Run(testName, func(t *testing.T) {
err := isSafeServerURL(tt.serverURL, tt.baseDomain)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
return
}
assert.NoError(t, err)
})
}
}

View file

@ -223,6 +223,16 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
return found
}
func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool {
for _, node := range nodes {
if node.NodeKey == nodeKey {
return true
}
}
return false
}
func (node *Node) Proto() *v1.Node {
nodeProto := &v1.Node{
Id: uint64(node.ID),

View file

@ -8,7 +8,7 @@ prefixes:
database:
type: sqlite3
server_url: "https://derp.no"
server_url: "https://server.derp.no"
dns:
magic_dns: true

View file

@ -11,10 +11,10 @@ Tests are located in files ending with `_test.go` and the framework are located
## Running integration tests locally
The easiest way to run tests locally is to use `[act](INSERT LINK)`, a local GitHub Actions runner:
The easiest way to run tests locally is to use [act](https://github.com/nektos/act), a local GitHub Actions runner:
```
act pull_request -W .github/workflows/test-integration-v2-TestPingAllByIP.yaml
act pull_request -W .github/workflows/test-integration.yaml
```
Alternatively, the `docker run` command in each GitHub workflow file can be used.

View file

@ -13,6 +13,7 @@ import (
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
)
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
@ -786,117 +787,85 @@ func TestNodeTagCommand(t *testing.T) {
)
}
func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
func TestNodeAdvertiseTagCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec["user1"])
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--tags",
"--output", "json",
tests := []struct {
name string
policy *policy.ACLPolicy
wantTag bool
}{
{
name: "no-policy",
wantTag: false,
},
&resultMachines,
)
assert.Nil(t, err)
found := false
for _, node := range resultMachines {
if node.GetInvalidTags() != nil {
for _, tag := range node.GetInvalidTags() {
if tag == "tag:test" {
found = true
}
}
}
}
assert.Equal(
t,
true,
found,
"should not find a node with the tag 'tag:test' in the list of nodes",
)
}
func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": 1,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy(
&policy.ACLPolicy{
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
{
name: "with-policy",
policy: &policy.ACLPolicy{
ACLs: []policy.ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
TagOwners: map[string][]string{
"tag:test": {"user1"},
},
},
TagOwners: map[string][]string{
"tag:exists": {"user1"},
},
wantTag: true,
},
))
assertNoErr(t, err)
}
headscale, err := scenario.Headscale()
assertNoErr(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
// defer scenario.ShutdownAssertNoPanics(t)
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec["user1"])
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--tags",
"--output", "json",
},
&resultMachines,
)
assert.Nil(t, err)
found := false
for _, node := range resultMachines {
if node.GetValidTags() != nil {
for _, tag := range node.GetValidTags() {
if tag == "tag:exists" {
found = true
spec := map[string]int{
"user1": 1,
}
err = scenario.CreateHeadscaleEnv(spec,
[]tsic.Option{tsic.WithTags([]string{"tag:test"})},
hsic.WithTestName("cliadvtags"),
hsic.WithACLPolicy(tt.policy),
)
assertNoErr(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, spec["user1"])
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--tags",
"--output", "json",
},
&resultMachines,
)
assert.Nil(t, err)
found := false
for _, node := range resultMachines {
if tags := node.GetValidTags(); tags != nil {
found = slices.Contains(tags, "tag:test")
}
}
}
assert.Equalf(
t,
tt.wantTag,
found,
"'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag,
)
})
}
assert.Equal(
t,
true,
found,
"should not find a node with the tag 'tag:exists' in the list of nodes",
)
}
func TestNodeCommand(t *testing.T) {
@ -1732,7 +1701,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
policyFilePath,
},
)
assert.ErrorContains(t, err, "verifying policy rules: invalid action")
assert.ErrorContains(t, err, "compiling filter rules: invalid action")
// The new policy was invalid, the old one should still be in place, which
// is none.

View file

@ -0,0 +1,96 @@
package integration
import (
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"testing"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dsic"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic"
"tailscale.com/tailcfg"
)
func TestDERPVerifyEndpoint(t *testing.T) {
IntegrationSkip(t)
// Generate random hostname for the headscale instance
hash, err := util.GenerateRandomStringDNSSafe(6)
assertNoErr(t, err)
testName := "derpverify"
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
headscalePort := 8080
// Create cert for headscale
certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname)
assertNoErr(t, err)
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
}
derper, err := scenario.CreateDERPServer("head",
dsic.WithCACert(certHeadscale),
dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))),
)
assertNoErr(t, err)
derpMap := tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
900: {
RegionID: 900,
RegionCode: "test-derpverify",
RegionName: "TestDerpVerify",
Nodes: []*tailcfg.DERPNode{
{
Name: "TestDerpVerify",
RegionID: 900,
HostName: derper.GetHostname(),
STUNPort: derper.GetSTUNPort(),
STUNOnly: false,
DERPPort: derper.GetDERPPort(),
},
},
},
},
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())},
hsic.WithHostname(hostname),
hsic.WithPort(headscalePort),
hsic.WithCustomTLS(certHeadscale, keyHeadscale),
hsic.WithHostnameAsServerURL(),
hsic.WithDERPConfig(derpMap))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
for _, client := range allClients {
report, err := client.DebugDERPRegion("test-derpverify")
assertNoErr(t, err)
successful := false
for _, line := range report.Info {
if strings.Contains(line, "Successfully established a DERP connection with node") {
successful = true
break
}
}
if !successful {
stJSON, err := json.Marshal(report)
assertNoErr(t, err)
t.Errorf("Client %s could not establish a DERP connection: %s", client.Hostname(), string(stJSON))
}
}
}

321
integration/dsic/dsic.go Normal file
View file

@ -0,0 +1,321 @@
package dsic
import (
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
)
const (
dsicHashLength = 6
dockerContextPath = "../."
caCertRoot = "/usr/local/share/ca-certificates"
DERPerCertRoot = "/usr/local/share/derper-certs"
dockerExecuteTimeout = 60 * time.Second
)
var errDERPerStatusCodeNotOk = errors.New("DERPer status code not OK")
// DERPServerInContainer represents DERP Server in Container (DSIC).
type DERPServerInContainer struct {
version string
hostname string
pool *dockertest.Pool
container *dockertest.Resource
network *dockertest.Network
stunPort int
derpPort int
caCerts [][]byte
tlsCert []byte
tlsKey []byte
withExtraHosts []string
withVerifyClientURL string
workdir string
}
// Option represent optional settings that can be given to a
// DERPer instance.
type Option = func(c *DERPServerInContainer)
// WithCACert adds it to the trusted surtificate of the Tailscale container.
func WithCACert(cert []byte) Option {
return func(dsic *DERPServerInContainer) {
dsic.caCerts = append(dsic.caCerts, cert)
}
}
// WithOrCreateNetwork sets the Docker container network to use with
// the DERPer instance, if the parameter is nil, a new network,
// isolating the DERPer, will be created. If a network is
// passed, the DERPer instance will join the given network.
func WithOrCreateNetwork(network *dockertest.Network) Option {
return func(tsic *DERPServerInContainer) {
if network != nil {
tsic.network = network
return
}
network, err := dockertestutil.GetFirstOrCreateNetwork(
tsic.pool,
tsic.hostname+"-network",
)
if err != nil {
log.Fatalf("failed to create network: %s", err)
}
tsic.network = network
}
}
// WithDockerWorkdir allows the docker working directory to be set.
func WithDockerWorkdir(dir string) Option {
return func(tsic *DERPServerInContainer) {
tsic.workdir = dir
}
}
// WithVerifyClientURL sets the URL to verify the client.
func WithVerifyClientURL(url string) Option {
return func(tsic *DERPServerInContainer) {
tsic.withVerifyClientURL = url
}
}
// WithExtraHosts adds extra hosts to the container.
func WithExtraHosts(hosts []string) Option {
return func(tsic *DERPServerInContainer) {
tsic.withExtraHosts = hosts
}
}
// New returns a new TailscaleInContainer instance.
func New(
pool *dockertest.Pool,
version string,
network *dockertest.Network,
opts ...Option,
) (*DERPServerInContainer, error) {
hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength)
if err != nil {
return nil, err
}
hostname := fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname)
if err != nil {
return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err)
}
dsic := &DERPServerInContainer{
version: version,
hostname: hostname,
pool: pool,
network: network,
tlsCert: tlsCert,
tlsKey: tlsKey,
stunPort: 3478, //nolint
derpPort: 443, //nolint
}
for _, opt := range opts {
opt(dsic)
}
var cmdArgs strings.Builder
fmt.Fprintf(&cmdArgs, "--hostname=%s", hostname)
fmt.Fprintf(&cmdArgs, " --certmode=manual")
fmt.Fprintf(&cmdArgs, " --certdir=%s", DERPerCertRoot)
fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort)
fmt.Fprintf(&cmdArgs, " --stun=true")
fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort)
if dsic.withVerifyClientURL != "" {
fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL)
}
runOptions := &dockertest.RunOptions{
Name: hostname,
Networks: []*dockertest.Network{dsic.network},
ExtraHosts: dsic.withExtraHosts,
// we currently need to give us some time to inject the certificate further down.
Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()},
ExposedPorts: []string{
"80/tcp",
fmt.Sprintf("%d/tcp", dsic.derpPort),
fmt.Sprintf("%d/udp", dsic.stunPort),
},
}
if dsic.workdir != "" {
runOptions.WorkingDir = dsic.workdir
}
// dockertest isnt very good at handling containers that has already
// been created, this is an attempt to make sure this container isnt
// present.
err = pool.RemoveContainerByName(hostname)
if err != nil {
return nil, err
}
var container *dockertest.Resource
buildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.derper",
ContextDir: dockerContextPath,
BuildArgs: []docker.BuildArg{},
}
switch version {
case "head":
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
Name: "VERSION_BRANCH",
Value: "main",
})
default:
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
Name: "VERSION_BRANCH",
Value: "v" + version,
})
}
container, err = pool.BuildAndRunWithBuildOptions(
buildOptions,
runOptions,
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
)
if err != nil {
return nil, fmt.Errorf(
"%s could not start tailscale DERPer container (version: %s): %w",
hostname,
version,
err,
)
}
log.Printf("Created %s container\n", hostname)
dsic.container = container
for i, cert := range dsic.caCerts {
err = dsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
if len(dsic.tlsCert) != 0 {
err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
if len(dsic.tlsKey) != 0 {
err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey)
if err != nil {
return nil, fmt.Errorf("failed to write TLS key to container: %w", err)
}
}
return dsic, nil
}
// Shutdown stops and cleans up the DERPer container.
func (t *DERPServerInContainer) Shutdown() error {
err := t.SaveLog("/tmp/control")
if err != nil {
log.Printf(
"Failed to save log from %s: %s",
t.hostname,
fmt.Errorf("failed to save log: %w", err),
)
}
return t.pool.Purge(t.container)
}
// GetCert returns the TLS certificate of the DERPer instance.
func (t *DERPServerInContainer) GetCert() []byte {
return t.tlsCert
}
// Hostname returns the hostname of the DERPer instance.
func (t *DERPServerInContainer) Hostname() string {
return t.hostname
}
// Version returns the running DERPer version of the instance.
func (t *DERPServerInContainer) Version() string {
return t.version
}
// ID returns the Docker container ID of the DERPServerInContainer
// instance.
func (t *DERPServerInContainer) ID() string {
return t.container.Container.ID
}
func (t *DERPServerInContainer) GetHostname() string {
return t.hostname
}
// GetSTUNPort returns the STUN port of the DERPer instance.
func (t *DERPServerInContainer) GetSTUNPort() int {
return t.stunPort
}
// GetDERPPort returns the DERP port of the DERPer instance.
func (t *DERPServerInContainer) GetDERPPort() int {
return t.derpPort
}
// WaitForRunning blocks until the DERPer instance is ready to be used.
func (t *DERPServerInContainer) WaitForRunning() error {
url := "https://" + net.JoinHostPort(t.GetHostname(), strconv.Itoa(t.GetDERPPort())) + "/"
log.Printf("waiting for DERPer to be ready at %s", url)
insecureTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint
client := &http.Client{Transport: insecureTransport}
return t.pool.Retry(func() error {
resp, err := client.Get(url) //nolint
if err != nil {
return fmt.Errorf("headscale is not ready: %w", err)
}
if resp.StatusCode != http.StatusOK {
return errDERPerStatusCodeNotOk
}
return nil
})
}
// ConnectToNetwork connects the DERPer instance to a network.
func (t *DERPServerInContainer) ConnectToNetwork(network *dockertest.Network) error {
return t.container.ConnectToNetwork(network)
}
// WriteFile save file inside the container.
func (t *DERPServerInContainer) WriteFile(path string, data []byte) error {
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
}
// SaveLog saves the current stdout log of the container to a path
// on the host system.
func (t *DERPServerInContainer) SaveLog(path string) error {
_, _, err := dockertestutil.SaveLog(t.pool, t.container, path)
return err
}

View file

@ -55,7 +55,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
spec := map[string]ClientsSpec{
"user1": {
Plain: 0,
WebsocketDERP: len(MustTestVersions),
WebsocketDERP: 2,
},
}
@ -239,10 +239,13 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv(
if clientCount.WebsocketDERP > 0 {
// Containers that use DERP-over-WebSocket
// Note that these clients *must* be built
// from source, which is currently
// only done for HEAD.
err = s.CreateTailscaleIsolatedNodesInUser(
hash,
userName,
"all",
tsic.VersionHead,
clientCount.WebsocketDERP,
tsic.WithWebsocketDERP(true),
)
@ -307,7 +310,7 @@ func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser(
cert := hsServer.GetCert()
opts = append(opts,
tsic.WithHeadscaleTLS(cert),
tsic.WithCACert(cert),
)
user.createWaitGroup.Go(func() error {

View file

@ -1,19 +1,12 @@
package hsic
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"net/url"
@ -32,11 +25,14 @@ import (
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"gopkg.in/yaml.v3"
"tailscale.com/tailcfg"
)
const (
hsicHashLength = 6
dockerContextPath = "../."
caCertRoot = "/usr/local/share/ca-certificates"
aclPolicyPath = "/etc/headscale/acl.hujson"
tlsCertPath = "/etc/headscale/tls.cert"
tlsKeyPath = "/etc/headscale/tls.key"
@ -64,6 +60,7 @@ type HeadscaleInContainer struct {
// optional config
port int
extraPorts []string
caCerts [][]byte
hostPortBindings map[string][]string
aclPolicy *policy.ACLPolicy
env map[string]string
@ -81,6 +78,10 @@ type Option = func(c *HeadscaleInContainer)
// HeadscaleInContainer instance.
func WithACLPolicy(acl *policy.ACLPolicy) Option {
return func(hsic *HeadscaleInContainer) {
if acl == nil {
return
}
// TODO(kradalby): Move somewhere appropriate
hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath
@ -88,18 +89,29 @@ func WithACLPolicy(acl *policy.ACLPolicy) Option {
}
}
// WithCACert adds it to the trusted surtificate of the container.
func WithCACert(cert []byte) Option {
return func(hsic *HeadscaleInContainer) {
hsic.caCerts = append(hsic.caCerts, cert)
}
}
// WithTLS creates certificates and enables HTTPS.
func WithTLS() Option {
return func(hsic *HeadscaleInContainer) {
cert, key, err := createCertificate(hsic.hostname)
cert, key, err := integrationutil.CreateCertificate(hsic.hostname)
if err != nil {
log.Fatalf("failed to create certificates for headscale test: %s", err)
}
// TODO(kradalby): Move somewhere appropriate
hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath
hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath
hsic.tlsCert = cert
hsic.tlsKey = key
}
}
// WithCustomTLS uses the given certificates for the Headscale instance.
func WithCustomTLS(cert, key []byte) Option {
return func(hsic *HeadscaleInContainer) {
hsic.tlsCert = cert
hsic.tlsKey = key
}
@ -146,6 +158,13 @@ func WithTestName(testName string) Option {
}
}
// WithHostname sets the hostname of the Headscale instance.
func WithHostname(hostname string) Option {
return func(hsic *HeadscaleInContainer) {
hsic.hostname = hostname
}
}
// WithHostnameAsServerURL sets the Headscale ServerURL based on
// the Hostname.
func WithHostnameAsServerURL() Option {
@ -203,6 +222,34 @@ func WithEmbeddedDERPServerOnly() Option {
}
}
// WithDERPConfig configures Headscale use a custom
// DERP server only.
func WithDERPConfig(derpMap tailcfg.DERPMap) Option {
return func(hsic *HeadscaleInContainer) {
contents, err := yaml.Marshal(derpMap)
if err != nil {
log.Fatalf("failed to marshal DERP map: %s", err)
return
}
hsic.env["HEADSCALE_DERP_PATHS"] = "/etc/headscale/derp.yml"
hsic.filesInContainer = append(hsic.filesInContainer,
fileInContainer{
path: "/etc/headscale/derp.yml",
contents: contents,
})
// Disable global DERP server and embedded DERP server
hsic.env["HEADSCALE_DERP_URLS"] = ""
hsic.env["HEADSCALE_DERP_SERVER_ENABLED"] = "false"
// Envknob for enabling DERP debug logs
hsic.env["DERP_DEBUG_LOGS"] = "true"
hsic.env["DERP_PROBER_DEBUG_LOGS"] = "true"
}
}
// WithTuning allows changing the tuning settings easily.
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
return func(hsic *HeadscaleInContainer) {
@ -300,6 +347,10 @@ func New(
"HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1",
"HEADSCALE_DEBUG_DUMP_CONFIG=1",
}
if hsic.hasTLS() {
hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath
hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath
}
for key, value := range hsic.env {
env = append(env, fmt.Sprintf("%s=%s", key, value))
}
@ -313,7 +364,7 @@ func New(
// Cmd: []string{"headscale", "serve"},
// TODO(kradalby): Get rid of this hack, we currently need to give us some
// to inject the headscale configuration further down.
Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; headscale serve ; /bin/sleep 30"},
Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; update-ca-certificates ; headscale serve ; /bin/sleep 30"},
Env: env,
}
@ -351,6 +402,14 @@ func New(
hsic.container = container
// Write the CA certificates to the container
for i, cert := range hsic.caCerts {
err = hsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML()))
if err != nil {
return nil, fmt.Errorf("failed to write headscale config to container: %w", err)
@ -749,86 +808,3 @@ func (t *HeadscaleInContainer) SendInterrupt() error {
return nil
}
// nolint
func createCertificate(hostname string) ([]byte, []byte, error) {
// From:
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(60 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: big.NewInt(1658),
Subject: pkix.Name{
CommonName: hostname,
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(60 * time.Minute),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
DNSNames: []string{hostname},
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(
rand.Reader,
cert,
ca,
&certPrivKey.PublicKey,
caPrivKey,
)
if err != nil {
return nil, nil, err
}
certPEM := new(bytes.Buffer)
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
if err != nil {
return nil, nil, err
}
certPrivKeyPEM := new(bytes.Buffer)
err = pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
if err != nil {
return nil, nil, err
}
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
}

View file

@ -3,9 +3,16 @@ package integrationutil
import (
"archive/tar"
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"math/big"
"path/filepath"
"time"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/ory/dockertest/v3"
@ -93,3 +100,86 @@ func FetchPathFromContainer(
return buf.Bytes(), nil
}
// nolint
func CreateCertificate(hostname string) ([]byte, []byte, error) {
// From:
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(60 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: big.NewInt(1658),
Subject: pkix.Name{
CommonName: hostname,
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(60 * time.Minute),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
DNSNames: []string{hostname},
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(
rand.Reader,
cert,
ca,
&certPrivKey.PublicKey,
caPrivKey,
)
if err != nil {
return nil, nil, err
}
certPEM := new(bytes.Buffer)
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
if err != nil {
return nil, nil, err
}
certPrivKeyPEM := new(bytes.Buffer)
err = pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
if err != nil {
return nil, nil, err
}
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
}

View file

@ -14,6 +14,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/dsic"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3"
@ -140,6 +141,7 @@ type Scenario struct {
// TODO(kradalby): support multiple headcales for later, currently only
// use one.
controlServers *xsync.MapOf[string, ControlServer]
derpServers []*dsic.DERPServerInContainer
users map[string]*User
@ -224,6 +226,13 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
}
}
for _, derp := range s.derpServers {
err := derp.Shutdown()
if err != nil {
log.Printf("failed to tear down derp server: %s", err)
}
}
if err := s.pool.RemoveNetwork(s.network); err != nil {
log.Printf("failed to remove network: %s", err)
}
@ -352,7 +361,7 @@ func (s *Scenario) CreateTailscaleNodesInUser(
hostname := headscale.GetHostname()
opts = append(opts,
tsic.WithHeadscaleTLS(cert),
tsic.WithCACert(cert),
tsic.WithHeadscaleName(hostname),
)
@ -651,3 +660,20 @@ func (s *Scenario) WaitForTailscaleLogout() error {
return nil
}
// CreateDERPServer creates a new DERP server in a container.
func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) {
derp, err := dsic.New(s.pool, version, s.network, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create DERP server: %w", err)
}
err = derp.WaitForRunning()
if err != nil {
return nil, fmt.Errorf("failed to reach DERP server: %w", err)
}
s.derpServers = append(s.derpServers, derp)
return derp, nil
}

View file

@ -30,6 +30,7 @@ type TailscaleClient interface {
FQDN() (string, error)
Status(...bool) (*ipnstate.Status, error)
Netmap() (*netmap.NetworkMap, error)
DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error)
Netcheck() (*netcheck.Report, error)
WaitForNeedsLogin() error
WaitForRunning() error

View file

@ -12,6 +12,7 @@ import (
"net/netip"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
@ -32,7 +33,7 @@ const (
defaultPingTimeout = 300 * time.Millisecond
defaultPingCount = 10
dockerContextPath = "../."
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
caCertRoot = "/usr/local/share/ca-certificates"
dockerExecuteTimeout = 60 * time.Second
)
@ -44,6 +45,11 @@ var (
errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey")
errTailscaleNotConnected = errors.New("tailscale not connected")
errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login")
errInvalidClientConfig = errors.New("verifiably invalid client config requested")
)
const (
VersionHead = "head"
)
func errTailscaleStatus(hostname string, err error) error {
@ -65,7 +71,7 @@ type TailscaleInContainer struct {
fqdn string
// optional config
headscaleCert []byte
caCerts [][]byte
headscaleHostname string
withWebsocketDERP bool
withSSH bool
@ -74,17 +80,23 @@ type TailscaleInContainer struct {
withExtraHosts []string
workdir string
netfilter string
// build options, solely for HEAD
buildConfig TailscaleInContainerBuildConfig
}
type TailscaleInContainerBuildConfig struct {
tags []string
}
// Option represent optional settings that can be given to a
// Tailscale instance.
type Option = func(c *TailscaleInContainer)
// WithHeadscaleTLS takes the certificate of the Headscale instance
// and adds it to the trusted surtificate of the Tailscale container.
func WithHeadscaleTLS(cert []byte) Option {
// WithCACert adds it to the trusted surtificate of the Tailscale container.
func WithCACert(cert []byte) Option {
return func(tsic *TailscaleInContainer) {
tsic.headscaleCert = cert
tsic.caCerts = append(tsic.caCerts, cert)
}
}
@ -113,7 +125,7 @@ func WithOrCreateNetwork(network *dockertest.Network) Option {
}
// WithHeadscaleName set the name of the headscale instance,
// mostly useful in combination with TLS and WithHeadscaleTLS.
// mostly useful in combination with TLS and WithCACert.
func WithHeadscaleName(hsName string) Option {
return func(tsic *TailscaleInContainer) {
tsic.headscaleHostname = hsName
@ -175,6 +187,22 @@ func WithNetfilter(state string) Option {
}
}
// WithBuildTag adds an additional value to the `-tags=` parameter
// of the Go compiler, allowing callers to customize the Tailscale client build.
// This option is only meaningful when invoked on **HEAD** versions of the client.
// Attempts to use it with any other version is a bug in the calling code.
func WithBuildTag(tag string) Option {
return func(tsic *TailscaleInContainer) {
if tsic.version != VersionHead {
panic(errInvalidClientConfig)
}
tsic.buildConfig.tags = append(
tsic.buildConfig.tags, tag,
)
}
}
// New returns a new TailscaleInContainer instance.
func New(
pool *dockertest.Pool,
@ -219,18 +247,20 @@ func New(
}
if tsic.withWebsocketDERP {
if version != VersionHead {
return tsic, errInvalidClientConfig
}
WithBuildTag("ts_debug_websockets")(tsic)
tailscaleOptions.Env = append(
tailscaleOptions.Env,
fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP),
)
}
if tsic.headscaleHostname != "" {
tailscaleOptions.ExtraHosts = []string{
"host.docker.internal:host-gateway",
fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname),
}
}
tailscaleOptions.ExtraHosts = append(tailscaleOptions.ExtraHosts,
"host.docker.internal:host-gateway")
if tsic.workdir != "" {
tailscaleOptions.WorkingDir = tsic.workdir
@ -245,14 +275,36 @@ func New(
}
var container *dockertest.Resource
if version != VersionHead {
// build options are not meaningful with pre-existing images,
// let's not lead anyone astray by pretending otherwise.
defaultBuildConfig := TailscaleInContainerBuildConfig{}
hasBuildConfig := !reflect.DeepEqual(defaultBuildConfig, tsic.buildConfig)
if hasBuildConfig {
return tsic, errInvalidClientConfig
}
}
switch version {
case "head":
case VersionHead:
buildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.tailscale-HEAD",
ContextDir: dockerContextPath,
BuildArgs: []docker.BuildArg{},
}
buildTags := strings.Join(tsic.buildConfig.tags, ",")
if len(buildTags) > 0 {
buildOptions.BuildArgs = append(
buildOptions.BuildArgs,
docker.BuildArg{
Name: "BUILD_TAGS",
Value: buildTags,
},
)
}
container, err = pool.BuildAndRunWithBuildOptions(
buildOptions,
tailscaleOptions,
@ -294,8 +346,8 @@ func New(
tsic.container = container
if tsic.hasTLS() {
err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert)
for i, cert := range tsic.caCerts {
err = tsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
@ -304,10 +356,6 @@ func New(
return tsic, nil
}
func (t *TailscaleInContainer) hasTLS() bool {
return len(t.headscaleCert) != 0
}
// Shutdown stops and cleans up the Tailscale container.
func (t *TailscaleInContainer) Shutdown() error {
err := t.SaveLog("/tmp/control")
@ -682,6 +730,34 @@ func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error
}
}
func (t *TailscaleInContainer) DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error) {
if !util.TailscaleVersionNewerOrEqual("1.34", t.version) {
panic("tsic.DebugDERPRegion() called with unsupported version: " + t.version)
}
command := []string{
"tailscale",
"debug",
"derp",
region,
}
result, stderr, err := t.Execute(command)
if err != nil {
fmt.Printf("stderr: %s\n", stderr) // nolint
return nil, fmt.Errorf("failed to execute tailscale debug derp command: %w", err)
}
var report ipnstate.DebugDERPRegionReport
err = json.Unmarshal([]byte(result), &report)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale derp region report: %w", err)
}
return &report, err
}
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
command := []string{