mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-26 08:53:05 +00:00
Compare commits
17 commits
712bb375bf
...
d858ddf1f0
Author | SHA1 | Date | |
---|---|---|---|
|
d858ddf1f0 | ||
|
edf9e25001 | ||
|
c6336adb01 | ||
|
5fbf3f8327 | ||
|
014ee87066 | ||
|
a942fcf50a | ||
|
24f3895b2b | ||
|
85a038cfca | ||
|
dbf2faa4bf | ||
|
7f665023d8 | ||
|
f2ab5e05c9 | ||
|
50b62ddfb3 | ||
|
8d5b04f3d3 | ||
|
19bc8b6e01 | ||
|
8ecba121cc | ||
|
6afb554e20 | ||
|
50165ce9e1 |
34 changed files with 1945 additions and 486 deletions
4
.github/workflows/test-integration.yaml
vendored
4
.github/workflows/test-integration.yaml
vendored
|
@ -30,14 +30,14 @@ jobs:
|
||||||
- TestPreAuthKeyCorrectUserLoggedInCommand
|
- TestPreAuthKeyCorrectUserLoggedInCommand
|
||||||
- TestApiKeyCommand
|
- TestApiKeyCommand
|
||||||
- TestNodeTagCommand
|
- TestNodeTagCommand
|
||||||
- TestNodeAdvertiseTagNoACLCommand
|
- TestNodeAdvertiseTagCommand
|
||||||
- TestNodeAdvertiseTagWithACLCommand
|
|
||||||
- TestNodeCommand
|
- TestNodeCommand
|
||||||
- TestNodeExpireCommand
|
- TestNodeExpireCommand
|
||||||
- TestNodeRenameCommand
|
- TestNodeRenameCommand
|
||||||
- TestNodeMoveCommand
|
- TestNodeMoveCommand
|
||||||
- TestPolicyCommand
|
- TestPolicyCommand
|
||||||
- TestPolicyBrokenConfigCommand
|
- TestPolicyBrokenConfigCommand
|
||||||
|
- TestDERPVerifyEndpoint
|
||||||
- TestResolveMagicDNS
|
- TestResolveMagicDNS
|
||||||
- TestValidateResolvConf
|
- TestValidateResolvConf
|
||||||
- TestDERPServerScenario
|
- TestDERPServerScenario
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
- Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198)
|
- Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198)
|
||||||
- Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199)
|
- Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199)
|
||||||
- Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232)
|
- Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232)
|
||||||
|
- Loosened up `server_url` and `base_domain` check. It was overly strict in some cases.
|
||||||
|
|
||||||
## 0.23.0 (2024-09-18)
|
## 0.23.0 (2024-09-18)
|
||||||
|
|
||||||
|
|
19
Dockerfile.derper
Normal file
19
Dockerfile.derper
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# For testing purposes only
|
||||||
|
|
||||||
|
FROM golang:alpine AS build-env
|
||||||
|
|
||||||
|
WORKDIR /go/src
|
||||||
|
|
||||||
|
RUN apk add --no-cache git
|
||||||
|
ARG VERSION_BRANCH=main
|
||||||
|
RUN git clone https://github.com/tailscale/tailscale.git --branch=$VERSION_BRANCH --depth=1
|
||||||
|
WORKDIR /go/src/tailscale
|
||||||
|
|
||||||
|
ARG TARGETARCH
|
||||||
|
RUN GOARCH=$TARGETARCH go install -v ./cmd/derper
|
||||||
|
|
||||||
|
FROM alpine:3.18
|
||||||
|
RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl
|
||||||
|
|
||||||
|
COPY --from=build-env /go/bin/* /usr/local/bin/
|
||||||
|
ENTRYPOINT [ "/usr/local/bin/derper" ]
|
|
@ -28,7 +28,9 @@ ARG VERSION_GIT_HASH=""
|
||||||
ENV VERSION_GIT_HASH=$VERSION_GIT_HASH
|
ENV VERSION_GIT_HASH=$VERSION_GIT_HASH
|
||||||
ARG TARGETARCH
|
ARG TARGETARCH
|
||||||
|
|
||||||
RUN GOARCH=$TARGETARCH go install -ldflags="\
|
ARG BUILD_TAGS=""
|
||||||
|
|
||||||
|
RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\
|
||||||
-X tailscale.com/version.longStamp=$VERSION_LONG \
|
-X tailscale.com/version.longStamp=$VERSION_LONG \
|
||||||
-X tailscale.com/version.shortStamp=$VERSION_SHORT \
|
-X tailscale.com/version.shortStamp=$VERSION_SHORT \
|
||||||
-X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \
|
-X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \
|
||||||
|
|
167
hscontrol/app.go
167
hscontrol/app.go
|
@ -88,7 +88,8 @@ type Headscale struct {
|
||||||
DERPMap *tailcfg.DERPMap
|
DERPMap *tailcfg.DERPMap
|
||||||
DERPServer *derpServer.DERPServer
|
DERPServer *derpServer.DERPServer
|
||||||
|
|
||||||
ACLPolicy *policy.ACLPolicy
|
polManOnce sync.Once
|
||||||
|
polMan policy.PolicyManager
|
||||||
|
|
||||||
mapper *mapper.Mapper
|
mapper *mapper.Mapper
|
||||||
nodeNotifier *notifier.Notifier
|
nodeNotifier *notifier.Notifier
|
||||||
|
@ -153,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err = app.loadPolicyManager(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load ACL policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
var authProvider AuthProvider
|
var authProvider AuthProvider
|
||||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||||
if cfg.OIDC.Issuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
|
@ -165,6 +170,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
app.db,
|
app.db,
|
||||||
app.nodeNotifier,
|
app.nodeNotifier,
|
||||||
app.ipAlloc,
|
app.ipAlloc,
|
||||||
|
app.polMan,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
||||||
|
@ -457,6 +463,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||||
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
|
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
|
||||||
Methods(http.MethodGet)
|
Methods(http.MethodGet)
|
||||||
|
|
||||||
|
router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost)
|
||||||
|
|
||||||
if h.cfg.DERP.ServerEnabled {
|
if h.cfg.DERP.ServerEnabled {
|
||||||
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
|
router.HandleFunc("/derp", h.DERPServer.DERPHandler)
|
||||||
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
|
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)
|
||||||
|
@ -473,6 +481,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||||
|
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||||
|
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
||||||
|
users, err := db.ListUsers()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, err := polMan.SetUsers(users)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
|
||||||
|
notif.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StateFullUpdate,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||||
|
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||||
|
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
||||||
|
nodes, err := db.ListNodes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, err := polMan.SetNodes(nodes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
|
||||||
|
notif.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StateFullUpdate,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||||
func (h *Headscale) Serve() error {
|
func (h *Headscale) Serve() error {
|
||||||
if profilingEnabled {
|
if profilingEnabled {
|
||||||
|
@ -488,19 +542,13 @@ func (h *Headscale) Serve() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if err = h.loadACLPolicy(); err != nil {
|
|
||||||
return fmt.Errorf("failed to load ACL policy: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dumpConfig {
|
if dumpConfig {
|
||||||
spew.Dump(h.cfg)
|
spew.Dump(h.cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch an initial DERP Map before we start serving
|
// Fetch an initial DERP Map before we start serving
|
||||||
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
||||||
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier)
|
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)
|
||||||
|
|
||||||
if h.cfg.DERP.ServerEnabled {
|
if h.cfg.DERP.ServerEnabled {
|
||||||
// When embedded DERP is enabled we always need a STUN server
|
// When embedded DERP is enabled we always need a STUN server
|
||||||
|
@ -770,12 +818,21 @@ func (h *Headscale) Serve() error {
|
||||||
Str("signal", sig.String()).
|
Str("signal", sig.String()).
|
||||||
Msg("Received SIGHUP, reloading ACL and Config")
|
Msg("Received SIGHUP, reloading ACL and Config")
|
||||||
|
|
||||||
// TODO(kradalby): Reload config on SIGHUP
|
if err := h.loadPolicyManager(); err != nil {
|
||||||
if err := h.loadACLPolicy(); err != nil {
|
log.Error().Err(err).Msg("failed to reload Policy")
|
||||||
log.Error().Err(err).Msg("failed to reload ACL policy")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.ACLPolicy != nil {
|
pol, err := h.policyBytes()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to get policy blob")
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, err := h.polMan.SetPolicy(pol)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to set new policy")
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
log.Info().
|
log.Info().
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||||
|
|
||||||
|
@ -994,27 +1051,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||||
return &machineKey, nil
|
return &machineKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) loadACLPolicy() error {
|
// policyBytes returns the appropriate policy for the
|
||||||
var (
|
// current configuration as a []byte array.
|
||||||
pol *policy.ACLPolicy
|
func (h *Headscale) policyBytes() ([]byte, error) {
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
switch h.cfg.Policy.Mode {
|
switch h.cfg.Policy.Mode {
|
||||||
case types.PolicyModeFile:
|
case types.PolicyModeFile:
|
||||||
path := h.cfg.Policy.Path
|
path := h.cfg.Policy.Path
|
||||||
|
|
||||||
// It is fine to start headscale without a policy file.
|
// It is fine to start headscale without a policy file.
|
||||||
if len(path) == 0 {
|
if len(path) == 0 {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
absPath := util.AbsolutePathFromConfigPath(path)
|
absPath := util.AbsolutePathFromConfigPath(path)
|
||||||
pol, err = policy.LoadACLPolicyFromPath(absPath)
|
policyFile, err := os.Open(absPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load ACL policy from file: %w", err)
|
return nil, err
|
||||||
|
}
|
||||||
|
defer policyFile.Close()
|
||||||
|
|
||||||
|
return io.ReadAll(policyFile)
|
||||||
|
|
||||||
|
case types.PolicyModeDB:
|
||||||
|
p, err := h.db.GetPolicy()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, types.ErrPolicyNotFound) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return []byte(p.Data), err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) loadPolicyManager() error {
|
||||||
|
var errOut error
|
||||||
|
h.polManOnce.Do(func() {
|
||||||
// Validate and reject configuration that would error when applied
|
// Validate and reject configuration that would error when applied
|
||||||
// when creating a map response. This requires nodes, so there is still
|
// when creating a map response. This requires nodes, so there is still
|
||||||
// a scenario where they might be allowed if the server has no nodes
|
// a scenario where they might be allowed if the server has no nodes
|
||||||
|
@ -1025,42 +1101,35 @@ func (h *Headscale) loadACLPolicy() error {
|
||||||
// allowed to be written to the database.
|
// allowed to be written to the database.
|
||||||
nodes, err := h.db.ListNodes()
|
nodes, err := h.db.ListNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
users, err := h.db.ListUsers()
|
||||||
|
if err != nil {
|
||||||
|
errOut = fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = pol.CompileFilterRules(nodes)
|
pol, err := h.policyBytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("verifying policy rules: %w", err)
|
errOut = fmt.Errorf("loading policy bytes: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
|
||||||
|
if err != nil {
|
||||||
|
errOut = fmt.Errorf("creating policy manager: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nodes) > 0 {
|
if len(nodes) > 0 {
|
||||||
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
|
_, err = h.polMan.SSHPolicy(nodes[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("verifying SSH rules: %w", err)
|
errOut = fmt.Errorf("verifying SSH rules: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
case types.PolicyModeDB:
|
return errOut
|
||||||
p, err := h.db.GetPolicy()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, types.ErrPolicyNotFound) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("failed to get policy from database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse policy: %w", err)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Fatal().
|
|
||||||
Str("mode", string(h.cfg.Policy.Mode)).
|
|
||||||
Msg("Unknown ACL policy mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
h.ACLPolicy = pol
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey(
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.db.Write(func(tx *gorm.DB) error {
|
err = h.db.Write(func(tx *gorm.DB) error {
|
||||||
|
|
|
@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(testPeers), check.Equals, 9)
|
c.Assert(len(testPeers), check.Equals, 9)
|
||||||
|
|
||||||
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
|
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
|
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
||||||
|
@ -559,10 +559,6 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
adb, err := newTestDB()
|
adb, err := newTestDB()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, pol)
|
|
||||||
|
|
||||||
user, err := adb.CreateUser("test")
|
user, err := adb.CreateUser("test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
@ -599,8 +595,17 @@ func TestAutoApproveRoutes(t *testing.T) {
|
||||||
node0ByID, err := adb.GetNodeByID(0)
|
node0ByID, err := adb.GetNodeByID(0)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
users, err := adb.ListUsers()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
nodes, err := adb.ListNodes()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// TODO(kradalby): Check state update
|
// TODO(kradalby): Check state update
|
||||||
err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
|
err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
||||||
|
|
|
@ -598,18 +598,18 @@ func failoverRoute(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
aclPolicy *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) error {
|
) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
|
return EnableAutoApprovedRoutes(tx, polMan, node)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||||
func EnableAutoApprovedRoutes(
|
func EnableAutoApprovedRoutes(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
aclPolicy *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) error {
|
) error {
|
||||||
if node.IPv4 == nil && node.IPv6 == nil {
|
if node.IPv4 == nil && node.IPv6 == nil {
|
||||||
|
@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
|
routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))
|
||||||
netip.Prefix(advertisedRoute.Prefix),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
|
@ -649,7 +644,7 @@ func EnableAutoApprovedRoutes(
|
||||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
} else {
|
} else {
|
||||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||||
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias)
|
approvedIps, err := polMan.ExpandAlias(approvedAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,6 @@ import (
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
)
|
)
|
||||||
|
@ -58,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("updating resources using user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,6 +91,11 @@ func (api headscaleV1APIServer) DeleteUser(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("updating resources using user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.DeleteUserResponse{}, nil
|
return &v1.DeleteUserResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -221,6 +230,11 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("updating resources using node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -450,10 +464,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
resp.Online = true
|
resp.Online = true
|
||||||
}
|
}
|
||||||
|
|
||||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
validTags := api.h.polMan.Tags(node)
|
||||||
node,
|
|
||||||
)
|
|
||||||
resp.InvalidTags = invalidTags
|
|
||||||
resp.ValidTags = validTags
|
resp.ValidTags = validTags
|
||||||
response[index] = resp
|
response[index] = resp
|
||||||
}
|
}
|
||||||
|
@ -723,11 +734,6 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||||
|
|
||||||
p := request.GetPolicy()
|
p := request.GetPolicy()
|
||||||
|
|
||||||
pol, err := policy.LoadACLPolicyFromBytes([]byte(p))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("loading ACL policy file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate and reject configuration that would error when applied
|
// Validate and reject configuration that would error when applied
|
||||||
// when creating a map response. This requires nodes, so there is still
|
// when creating a map response. This requires nodes, so there is still
|
||||||
// a scenario where they might be allowed if the server has no nodes
|
// a scenario where they might be allowed if the server has no nodes
|
||||||
|
@ -737,14 +743,13 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||||
}
|
}
|
||||||
|
changed, err := api.h.polMan.SetPolicy([]byte(p))
|
||||||
_, err = pol.CompileFilterRules(nodes)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("verifying policy rules: %w", err)
|
return nil, fmt.Errorf("setting policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nodes) > 0 {
|
if len(nodes) > 0 {
|
||||||
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
|
_, err = api.h.polMan.SSHPolicy(nodes[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -755,12 +760,13 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
api.h.ACLPolicy = pol
|
// Only send update if the packet filter has changed.
|
||||||
|
if changed {
|
||||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
Type: types.StateFullUpdate,
|
Type: types.StateFullUpdate,
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
response := &v1.SetPolicyResponse{
|
response := &v1.SetPolicyResponse{
|
||||||
Policy: updated.Data,
|
Policy: updated.Data,
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -56,6 +57,65 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error)
|
||||||
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
|
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) handleVerifyRequest(
|
||||||
|
req *http.Request,
|
||||||
|
) (bool, error) {
|
||||||
|
body, err := io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot read request body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
|
||||||
|
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
|
||||||
|
return false, fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes, err := h.db.ListNodes()
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("cannot list nodes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// see https://github.com/tailscale/tailscale/blob/964282d34f06ecc06ce644769c66b0b31d118340/derp/derp_server.go#L1159, Derp use verifyClientsURL to verify whether a client is allowed to connect to the DERP server.
|
||||||
|
func (h *Headscale) VerifyHandler(
|
||||||
|
writer http.ResponseWriter,
|
||||||
|
req *http.Request,
|
||||||
|
) {
|
||||||
|
if req.Method != http.MethodPost {
|
||||||
|
http.Error(writer, "Wrong method", http.StatusMethodNotAllowed)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug().
|
||||||
|
Str("handler", "/verify").
|
||||||
|
Msg("verify client")
|
||||||
|
|
||||||
|
allow, err := h.handleVerifyRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to verify client")
|
||||||
|
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := tailcfg.DERPAdmitClientResponse{
|
||||||
|
Allow: allow,
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "application/json")
|
||||||
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
err = json.NewEncoder(writer).Encode(resp)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to write response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// KeyHandler provides the Headscale pub key
|
// KeyHandler provides the Headscale pub key
|
||||||
// Listens in /key.
|
// Listens in /key.
|
||||||
func (h *Headscale) KeyHandler(
|
func (h *Headscale) KeyHandler(
|
||||||
|
|
|
@ -55,6 +55,7 @@ type Mapper struct {
|
||||||
cfg *types.Config
|
cfg *types.Config
|
||||||
derpMap *tailcfg.DERPMap
|
derpMap *tailcfg.DERPMap
|
||||||
notif *notifier.Notifier
|
notif *notifier.Notifier
|
||||||
|
polMan policy.PolicyManager
|
||||||
|
|
||||||
uid string
|
uid string
|
||||||
created time.Time
|
created time.Time
|
||||||
|
@ -71,6 +72,7 @@ func NewMapper(
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
derpMap *tailcfg.DERPMap,
|
derpMap *tailcfg.DERPMap,
|
||||||
notif *notifier.Notifier,
|
notif *notifier.Notifier,
|
||||||
|
polMan policy.PolicyManager,
|
||||||
) *Mapper {
|
) *Mapper {
|
||||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||||
|
|
||||||
|
@ -79,6 +81,7 @@ func NewMapper(
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
derpMap: derpMap,
|
derpMap: derpMap,
|
||||||
notif: notif,
|
notif: notif,
|
||||||
|
polMan: polMan,
|
||||||
|
|
||||||
uid: uid,
|
uid: uid,
|
||||||
created: time.Now(),
|
created: time.Now(),
|
||||||
|
@ -153,10 +156,9 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||||
func (m *Mapper) fullMapResponse(
|
func (m *Mapper) fullMapResponse(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
|
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -164,11 +166,10 @@ func (m *Mapper) fullMapResponse(
|
||||||
err = appendPeerChanges(
|
err = appendPeerChanges(
|
||||||
resp,
|
resp,
|
||||||
true, // full change
|
true, // full change
|
||||||
pol,
|
m.polMan,
|
||||||
node,
|
node,
|
||||||
capVer,
|
capVer,
|
||||||
peers,
|
peers,
|
||||||
peers,
|
|
||||||
m.cfg,
|
m.cfg,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -182,7 +183,6 @@ func (m *Mapper) fullMapResponse(
|
||||||
func (m *Mapper) FullMapResponse(
|
func (m *Mapper) FullMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
peers, err := m.ListPeers(node.ID)
|
peers, err := m.ListPeers(node.ID)
|
||||||
|
@ -190,7 +190,7 @@ func (m *Mapper) FullMapResponse(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
|
resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -204,10 +204,9 @@ func (m *Mapper) FullMapResponse(
|
||||||
func (m *Mapper) ReadOnlyMapResponse(
|
func (m *Mapper) ReadOnlyMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
|
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -243,7 +242,6 @@ func (m *Mapper) PeerChangedResponse(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
changed map[types.NodeID]bool,
|
changed map[types.NodeID]bool,
|
||||||
patches []*tailcfg.PeerChange,
|
patches []*tailcfg.PeerChange,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
|
@ -273,10 +271,9 @@ func (m *Mapper) PeerChangedResponse(
|
||||||
err = appendPeerChanges(
|
err = appendPeerChanges(
|
||||||
&resp,
|
&resp,
|
||||||
false, // partial change
|
false, // partial change
|
||||||
pol,
|
m.polMan,
|
||||||
node,
|
node,
|
||||||
mapRequest.Version,
|
mapRequest.Version,
|
||||||
peers,
|
|
||||||
changedNodes,
|
changedNodes,
|
||||||
m.cfg,
|
m.cfg,
|
||||||
)
|
)
|
||||||
|
@ -303,7 +300,7 @@ func (m *Mapper) PeerChangedResponse(
|
||||||
|
|
||||||
// Add the node itself, it might have changed, and particularly
|
// Add the node itself, it might have changed, and particularly
|
||||||
// if there are no patches or changes, this is a self update.
|
// if there are no patches or changes, this is a self update.
|
||||||
tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
|
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -318,7 +315,6 @@ func (m *Mapper) PeerChangedPatchResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
changed []*tailcfg.PeerChange,
|
changed []*tailcfg.PeerChange,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
resp.PeersChangedPatch = changed
|
resp.PeersChangedPatch = changed
|
||||||
|
@ -447,12 +443,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
||||||
// incremental.
|
// incremental.
|
||||||
func (m *Mapper) baseWithConfigMapResponse(
|
func (m *Mapper) baseWithConfigMapResponse(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
|
|
||||||
tailnode, err := tailNode(node, capVer, pol, m.cfg)
|
tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -505,34 +500,30 @@ func appendPeerChanges(
|
||||||
resp *tailcfg.MapResponse,
|
resp *tailcfg.MapResponse,
|
||||||
|
|
||||||
fullChange bool,
|
fullChange bool,
|
||||||
pol *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
peers types.Nodes,
|
|
||||||
changed types.Nodes,
|
changed types.Nodes,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
) error {
|
) error {
|
||||||
packetFilter, err := pol.CompileFilterRules(append(peers, node))
|
filter := polMan.Filter()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
|
sshPolicy, err := polMan.SSHPolicy(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are filter rules present, see if there are any nodes that cannot
|
// If there are filter rules present, see if there are any nodes that cannot
|
||||||
// access each-other at all and remove them from the peers.
|
// access each-other at all and remove them from the peers.
|
||||||
if len(packetFilter) > 0 {
|
if len(filter) > 0 {
|
||||||
changed = policy.FilterNodesByACL(node, changed, packetFilter)
|
changed = policy.FilterNodesByACL(node, changed, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := generateUserProfiles(node, changed)
|
profiles := generateUserProfiles(node, changed)
|
||||||
|
|
||||||
dnsConfig := generateDNSConfig(cfg, node)
|
dnsConfig := generateDNSConfig(cfg, node)
|
||||||
|
|
||||||
tailPeers, err := tailNodes(changed, capVer, pol, cfg)
|
tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -557,7 +548,7 @@ func appendPeerChanges(
|
||||||
// new PacketFilters field and "base" allows us to send a full update when we
|
// new PacketFilters field and "base" allows us to send a full update when we
|
||||||
// have to send an empty list, avoiding the hack in the else block.
|
// have to send an empty list, avoiding the hack in the else block.
|
||||||
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||||
"base": policy.ReduceFilterRules(node, packetFilter),
|
"base": policy.ReduceFilterRules(node, filter),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// This is a hack to avoid sending an empty list of packet filters.
|
// This is a hack to avoid sending an empty list of packet filters.
|
||||||
|
@ -565,11 +556,11 @@ func appendPeerChanges(
|
||||||
// be omitted, causing the client to consider it unchanged, keeping the
|
// be omitted, causing the client to consider it unchanged, keeping the
|
||||||
// previous packet filter. Worst case, this can cause a node that previously
|
// previous packet filter. Worst case, this can cause a node that previously
|
||||||
// has access to a node to _not_ loose access if an empty (allow none) is sent.
|
// has access to a node to _not_ loose access if an empty (allow none) is sent.
|
||||||
reduced := policy.ReduceFilterRules(node, packetFilter)
|
reduced := policy.ReduceFilterRules(node, filter)
|
||||||
if len(reduced) > 0 {
|
if len(reduced) > 0 {
|
||||||
resp.PacketFilter = reduced
|
resp.PacketFilter = reduced
|
||||||
} else {
|
} else {
|
||||||
resp.PacketFilter = packetFilter
|
resp.PacketFilter = filter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
|
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
|
||||||
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
|
||||||
|
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
|
||||||
|
|
||||||
mini := &types.Node{
|
mini := &types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
MachineKey: mustMK(
|
MachineKey: mustMK(
|
||||||
|
@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
Hostname: "mini",
|
Hostname: "mini",
|
||||||
GivenName: "mini",
|
GivenName: "mini",
|
||||||
UserID: 0,
|
UserID: user1.ID,
|
||||||
User: types.User{Name: "mini"},
|
User: user1,
|
||||||
ForcedTags: []string{},
|
ForcedTags: []string{},
|
||||||
AuthKey: &types.PreAuthKey{},
|
AuthKey: &types.PreAuthKey{},
|
||||||
LastSeen: &lastSeen,
|
LastSeen: &lastSeen,
|
||||||
|
@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
Hostname: "peer1",
|
Hostname: "peer1",
|
||||||
GivenName: "peer1",
|
GivenName: "peer1",
|
||||||
UserID: 0,
|
UserID: user1.ID,
|
||||||
User: types.User{Name: "mini"},
|
User: user1,
|
||||||
ForcedTags: []string{},
|
ForcedTags: []string{},
|
||||||
LastSeen: &lastSeen,
|
LastSeen: &lastSeen,
|
||||||
Expiry: &expire,
|
Expiry: &expire,
|
||||||
|
@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
IPv4: iap("100.64.0.3"),
|
IPv4: iap("100.64.0.3"),
|
||||||
Hostname: "peer2",
|
Hostname: "peer2",
|
||||||
GivenName: "peer2",
|
GivenName: "peer2",
|
||||||
UserID: 1,
|
UserID: user2.ID,
|
||||||
User: types.User{Name: "peer2"},
|
User: user2,
|
||||||
ForcedTags: []string{},
|
ForcedTags: []string{},
|
||||||
LastSeen: &lastSeen,
|
LastSeen: &lastSeen,
|
||||||
Expiry: &expire,
|
Expiry: &expire,
|
||||||
|
@ -458,17 +461,19 @@ func Test_fullMapResponse(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||||
|
|
||||||
mappy := NewMapper(
|
mappy := NewMapper(
|
||||||
nil,
|
nil,
|
||||||
tt.cfg,
|
tt.cfg,
|
||||||
tt.derpMap,
|
tt.derpMap,
|
||||||
nil,
|
nil,
|
||||||
|
polMan,
|
||||||
)
|
)
|
||||||
|
|
||||||
got, err := mappy.fullMapResponse(
|
got, err := mappy.fullMapResponse(
|
||||||
tt.node,
|
tt.node,
|
||||||
tt.peers,
|
tt.peers,
|
||||||
tt.pol,
|
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
func tailNodes(
|
func tailNodes(
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
pol *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
) ([]*tailcfg.Node, error) {
|
) ([]*tailcfg.Node, error) {
|
||||||
tNodes := make([]*tailcfg.Node, len(nodes))
|
tNodes := make([]*tailcfg.Node, len(nodes))
|
||||||
|
@ -23,7 +23,7 @@ func tailNodes(
|
||||||
node, err := tailNode(
|
node, err := tailNode(
|
||||||
node,
|
node,
|
||||||
capVer,
|
capVer,
|
||||||
pol,
|
polMan,
|
||||||
cfg,
|
cfg,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -40,7 +40,7 @@ func tailNodes(
|
||||||
func tailNode(
|
func tailNode(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
pol *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
) (*tailcfg.Node, error) {
|
) (*tailcfg.Node, error) {
|
||||||
addrs := node.Prefixes()
|
addrs := node.Prefixes()
|
||||||
|
@ -81,7 +81,7 @@ func tailNode(
|
||||||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tags, _ := pol.TagsOfNode(node)
|
tags := polMan.Tags(node)
|
||||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||||
|
|
||||||
tNode := tailcfg.Node{
|
tNode := tailcfg.Node{
|
||||||
|
|
|
@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||||
cfg := &types.Config{
|
cfg := &types.Config{
|
||||||
BaseDomain: tt.baseDomain,
|
BaseDomain: tt.baseDomain,
|
||||||
DNSConfig: tt.dnsConfig,
|
DNSConfig: tt.dnsConfig,
|
||||||
|
@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) {
|
||||||
got, err := tailNode(
|
got, err := tailNode(
|
||||||
tt.node,
|
tt.node,
|
||||||
0,
|
0,
|
||||||
tt.pol,
|
polMan,
|
||||||
cfg,
|
cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) {
|
||||||
tn, err := tailNode(
|
tn, err := tailNode(
|
||||||
node,
|
node,
|
||||||
0,
|
0,
|
||||||
&policy.ACLPolicy{},
|
&policy.PolicyManagerV1{},
|
||||||
&types.Config{},
|
&types.Config{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
@ -53,6 +54,7 @@ type AuthProviderOIDC struct {
|
||||||
registrationCache *zcache.Cache[string, key.MachinePublic]
|
registrationCache *zcache.Cache[string, key.MachinePublic]
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
ipAlloc *db.IPAllocator
|
ipAlloc *db.IPAllocator
|
||||||
|
polMan policy.PolicyManager
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oauth2Config *oauth2.Config
|
oauth2Config *oauth2.Config
|
||||||
|
@ -65,6 +67,7 @@ func NewAuthProviderOIDC(
|
||||||
db *db.HSDatabase,
|
db *db.HSDatabase,
|
||||||
notif *notifier.Notifier,
|
notif *notifier.Notifier,
|
||||||
ipAlloc *db.IPAllocator,
|
ipAlloc *db.IPAllocator,
|
||||||
|
polMan policy.PolicyManager,
|
||||||
) (*AuthProviderOIDC, error) {
|
) (*AuthProviderOIDC, error) {
|
||||||
var err error
|
var err error
|
||||||
// grab oidc config if it hasn't been already
|
// grab oidc config if it hasn't been already
|
||||||
|
@ -96,6 +99,7 @@ func NewAuthProviderOIDC(
|
||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
notifier: notif,
|
notifier: notif,
|
||||||
ipAlloc: ipAlloc,
|
ipAlloc: ipAlloc,
|
||||||
|
polMan: polMan,
|
||||||
|
|
||||||
oidcProvider: oidcProvider,
|
oidcProvider: oidcProvider,
|
||||||
oauth2Config: oauth2Config,
|
oauth2Config: oauth2Config,
|
||||||
|
@ -461,6 +465,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = usersChangedHook(a.db, a.polMan, a.notifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("updating resources using user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -484,6 +493,11 @@ func (a *AuthProviderOIDC) registerNode(
|
||||||
return fmt.Errorf("could not register node: %w", err)
|
return fmt.Errorf("could not register node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nodesChangedHook(a.db, a.polMan, a.notifier)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("updating resources using node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -137,20 +137,21 @@ func GenerateFilterAndSSHRulesForTests(
|
||||||
policy *ACLPolicy,
|
policy *ACLPolicy,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
|
users []types.User,
|
||||||
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
||||||
// If there is no policy defined, we default to allow all
|
// If there is no policy defined, we default to allow all
|
||||||
if policy == nil {
|
if policy == nil {
|
||||||
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := policy.CompileFilterRules(append(peers, node))
|
rules, err := policy.CompileFilterRules(users, append(peers, node))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
|
log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules")
|
||||||
|
|
||||||
sshPolicy, err := policy.CompileSSHPolicy(node, peers)
|
sshPolicy, err := policy.CompileSSHPolicy(node, users, peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||||
}
|
}
|
||||||
|
@ -161,6 +162,7 @@ func GenerateFilterAndSSHRulesForTests(
|
||||||
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||||
func (pol *ACLPolicy) CompileFilterRules(
|
func (pol *ACLPolicy) CompileFilterRules(
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) ([]tailcfg.FilterRule, error) {
|
) ([]tailcfg.FilterRule, error) {
|
||||||
if pol == nil {
|
if pol == nil {
|
||||||
|
@ -176,9 +178,14 @@ func (pol *ACLPolicy) CompileFilterRules(
|
||||||
|
|
||||||
var srcIPs []string
|
var srcIPs []string
|
||||||
for srcIndex, src := range acl.Sources {
|
for srcIndex, src := range acl.Sources {
|
||||||
srcs, err := pol.expandSource(src, nodes)
|
srcs, err := pol.expandSource(src, users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing policy, acl index: %d->%d: %w", index, srcIndex, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing policy, acl index: %d->%d: %w",
|
||||||
|
index,
|
||||||
|
srcIndex,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
srcIPs = append(srcIPs, srcs...)
|
srcIPs = append(srcIPs, srcs...)
|
||||||
}
|
}
|
||||||
|
@ -197,6 +204,7 @@ func (pol *ACLPolicy) CompileFilterRules(
|
||||||
|
|
||||||
expanded, err := pol.ExpandAlias(
|
expanded, err := pol.ExpandAlias(
|
||||||
nodes,
|
nodes,
|
||||||
|
users,
|
||||||
alias,
|
alias,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -281,6 +289,7 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
||||||
|
|
||||||
func (pol *ACLPolicy) CompileSSHPolicy(
|
func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
|
users []types.User,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
) (*tailcfg.SSHPolicy, error) {
|
) (*tailcfg.SSHPolicy, error) {
|
||||||
if pol == nil {
|
if pol == nil {
|
||||||
|
@ -312,7 +321,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
for index, sshACL := range pol.SSHs {
|
for index, sshACL := range pol.SSHs {
|
||||||
var dest netipx.IPSetBuilder
|
var dest netipx.IPSetBuilder
|
||||||
for _, src := range sshACL.Destinations {
|
for _, src := range sshACL.Destinations {
|
||||||
expanded, err := pol.ExpandAlias(append(peers, node), src)
|
expanded, err := pol.ExpandAlias(append(peers, node), users, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -335,12 +344,21 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
case "check":
|
case "check":
|
||||||
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
|
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing SSH policy, parsing check duration, index: %d: %w", index, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing SSH policy, parsing check duration, index: %d: %w",
|
||||||
|
index,
|
||||||
|
err,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
action = *checkAction
|
action = *checkAction
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err)
|
return nil, fmt.Errorf(
|
||||||
|
"parsing SSH policy, unknown action %q, index: %d: %w",
|
||||||
|
sshACL.Action,
|
||||||
|
index,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
||||||
|
@ -363,6 +381,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||||
} else {
|
} else {
|
||||||
expandedSrcs, err := pol.ExpandAlias(
|
expandedSrcs, err := pol.ExpandAlias(
|
||||||
peers,
|
peers,
|
||||||
|
users,
|
||||||
rawSrc,
|
rawSrc,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -512,9 +531,10 @@ func parseProtocol(protocol string) ([]int, bool, error) {
|
||||||
// with the given src alias.
|
// with the given src alias.
|
||||||
func (pol *ACLPolicy) expandSource(
|
func (pol *ACLPolicy) expandSource(
|
||||||
src string,
|
src string,
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
ipSet, err := pol.ExpandAlias(nodes, src)
|
ipSet, err := pol.ExpandAlias(nodes, users, src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, err
|
return []string{}, err
|
||||||
}
|
}
|
||||||
|
@ -538,6 +558,7 @@ func (pol *ACLPolicy) expandSource(
|
||||||
// and transform these in IPAddresses.
|
// and transform these in IPAddresses.
|
||||||
func (pol *ACLPolicy) ExpandAlias(
|
func (pol *ACLPolicy) ExpandAlias(
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
|
users []types.User,
|
||||||
alias string,
|
alias string,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
if isWildcard(alias) {
|
if isWildcard(alias) {
|
||||||
|
@ -552,12 +573,12 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
|
|
||||||
// if alias is a group
|
// if alias is a group
|
||||||
if isGroup(alias) {
|
if isGroup(alias) {
|
||||||
return pol.expandIPsFromGroup(alias, nodes)
|
return pol.expandIPsFromGroup(alias, users, nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is a tag
|
// if alias is a tag
|
||||||
if isTag(alias) {
|
if isTag(alias) {
|
||||||
return pol.expandIPsFromTag(alias, nodes)
|
return pol.expandIPsFromTag(alias, users, nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isAutoGroup(alias) {
|
if isAutoGroup(alias) {
|
||||||
|
@ -565,7 +586,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is a user
|
// if alias is a user
|
||||||
if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil {
|
if ips, err := pol.expandIPsFromUser(alias, users, nodes); ips != nil {
|
||||||
return ips, err
|
return ips, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -574,7 +595,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||||
if h, ok := pol.Hosts[alias]; ok {
|
if h, ok := pol.Hosts[alias]; ok {
|
||||||
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
||||||
|
|
||||||
return pol.ExpandAlias(nodes, h.String())
|
return pol.ExpandAlias(nodes, users, h.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// if alias is an IP
|
// if alias is an IP
|
||||||
|
@ -751,16 +772,17 @@ func (pol *ACLPolicy) expandUsersFromGroup(
|
||||||
|
|
||||||
func (pol *ACLPolicy) expandIPsFromGroup(
|
func (pol *ACLPolicy) expandIPsFromGroup(
|
||||||
group string,
|
group string,
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
var build netipx.IPSetBuilder
|
var build netipx.IPSetBuilder
|
||||||
|
|
||||||
users, err := pol.expandUsersFromGroup(group)
|
userTokens, err := pol.expandUsersFromGroup(group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &netipx.IPSet{}, err
|
return &netipx.IPSet{}, err
|
||||||
}
|
}
|
||||||
for _, user := range users {
|
for _, user := range userTokens {
|
||||||
filteredNodes := filterNodesByUser(nodes, user)
|
filteredNodes := filterNodesByUser(nodes, users, user)
|
||||||
for _, node := range filteredNodes {
|
for _, node := range filteredNodes {
|
||||||
node.AppendToIPSet(&build)
|
node.AppendToIPSet(&build)
|
||||||
}
|
}
|
||||||
|
@ -771,6 +793,7 @@ func (pol *ACLPolicy) expandIPsFromGroup(
|
||||||
|
|
||||||
func (pol *ACLPolicy) expandIPsFromTag(
|
func (pol *ACLPolicy) expandIPsFromTag(
|
||||||
alias string,
|
alias string,
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
var build netipx.IPSetBuilder
|
var build netipx.IPSetBuilder
|
||||||
|
@ -803,7 +826,7 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
||||||
|
|
||||||
// filter out nodes per tag owner
|
// filter out nodes per tag owner
|
||||||
for _, user := range owners {
|
for _, user := range owners {
|
||||||
nodes := filterNodesByUser(nodes, user)
|
nodes := filterNodesByUser(nodes, users, user)
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.Hostinfo == nil {
|
if node.Hostinfo == nil {
|
||||||
continue
|
continue
|
||||||
|
@ -820,11 +843,12 @@ func (pol *ACLPolicy) expandIPsFromTag(
|
||||||
|
|
||||||
func (pol *ACLPolicy) expandIPsFromUser(
|
func (pol *ACLPolicy) expandIPsFromUser(
|
||||||
user string,
|
user string,
|
||||||
|
users []types.User,
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
) (*netipx.IPSet, error) {
|
) (*netipx.IPSet, error) {
|
||||||
var build netipx.IPSetBuilder
|
var build netipx.IPSetBuilder
|
||||||
|
|
||||||
filteredNodes := filterNodesByUser(nodes, user)
|
filteredNodes := filterNodesByUser(nodes, users, user)
|
||||||
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
|
filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user)
|
||||||
|
|
||||||
// shortcurcuit if we have no nodes to get ips from.
|
// shortcurcuit if we have no nodes to get ips from.
|
||||||
|
@ -953,10 +977,40 @@ func (pol *ACLPolicy) TagsOfNode(
|
||||||
return validTags, invalidTags
|
return validTags, invalidTags
|
||||||
}
|
}
|
||||||
|
|
||||||
func filterNodesByUser(nodes types.Nodes, user string) types.Nodes {
|
// filterNodesByUser returns a list of nodes that match the given userToken from a
|
||||||
|
// policy.
|
||||||
|
// Matching nodes are determined by first matching the user token to a user by checking:
|
||||||
|
// - If it is an ID that mactches the user database ID
|
||||||
|
// - It is the Provider Identifier from OIDC
|
||||||
|
// - It matches the username or email of a user
|
||||||
|
//
|
||||||
|
// If the token matches more than one user, zero nodes will returned.
|
||||||
|
func filterNodesByUser(nodes types.Nodes, users []types.User, userToken string) types.Nodes {
|
||||||
var out types.Nodes
|
var out types.Nodes
|
||||||
|
|
||||||
|
var potentialUsers []types.User
|
||||||
|
for _, user := range users {
|
||||||
|
if user.ProviderIdentifier == userToken {
|
||||||
|
potentialUsers = append(potentialUsers, user)
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if user.Email == userToken {
|
||||||
|
potentialUsers = append(potentialUsers, user)
|
||||||
|
}
|
||||||
|
if user.Name == userToken {
|
||||||
|
potentialUsers = append(potentialUsers, user)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(potentialUsers) != 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
user := potentialUsers[0]
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.User.Username() == user {
|
if node.User.ID == user.ID {
|
||||||
out = append(out, node)
|
out = append(out, node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,8 +2,10 @@ package policy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"math/rand/v2"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
@ -14,6 +16,7 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go4.org/netipx"
|
"go4.org/netipx"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/net/tsaddr"
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
@ -375,18 +378,24 @@ func TestParsing(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := pol.CompileFilterRules(types.Nodes{
|
user := types.User{
|
||||||
&types.Node{
|
Model: gorm.Model{ID: 1},
|
||||||
IPv4: iap("100.100.100.100"),
|
Name: "testuser",
|
||||||
|
}
|
||||||
|
rules, err := pol.CompileFilterRules(
|
||||||
|
[]types.User{
|
||||||
|
user,
|
||||||
},
|
},
|
||||||
&types.Node{
|
types.Nodes{
|
||||||
IPv4: iap("200.200.200.200"),
|
&types.Node{
|
||||||
User: types.User{
|
IPv4: iap("100.100.100.100"),
|
||||||
Name: "testuser",
|
|
||||||
},
|
},
|
||||||
Hostinfo: &tailcfg.Hostinfo{},
|
&types.Node{
|
||||||
},
|
IPv4: iap("200.200.200.200"),
|
||||||
})
|
User: user,
|
||||||
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
@ -533,7 +542,7 @@ func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
|
||||||
c.Assert(pol.ACLs, check.HasLen, 6)
|
c.Assert(pol.ACLs, check.HasLen, 6)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
rules, err := pol.CompileFilterRules(types.Nodes{})
|
rules, err := pol.CompileFilterRules([]types.User{}, types.Nodes{})
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(rules, check.IsNil)
|
c.Assert(rules, check.IsNil)
|
||||||
}
|
}
|
||||||
|
@ -549,7 +558,7 @@ func (s *Suite) TestInvalidAction(c *check.C) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -568,7 +577,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -584,7 +593,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{})
|
_, _, err := GenerateFilterAndSSHRulesForTests(pol, &types.Node{}, types.Nodes{}, []types.User{})
|
||||||
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
|
c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -861,6 +870,14 @@ func Test_expandPorts(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_listNodesInUser(t *testing.T) {
|
func Test_listNodesInUser(t *testing.T) {
|
||||||
|
users := []types.User{
|
||||||
|
{Model: gorm.Model{ID: 1}, Name: "marc"},
|
||||||
|
{Model: gorm.Model{ID: 2}, Name: "joe", Email: "joe@headscale.net"},
|
||||||
|
{Model: gorm.Model{ID: 3}, Name: "mikael", Email: "mikael@headscale.net", ProviderIdentifier: "http://oidc.org/1234"},
|
||||||
|
{Model: gorm.Model{ID: 4}, Name: "mikael2", Email: "mikael@headscale.net"},
|
||||||
|
{Model: gorm.Model{ID: 5}, Name: "mikael", Email: "mikael2@headscale.net"},
|
||||||
|
}
|
||||||
|
|
||||||
type args struct {
|
type args struct {
|
||||||
nodes types.Nodes
|
nodes types.Nodes
|
||||||
user string
|
user string
|
||||||
|
@ -874,50 +891,239 @@ func Test_listNodesInUser(t *testing.T) {
|
||||||
name: "1 node in user",
|
name: "1 node in user",
|
||||||
args: args{
|
args: args{
|
||||||
nodes: types.Nodes{
|
nodes: types.Nodes{
|
||||||
&types.Node{User: types.User{Name: "joe"}},
|
&types.Node{User: users[1]},
|
||||||
},
|
},
|
||||||
user: "joe",
|
user: "joe",
|
||||||
},
|
},
|
||||||
want: types.Nodes{
|
want: types.Nodes{
|
||||||
&types.Node{User: types.User{Name: "joe"}},
|
&types.Node{User: users[1]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "3 nodes, 2 in user",
|
name: "3 nodes, 2 in user",
|
||||||
args: args{
|
args: args{
|
||||||
nodes: types.Nodes{
|
nodes: types.Nodes{
|
||||||
&types.Node{ID: 1, User: types.User{Name: "joe"}},
|
&types.Node{ID: 1, User: users[1]},
|
||||||
&types.Node{ID: 2, User: types.User{Name: "marc"}},
|
&types.Node{ID: 2, User: users[0]},
|
||||||
&types.Node{ID: 3, User: types.User{Name: "marc"}},
|
&types.Node{ID: 3, User: users[0]},
|
||||||
},
|
},
|
||||||
user: "marc",
|
user: "marc",
|
||||||
},
|
},
|
||||||
want: types.Nodes{
|
want: types.Nodes{
|
||||||
&types.Node{ID: 2, User: types.User{Name: "marc"}},
|
&types.Node{ID: 2, User: users[0]},
|
||||||
&types.Node{ID: 3, User: types.User{Name: "marc"}},
|
&types.Node{ID: 3, User: users[0]},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "5 nodes, 0 in user",
|
name: "5 nodes, 0 in user",
|
||||||
args: args{
|
args: args{
|
||||||
nodes: types.Nodes{
|
nodes: types.Nodes{
|
||||||
&types.Node{ID: 1, User: types.User{Name: "joe"}},
|
&types.Node{ID: 1, User: users[1]},
|
||||||
&types.Node{ID: 2, User: types.User{Name: "marc"}},
|
&types.Node{ID: 2, User: users[0]},
|
||||||
&types.Node{ID: 3, User: types.User{Name: "marc"}},
|
&types.Node{ID: 3, User: users[0]},
|
||||||
&types.Node{ID: 4, User: types.User{Name: "marc"}},
|
&types.Node{ID: 4, User: users[0]},
|
||||||
&types.Node{ID: 5, User: types.User{Name: "marc"}},
|
&types.Node{ID: 5, User: users[0]},
|
||||||
},
|
},
|
||||||
user: "mickael",
|
user: "mickael",
|
||||||
},
|
},
|
||||||
want: nil,
|
want: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "match-by-provider-ident",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[1]},
|
||||||
|
&types.Node{ID: 2, User: users[2]},
|
||||||
|
},
|
||||||
|
user: "http://oidc.org/1234",
|
||||||
|
},
|
||||||
|
want: types.Nodes{
|
||||||
|
&types.Node{ID: 2, User: users[2]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "match-by-email",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[1]},
|
||||||
|
&types.Node{ID: 2, User: users[2]},
|
||||||
|
},
|
||||||
|
user: "joe@headscale.net",
|
||||||
|
},
|
||||||
|
want: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[1]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-match-is-zero",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[1]},
|
||||||
|
&types.Node{ID: 2, User: users[2]},
|
||||||
|
&types.Node{ID: 3, User: users[3]},
|
||||||
|
},
|
||||||
|
user: "mikael@headscale.net",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-email-first-match-is-zero",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
// First match email, then provider id
|
||||||
|
&types.Node{ID: 3, User: users[3]},
|
||||||
|
&types.Node{ID: 2, User: users[2]},
|
||||||
|
},
|
||||||
|
user: "mikael@headscale.net",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-username-first-match-is-zero",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
// First match username, then provider id
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 2, User: users[2]},
|
||||||
|
},
|
||||||
|
user: "mikael",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-users-duplicate-username-random-order",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 5, User: users[4]},
|
||||||
|
},
|
||||||
|
user: "mikael",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-users-unique-username-random-order",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 5, User: users[4]},
|
||||||
|
},
|
||||||
|
user: "marc",
|
||||||
|
},
|
||||||
|
want: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-users-no-username-random-order",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 5, User: users[4]},
|
||||||
|
},
|
||||||
|
user: "not-working",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-users-duplicate-email-random-order",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 5, User: users[4]},
|
||||||
|
},
|
||||||
|
user: "mikael@headscale.net",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-users-duplicate-email-random-order",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 5, User: users[4]},
|
||||||
|
},
|
||||||
|
user: "joe@headscale.net",
|
||||||
|
},
|
||||||
|
want: types.Nodes{
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-users-no-email-random-order",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 5, User: users[4]},
|
||||||
|
},
|
||||||
|
user: "not-working@headscale.net",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-users-provider-id-random-order",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 5, User: users[4]},
|
||||||
|
},
|
||||||
|
user: "http://oidc.org/1234",
|
||||||
|
},
|
||||||
|
want: types.Nodes{
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all-users-no-provider-id-random-order",
|
||||||
|
args: args{
|
||||||
|
nodes: types.Nodes{
|
||||||
|
&types.Node{ID: 1, User: users[0]},
|
||||||
|
&types.Node{ID: 2, User: users[1]},
|
||||||
|
&types.Node{ID: 3, User: users[2]},
|
||||||
|
&types.Node{ID: 4, User: users[3]},
|
||||||
|
&types.Node{ID: 5, User: users[4]},
|
||||||
|
},
|
||||||
|
user: "http://oidc.org/4321",
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
got := filterNodesByUser(test.args.nodes, test.args.user)
|
for range 1000 {
|
||||||
|
ns := test.args.nodes
|
||||||
|
rand.Shuffle(len(ns), func(i, j int) {
|
||||||
|
ns[i], ns[j] = ns[j], ns[i]
|
||||||
|
})
|
||||||
|
got := filterNodesByUser(ns, users, test.args.user)
|
||||||
|
sort.Slice(got, func(i, j int) bool {
|
||||||
|
return got[i].ID < got[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
|
if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" {
|
||||||
t.Errorf("listNodesInUser() = (-want +got):\n%s", diff)
|
t.Errorf("filterNodesByUser() = (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -940,6 +1146,12 @@ func Test_expandAlias(t *testing.T) {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
users := []types.User{
|
||||||
|
{Model: gorm.Model{ID: 1}, Name: "joe"},
|
||||||
|
{Model: gorm.Model{ID: 2}, Name: "marc"},
|
||||||
|
{Model: gorm.Model{ID: 3}, Name: "mickael"},
|
||||||
|
}
|
||||||
|
|
||||||
type field struct {
|
type field struct {
|
||||||
pol ACLPolicy
|
pol ACLPolicy
|
||||||
}
|
}
|
||||||
|
@ -989,19 +1201,19 @@ func Test_expandAlias(t *testing.T) {
|
||||||
nodes: types.Nodes{
|
nodes: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.3"),
|
IPv4: iap("100.64.0.3"),
|
||||||
User: types.User{Name: "marc"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.4"),
|
IPv4: iap("100.64.0.4"),
|
||||||
User: types.User{Name: "mickael"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1022,19 +1234,19 @@ func Test_expandAlias(t *testing.T) {
|
||||||
nodes: types.Nodes{
|
nodes: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.3"),
|
IPv4: iap("100.64.0.3"),
|
||||||
User: types.User{Name: "marc"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.4"),
|
IPv4: iap("100.64.0.4"),
|
||||||
User: types.User{Name: "mickael"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1185,7 +1397,7 @@ func Test_expandAlias(t *testing.T) {
|
||||||
nodes: types.Nodes{
|
nodes: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
OS: "centos",
|
OS: "centos",
|
||||||
Hostname: "foo",
|
Hostname: "foo",
|
||||||
|
@ -1194,7 +1406,7 @@ func Test_expandAlias(t *testing.T) {
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
OS: "centos",
|
OS: "centos",
|
||||||
Hostname: "foo",
|
Hostname: "foo",
|
||||||
|
@ -1203,11 +1415,11 @@ func Test_expandAlias(t *testing.T) {
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.3"),
|
IPv4: iap("100.64.0.3"),
|
||||||
User: types.User{Name: "marc"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.4"),
|
IPv4: iap("100.64.0.4"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1260,21 +1472,21 @@ func Test_expandAlias(t *testing.T) {
|
||||||
nodes: types.Nodes{
|
nodes: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
ForcedTags: []string{"tag:hr-webserver"},
|
ForcedTags: []string{"tag:hr-webserver"},
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
ForcedTags: []string{"tag:hr-webserver"},
|
ForcedTags: []string{"tag:hr-webserver"},
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.3"),
|
IPv4: iap("100.64.0.3"),
|
||||||
User: types.User{Name: "marc"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.4"),
|
IPv4: iap("100.64.0.4"),
|
||||||
User: types.User{Name: "mickael"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1295,12 +1507,12 @@ func Test_expandAlias(t *testing.T) {
|
||||||
nodes: types.Nodes{
|
nodes: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
ForcedTags: []string{"tag:hr-webserver"},
|
ForcedTags: []string{"tag:hr-webserver"},
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
OS: "centos",
|
OS: "centos",
|
||||||
Hostname: "foo",
|
Hostname: "foo",
|
||||||
|
@ -1309,11 +1521,11 @@ func Test_expandAlias(t *testing.T) {
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.3"),
|
IPv4: iap("100.64.0.3"),
|
||||||
User: types.User{Name: "marc"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.4"),
|
IPv4: iap("100.64.0.4"),
|
||||||
User: types.User{Name: "mickael"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1350,12 +1562,12 @@ func Test_expandAlias(t *testing.T) {
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.3"),
|
IPv4: iap("100.64.0.3"),
|
||||||
User: types.User{Name: "marc"},
|
User: users[1],
|
||||||
Hostinfo: &tailcfg.Hostinfo{},
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.4"),
|
IPv4: iap("100.64.0.4"),
|
||||||
User: types.User{Name: "joe"},
|
User: users[0],
|
||||||
Hostinfo: &tailcfg.Hostinfo{},
|
Hostinfo: &tailcfg.Hostinfo{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1368,6 +1580,7 @@ func Test_expandAlias(t *testing.T) {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
got, err := test.field.pol.ExpandAlias(
|
got, err := test.field.pol.ExpandAlias(
|
||||||
test.args.nodes,
|
test.args.nodes,
|
||||||
|
users,
|
||||||
test.args.alias,
|
test.args.alias,
|
||||||
)
|
)
|
||||||
if (err != nil) != test.wantErr {
|
if (err != nil) != test.wantErr {
|
||||||
|
@ -1715,6 +1928,7 @@ func TestACLPolicy_generateFilterRules(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.field.pol.CompileFilterRules(
|
got, err := tt.field.pol.CompileFilterRules(
|
||||||
|
[]types.User{},
|
||||||
tt.args.nodes,
|
tt.args.nodes,
|
||||||
)
|
)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
|
@ -1834,6 +2048,13 @@ func TestTheInternet(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReduceFilterRules(t *testing.T) {
|
func TestReduceFilterRules(t *testing.T) {
|
||||||
|
users := []types.User{
|
||||||
|
{Model: gorm.Model{ID: 1}, Name: "mickael"},
|
||||||
|
{Model: gorm.Model{ID: 2}, Name: "user1"},
|
||||||
|
{Model: gorm.Model{ID: 3}, Name: "user2"},
|
||||||
|
{Model: gorm.Model{ID: 4}, Name: "user100"},
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
node *types.Node
|
node *types.Node
|
||||||
|
@ -1855,13 +2076,13 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
||||||
User: types.User{Name: "mickael"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
peers: types.Nodes{
|
peers: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
|
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
|
||||||
User: types.User{Name: "mickael"},
|
User: users[0],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{},
|
want: []tailcfg.FilterRule{},
|
||||||
|
@ -1888,7 +2109,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1"},
|
User: users[1],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{
|
RoutableIPs: []netip.Prefix{
|
||||||
netip.MustParsePrefix("10.33.0.0/16"),
|
netip.MustParsePrefix("10.33.0.0/16"),
|
||||||
|
@ -1899,7 +2120,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user1"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -1967,19 +2188,19 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
peers: types.Nodes{
|
peers: types.Nodes{
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
// "internal" exit node
|
// "internal" exit node
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.100"),
|
IPv4: iap("100.64.0.100"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: tsaddr.ExitRoutes(),
|
RoutableIPs: tsaddr.ExitRoutes(),
|
||||||
},
|
},
|
||||||
|
@ -2026,12 +2247,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -2113,7 +2334,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: iap("100.64.0.100"),
|
IPv4: iap("100.64.0.100"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: tsaddr.ExitRoutes(),
|
RoutableIPs: tsaddr.ExitRoutes(),
|
||||||
},
|
},
|
||||||
|
@ -2122,12 +2343,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -2215,7 +2436,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: iap("100.64.0.100"),
|
IPv4: iap("100.64.0.100"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
|
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
|
||||||
},
|
},
|
||||||
|
@ -2224,12 +2445,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -2292,7 +2513,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: iap("100.64.0.100"),
|
IPv4: iap("100.64.0.100"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
|
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
|
||||||
},
|
},
|
||||||
|
@ -2301,12 +2522,12 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::2"),
|
IPv6: iap("fd7a:115c:a1e0::2"),
|
||||||
User: types.User{Name: "user2"},
|
User: users[2],
|
||||||
},
|
},
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -2362,7 +2583,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
node: &types.Node{
|
node: &types.Node{
|
||||||
IPv4: iap("100.64.0.100"),
|
IPv4: iap("100.64.0.100"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::100"),
|
IPv6: iap("fd7a:115c:a1e0::100"),
|
||||||
User: types.User{Name: "user100"},
|
User: users[3],
|
||||||
Hostinfo: &tailcfg.Hostinfo{
|
Hostinfo: &tailcfg.Hostinfo{
|
||||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||||
},
|
},
|
||||||
|
@ -2372,7 +2593,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
&types.Node{
|
&types.Node{
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
IPv6: iap("fd7a:115c:a1e0::1"),
|
IPv6: iap("fd7a:115c:a1e0::1"),
|
||||||
User: types.User{Name: "user1"},
|
User: users[1],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
want: []tailcfg.FilterRule{
|
want: []tailcfg.FilterRule{
|
||||||
|
@ -2400,6 +2621,7 @@ func TestReduceFilterRules(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, _ := tt.pol.CompileFilterRules(
|
got, _ := tt.pol.CompileFilterRules(
|
||||||
|
users,
|
||||||
append(tt.peers, tt.node),
|
append(tt.peers, tt.node),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3391,7 +3613,7 @@ func TestSSHRules(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
|
got, err := tt.pol.CompileSSHPolicy(&tt.node, []types.User{}, tt.peers)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
@ -3474,14 +3696,17 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
|
||||||
RequestTags: []string{"tag:test"},
|
RequestTags: []string{"tag:test"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
|
||||||
node := &types.Node{
|
node := &types.Node{
|
||||||
ID: 0,
|
ID: 0,
|
||||||
Hostname: "testnodes",
|
Hostname: "testnodes",
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
UserID: 0,
|
UserID: 0,
|
||||||
User: types.User{
|
User: user,
|
||||||
Name: "user1",
|
|
||||||
},
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Hostinfo: &hostInfo,
|
Hostinfo: &hostInfo,
|
||||||
}
|
}
|
||||||
|
@ -3498,7 +3723,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{user})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -3532,7 +3757,8 @@ func TestInvalidTagValidUser(t *testing.T) {
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
User: types.User{
|
User: types.User{
|
||||||
Name: "user1",
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
},
|
},
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Hostinfo: &hostInfo,
|
Hostinfo: &hostInfo,
|
||||||
|
@ -3549,7 +3775,7 @@ func TestInvalidTagValidUser(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -3583,7 +3809,8 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
User: types.User{
|
User: types.User{
|
||||||
Name: "user1",
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
},
|
},
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Hostinfo: &hostInfo,
|
Hostinfo: &hostInfo,
|
||||||
|
@ -3608,7 +3835,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
|
||||||
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
// c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||||
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||||
|
|
||||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{})
|
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{}, []types.User{node.User})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
@ -3637,15 +3864,17 @@ func TestValidTagInvalidUser(t *testing.T) {
|
||||||
Hostname: "webserver",
|
Hostname: "webserver",
|
||||||
RequestTags: []string{"tag:webapp"},
|
RequestTags: []string{"tag:webapp"},
|
||||||
}
|
}
|
||||||
|
user := types.User{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "user1",
|
||||||
|
}
|
||||||
|
|
||||||
node := &types.Node{
|
node := &types.Node{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
Hostname: "webserver",
|
Hostname: "webserver",
|
||||||
IPv4: iap("100.64.0.1"),
|
IPv4: iap("100.64.0.1"),
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
User: types.User{
|
User: user,
|
||||||
Name: "user1",
|
|
||||||
},
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Hostinfo: &hostInfo,
|
Hostinfo: &hostInfo,
|
||||||
}
|
}
|
||||||
|
@ -3656,13 +3885,11 @@ func TestValidTagInvalidUser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes2 := &types.Node{
|
nodes2 := &types.Node{
|
||||||
ID: 2,
|
ID: 2,
|
||||||
Hostname: "user",
|
Hostname: "user",
|
||||||
IPv4: iap("100.64.0.2"),
|
IPv4: iap("100.64.0.2"),
|
||||||
UserID: 1,
|
UserID: 1,
|
||||||
User: types.User{
|
User: user,
|
||||||
Name: "user1",
|
|
||||||
},
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
Hostinfo: &hostInfo2,
|
Hostinfo: &hostInfo2,
|
||||||
}
|
}
|
||||||
|
@ -3678,7 +3905,7 @@ func TestValidTagInvalidUser(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2})
|
got, _, err := GenerateFilterAndSSHRulesForTests(pol, node, types.Nodes{nodes2}, []types.User{user})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
want := []tailcfg.FilterRule{
|
want := []tailcfg.FilterRule{
|
||||||
|
|
181
hscontrol/policy/pm.go
Normal file
181
hscontrol/policy/pm.go
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"go4.org/netipx"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/util/deephash"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PolicyManager interface {
|
||||||
|
Filter() []tailcfg.FilterRule
|
||||||
|
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||||
|
Tags(*types.Node) []string
|
||||||
|
ApproversForRoute(netip.Prefix) []string
|
||||||
|
ExpandAlias(string) (*netipx.IPSet, error)
|
||||||
|
SetPolicy([]byte) (bool, error)
|
||||||
|
SetUsers(users []types.User) (bool, error)
|
||||||
|
SetNodes(nodes types.Nodes) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||||
|
policyFile, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer policyFile.Close()
|
||||||
|
|
||||||
|
policyBytes, err := io.ReadAll(policyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewPolicyManager(policyBytes, users, nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||||
|
var pol *ACLPolicy
|
||||||
|
var err error
|
||||||
|
if polB != nil && len(polB) > 0 {
|
||||||
|
pol, err = LoadACLPolicyFromBytes(polB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := PolicyManagerV1{
|
||||||
|
pol: pol,
|
||||||
|
users: users,
|
||||||
|
nodes: nodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = pm.updateLocked()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||||
|
pm := PolicyManagerV1{
|
||||||
|
pol: pol,
|
||||||
|
users: users,
|
||||||
|
nodes: nodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := pm.updateLocked()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type PolicyManagerV1 struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
pol *ACLPolicy
|
||||||
|
|
||||||
|
users []types.User
|
||||||
|
nodes types.Nodes
|
||||||
|
|
||||||
|
filterHash deephash.Sum
|
||||||
|
filter []tailcfg.FilterRule
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||||
|
// It must be called with the lock held.
|
||||||
|
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
|
||||||
|
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filterHash := deephash.Hash(&filter)
|
||||||
|
if filterHash == pm.filterHash {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.filter = filter
|
||||||
|
pm.filterHash = filterHash
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
return pm.filter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
|
||||||
|
pol, err := LoadACLPolicyFromBytes(polB)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("parsing policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
pm.pol = pol
|
||||||
|
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||||
|
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
pm.users = users
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||||
|
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
pm.nodes = nodes
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
|
||||||
|
if pm == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tags, _ := pm.pol.TagsOfNode(node)
|
||||||
|
return tags
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
|
||||||
|
// TODO(kradalby): This can be a parse error of the address in the policy,
|
||||||
|
// in the new policy this will be typed and not a problem, in this policy
|
||||||
|
// we will just return empty list
|
||||||
|
if pm.pol == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||||
|
return approvers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
|
||||||
|
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ips, nil
|
||||||
|
}
|
158
hscontrol/policy/pm_test.go
Normal file
158
hscontrol/policy/pm_test.go
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicySetChange(t *testing.T) {
|
||||||
|
users := []types.User{
|
||||||
|
{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "testuser",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
users []types.User
|
||||||
|
nodes types.Nodes
|
||||||
|
policy []byte
|
||||||
|
wantUsersChange bool
|
||||||
|
wantNodesChange bool
|
||||||
|
wantPolicyChange bool
|
||||||
|
wantFilter []tailcfg.FilterRule
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "set-nodes",
|
||||||
|
nodes: types.Nodes{
|
||||||
|
{
|
||||||
|
IPv4: iap("100.64.0.2"),
|
||||||
|
User: users[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantNodesChange: false,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-users",
|
||||||
|
users: users,
|
||||||
|
wantUsersChange: false,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-users-and-node",
|
||||||
|
users: users,
|
||||||
|
nodes: types.Nodes{
|
||||||
|
{
|
||||||
|
IPv4: iap("100.64.0.2"),
|
||||||
|
User: users[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantUsersChange: false,
|
||||||
|
wantNodesChange: true,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"100.64.0.2/32"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-policy",
|
||||||
|
policy: []byte(`
|
||||||
|
{
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"100.64.0.61",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"100.64.0.62:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`),
|
||||||
|
wantPolicyChange: true,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"100.64.0.61/32"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pol := `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:example": [
|
||||||
|
"testuser",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.64.0.1",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"group:example",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`
|
||||||
|
pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tt.policy != nil {
|
||||||
|
change, err := pm.SetPolicy(tt.policy)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantPolicyChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.users != nil {
|
||||||
|
change, err := pm.SetUsers(tt.users)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantUsersChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.nodes != nil {
|
||||||
|
change, err := pm.SetNodes(tt.nodes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantNodesChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
|
||||||
|
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() {
|
||||||
switch update.Type {
|
switch update.Type {
|
||||||
case types.StateFullUpdate:
|
case types.StateFullUpdate:
|
||||||
m.tracef("Sending Full MapResponse")
|
m.tracef("Sending Full MapResponse")
|
||||||
data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
||||||
case types.StatePeerChanged:
|
case types.StatePeerChanged:
|
||||||
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
||||||
|
|
||||||
|
@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() {
|
||||||
|
|
||||||
lastMessage = update.Message
|
lastMessage = update.Message
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||||
updateType = "change"
|
updateType = "change"
|
||||||
|
|
||||||
case types.StatePeerChangedPatch:
|
case types.StatePeerChangedPatch:
|
||||||
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
||||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy)
|
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
|
||||||
updateType = "patch"
|
updateType = "patch"
|
||||||
case types.StatePeerRemoved:
|
case types.StatePeerRemoved:
|
||||||
changed := make(map[types.NodeID]bool, len(update.Removed))
|
changed := make(map[types.NodeID]bool, len(update.Removed))
|
||||||
|
@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() {
|
||||||
changed[nodeID] = false
|
changed[nodeID] = false
|
||||||
}
|
}
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||||
updateType = "remove"
|
updateType = "remove"
|
||||||
case types.StateSelfUpdate:
|
case types.StateSelfUpdate:
|
||||||
lastMessage = update.Message
|
lastMessage = update.Message
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||||
// create the map so an empty (self) update is sent
|
// create the map so an empty (self) update is sent
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
|
||||||
updateType = "remove"
|
updateType = "remove"
|
||||||
case types.StateDERPUpdated:
|
case types.StateDERPUpdated:
|
||||||
m.tracef("Sending DERPUpdate MapResponse")
|
m.tracef("Sending DERPUpdate MapResponse")
|
||||||
|
@ -488,9 +488,12 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.h.ACLPolicy != nil {
|
// TODO(kradalby): Only update the node that has actually changed
|
||||||
|
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
|
||||||
|
|
||||||
|
if m.h.polMan != nil {
|
||||||
// update routes with peer information
|
// update routes with peer information
|
||||||
err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node)
|
err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Error running auto approved routes")
|
m.errf(err, "Error running auto approved routes")
|
||||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||||
|
@ -544,7 +547,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||||
func (m *mapSession) handleReadOnlyRequest() {
|
func (m *mapSession) handleReadOnlyRequest() {
|
||||||
m.tracef("Client asked for a lite update, responding without peers")
|
m.tracef("Client asked for a lite update, responding without peers")
|
||||||
|
|
||||||
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy)
|
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Failed to create MapResponse")
|
m.errf(err, "Failed to create MapResponse")
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
http.Error(m.w, "", http.StatusInternalServerError)
|
||||||
|
|
|
@ -28,8 +28,9 @@ const (
|
||||||
maxDuration time.Duration = 1<<63 - 1
|
maxDuration time.Duration = 1<<63 - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
var errOidcMutuallyExclusive = errors.New(
|
var (
|
||||||
"oidc_client_secret and oidc_client_secret_path are mutually exclusive",
|
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
||||||
|
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPAllocationStrategy string
|
type IPAllocationStrategy string
|
||||||
|
@ -827,11 +828,10 @@ func LoadServerConfig() (*Config, error) {
|
||||||
// - DERP run on their own domains
|
// - DERP run on their own domains
|
||||||
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
||||||
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
||||||
if dnsConfig.BaseDomain != "" &&
|
if dnsConfig.BaseDomain != "" {
|
||||||
strings.Contains(serverURL, dnsConfig.BaseDomain) {
|
if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil {
|
||||||
return nil, errors.New(
|
return nil, err
|
||||||
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
|
}
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
|
@ -924,6 +924,37 @@ func LoadServerConfig() (*Config, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BaseDomain cannot be a suffix of the server URL.
|
||||||
|
// This is because Tailscale takes over the domain in BaseDomain,
|
||||||
|
// causing the headscale server and DERP to be unreachable.
|
||||||
|
// For Tailscale upstream, the following is true:
|
||||||
|
// - DERP run on their own domains.
|
||||||
|
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com.
|
||||||
|
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net).
|
||||||
|
func isSafeServerURL(serverURL, baseDomain string) error {
|
||||||
|
server, err := url.Parse(serverURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverDomainParts := strings.Split(server.Host, ".")
|
||||||
|
baseDomainParts := strings.Split(baseDomain, ".")
|
||||||
|
|
||||||
|
if len(serverDomainParts) <= len(baseDomainParts) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s := len(serverDomainParts)
|
||||||
|
b := len(baseDomainParts)
|
||||||
|
for i := range len(baseDomainParts) {
|
||||||
|
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errServerURLSuffix
|
||||||
|
}
|
||||||
|
|
||||||
type deprecator struct {
|
type deprecator struct {
|
||||||
warns set.Set[string]
|
warns set.Set[string]
|
||||||
fatals set.Set[string]
|
fatals set.Set[string]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -139,7 +140,7 @@ func TestReadConfig(t *testing.T) {
|
||||||
return LoadServerConfig()
|
return LoadServerConfig()
|
||||||
},
|
},
|
||||||
want: nil,
|
want: nil,
|
||||||
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
|
wantErr: errServerURLSuffix.Error(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "base-domain-not-in-server-url",
|
name: "base-domain-not-in-server-url",
|
||||||
|
@ -333,3 +334,64 @@ tls_letsencrypt_challenge_type: TLS-ALPN-01
|
||||||
err = LoadConfig(tmpDir, false)
|
err = LoadConfig(tmpDir, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OK
|
||||||
|
// server_url: headscale.com, base: clients.headscale.com
|
||||||
|
// server_url: headscale.com, base: headscale.net
|
||||||
|
//
|
||||||
|
// NOT OK
|
||||||
|
// server_url: server.headscale.com, base: headscale.com.
|
||||||
|
func TestSafeServerURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
serverURL, baseDomain,
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
serverURL: "https://example.com",
|
||||||
|
baseDomain: "example.org",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://headscale.com",
|
||||||
|
baseDomain: "headscale.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://headscale.com",
|
||||||
|
baseDomain: "clients.headscale.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://headscale.com",
|
||||||
|
baseDomain: "clients.subdomain.headscale.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://headscale.kristoffer.com",
|
||||||
|
baseDomain: "mybase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://server.headscale.com",
|
||||||
|
baseDomain: "headscale.com",
|
||||||
|
wantErr: errServerURLSuffix.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "https://server.subdomain.headscale.com",
|
||||||
|
baseDomain: "headscale.com",
|
||||||
|
wantErr: errServerURLSuffix.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
serverURL: "http://foo\x00",
|
||||||
|
wantErr: `parse "http://foo\x00": net/url: invalid control character in URL`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
testName := fmt.Sprintf("server=%s domain=%s", tt.serverURL, tt.baseDomain)
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
err := isSafeServerURL(tt.serverURL, tt.baseDomain)
|
||||||
|
if tt.wantErr != "" {
|
||||||
|
assert.EqualError(t, err, tt.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -223,6 +223,16 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
|
||||||
return found
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool {
|
||||||
|
for _, node := range nodes {
|
||||||
|
if node.NodeKey == nodeKey {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (node *Node) Proto() *v1.Node {
|
func (node *Node) Proto() *v1.Node {
|
||||||
nodeProto := &v1.Node{
|
nodeProto := &v1.Node{
|
||||||
Id: uint64(node.ID),
|
Id: uint64(node.ID),
|
||||||
|
|
|
@ -8,7 +8,7 @@ prefixes:
|
||||||
database:
|
database:
|
||||||
type: sqlite3
|
type: sqlite3
|
||||||
|
|
||||||
server_url: "https://derp.no"
|
server_url: "https://server.derp.no"
|
||||||
|
|
||||||
dns:
|
dns:
|
||||||
magic_dns: true
|
magic_dns: true
|
||||||
|
|
|
@ -11,10 +11,10 @@ Tests are located in files ending with `_test.go` and the framework are located
|
||||||
|
|
||||||
## Running integration tests locally
|
## Running integration tests locally
|
||||||
|
|
||||||
The easiest way to run tests locally is to use `[act](INSERT LINK)`, a local GitHub Actions runner:
|
The easiest way to run tests locally is to use [act](https://github.com/nektos/act), a local GitHub Actions runner:
|
||||||
|
|
||||||
```
|
```
|
||||||
act pull_request -W .github/workflows/test-integration-v2-TestPingAllByIP.yaml
|
act pull_request -W .github/workflows/test-integration.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, the `docker run` command in each GitHub workflow file can be used.
|
Alternatively, the `docker run` command in each GitHub workflow file can be used.
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
|
func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error {
|
||||||
|
@ -786,117 +787,85 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
|
func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
tests := []struct {
|
||||||
assertNoErr(t, err)
|
name string
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
policy *policy.ACLPolicy
|
||||||
|
wantTag bool
|
||||||
spec := map[string]int{
|
}{
|
||||||
"user1": 1,
|
{
|
||||||
}
|
name: "no-policy",
|
||||||
|
wantTag: false,
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:test"})}, hsic.WithTestName("cliadvtags"))
|
|
||||||
assertNoErr(t, err)
|
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
|
||||||
assertNoErr(t, err)
|
|
||||||
|
|
||||||
// Test list all nodes after added seconds
|
|
||||||
resultMachines := make([]*v1.Node, spec["user1"])
|
|
||||||
err = executeAndUnmarshal(
|
|
||||||
headscale,
|
|
||||||
[]string{
|
|
||||||
"headscale",
|
|
||||||
"nodes",
|
|
||||||
"list",
|
|
||||||
"--tags",
|
|
||||||
"--output", "json",
|
|
||||||
},
|
},
|
||||||
&resultMachines,
|
{
|
||||||
)
|
name: "with-policy",
|
||||||
assert.Nil(t, err)
|
policy: &policy.ACLPolicy{
|
||||||
found := false
|
ACLs: []policy.ACL{
|
||||||
for _, node := range resultMachines {
|
{
|
||||||
if node.GetInvalidTags() != nil {
|
Action: "accept",
|
||||||
for _, tag := range node.GetInvalidTags() {
|
Sources: []string{"*"},
|
||||||
if tag == "tag:test" {
|
Destinations: []string{"*:*"},
|
||||||
found = true
|
},
|
||||||
}
|
},
|
||||||
}
|
TagOwners: map[string][]string{
|
||||||
}
|
"tag:test": {"user1"},
|
||||||
}
|
|
||||||
assert.Equal(
|
|
||||||
t,
|
|
||||||
true,
|
|
||||||
found,
|
|
||||||
"should not find a node with the tag 'tag:test' in the list of nodes",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
|
|
||||||
IntegrationSkip(t)
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
scenario, err := NewScenario(dockertestMaxWait())
|
|
||||||
assertNoErr(t, err)
|
|
||||||
defer scenario.ShutdownAssertNoPanics(t)
|
|
||||||
|
|
||||||
spec := map[string]int{
|
|
||||||
"user1": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:exists"})}, hsic.WithTestName("cliadvtags"), hsic.WithACLPolicy(
|
|
||||||
&policy.ACLPolicy{
|
|
||||||
ACLs: []policy.ACL{
|
|
||||||
{
|
|
||||||
Action: "accept",
|
|
||||||
Sources: []string{"*"},
|
|
||||||
Destinations: []string{"*:*"},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
TagOwners: map[string][]string{
|
wantTag: true,
|
||||||
"tag:exists": {"user1"},
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
))
|
}
|
||||||
assertNoErr(t, err)
|
|
||||||
|
|
||||||
headscale, err := scenario.Headscale()
|
for _, tt := range tests {
|
||||||
assertNoErr(t, err)
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
// defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
// Test list all nodes after added seconds
|
spec := map[string]int{
|
||||||
resultMachines := make([]*v1.Node, spec["user1"])
|
"user1": 1,
|
||||||
err = executeAndUnmarshal(
|
}
|
||||||
headscale,
|
|
||||||
[]string{
|
err = scenario.CreateHeadscaleEnv(spec,
|
||||||
"headscale",
|
[]tsic.Option{tsic.WithTags([]string{"tag:test"})},
|
||||||
"nodes",
|
hsic.WithTestName("cliadvtags"),
|
||||||
"list",
|
hsic.WithACLPolicy(tt.policy),
|
||||||
"--tags",
|
)
|
||||||
"--output", "json",
|
assertNoErr(t, err)
|
||||||
},
|
|
||||||
&resultMachines,
|
headscale, err := scenario.Headscale()
|
||||||
)
|
assertNoErr(t, err)
|
||||||
assert.Nil(t, err)
|
|
||||||
found := false
|
// Test list all nodes after added seconds
|
||||||
for _, node := range resultMachines {
|
resultMachines := make([]*v1.Node, spec["user1"])
|
||||||
if node.GetValidTags() != nil {
|
err = executeAndUnmarshal(
|
||||||
for _, tag := range node.GetValidTags() {
|
headscale,
|
||||||
if tag == "tag:exists" {
|
[]string{
|
||||||
found = true
|
"headscale",
|
||||||
|
"nodes",
|
||||||
|
"list",
|
||||||
|
"--tags",
|
||||||
|
"--output", "json",
|
||||||
|
},
|
||||||
|
&resultMachines,
|
||||||
|
)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
found := false
|
||||||
|
for _, node := range resultMachines {
|
||||||
|
if tags := node.GetValidTags(); tags != nil {
|
||||||
|
found = slices.Contains(tags, "tag:test")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
assert.Equalf(
|
||||||
|
t,
|
||||||
|
tt.wantTag,
|
||||||
|
found,
|
||||||
|
"'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag,
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
assert.Equal(
|
|
||||||
t,
|
|
||||||
true,
|
|
||||||
found,
|
|
||||||
"should not find a node with the tag 'tag:exists' in the list of nodes",
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNodeCommand(t *testing.T) {
|
func TestNodeCommand(t *testing.T) {
|
||||||
|
@ -1732,7 +1701,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||||
policyFilePath,
|
policyFilePath,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert.ErrorContains(t, err, "verifying policy rules: invalid action")
|
assert.ErrorContains(t, err, "compiling filter rules: invalid action")
|
||||||
|
|
||||||
// The new policy was invalid, the old one should still be in place, which
|
// The new policy was invalid, the old one should still be in place, which
|
||||||
// is none.
|
// is none.
|
||||||
|
|
96
integration/derp_verify_endpoint_test.go
Normal file
96
integration/derp_verify_endpoint_test.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/juanfont/headscale/integration/dsic"
|
||||||
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDERPVerifyEndpoint(t *testing.T) {
|
||||||
|
IntegrationSkip(t)
|
||||||
|
|
||||||
|
// Generate random hostname for the headscale instance
|
||||||
|
hash, err := util.GenerateRandomStringDNSSafe(6)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
testName := "derpverify"
|
||||||
|
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
|
||||||
|
|
||||||
|
headscalePort := 8080
|
||||||
|
|
||||||
|
// Create cert for headscale
|
||||||
|
certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
|
assertNoErr(t, err)
|
||||||
|
defer scenario.ShutdownAssertNoPanics(t)
|
||||||
|
|
||||||
|
spec := map[string]int{
|
||||||
|
"user1": len(MustTestVersions),
|
||||||
|
}
|
||||||
|
|
||||||
|
derper, err := scenario.CreateDERPServer("head",
|
||||||
|
dsic.WithCACert(certHeadscale),
|
||||||
|
dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))),
|
||||||
|
)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
derpMap := tailcfg.DERPMap{
|
||||||
|
Regions: map[int]*tailcfg.DERPRegion{
|
||||||
|
900: {
|
||||||
|
RegionID: 900,
|
||||||
|
RegionCode: "test-derpverify",
|
||||||
|
RegionName: "TestDerpVerify",
|
||||||
|
Nodes: []*tailcfg.DERPNode{
|
||||||
|
{
|
||||||
|
Name: "TestDerpVerify",
|
||||||
|
RegionID: 900,
|
||||||
|
HostName: derper.GetHostname(),
|
||||||
|
STUNPort: derper.GetSTUNPort(),
|
||||||
|
STUNOnly: false,
|
||||||
|
DERPPort: derper.GetDERPPort(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())},
|
||||||
|
hsic.WithHostname(hostname),
|
||||||
|
hsic.WithPort(headscalePort),
|
||||||
|
hsic.WithCustomTLS(certHeadscale, keyHeadscale),
|
||||||
|
hsic.WithHostnameAsServerURL(),
|
||||||
|
hsic.WithDERPConfig(derpMap))
|
||||||
|
assertNoErrHeadscaleEnv(t, err)
|
||||||
|
|
||||||
|
allClients, err := scenario.ListTailscaleClients()
|
||||||
|
assertNoErrListClients(t, err)
|
||||||
|
|
||||||
|
for _, client := range allClients {
|
||||||
|
report, err := client.DebugDERPRegion("test-derpverify")
|
||||||
|
assertNoErr(t, err)
|
||||||
|
successful := false
|
||||||
|
for _, line := range report.Info {
|
||||||
|
if strings.Contains(line, "Successfully established a DERP connection with node") {
|
||||||
|
successful = true
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !successful {
|
||||||
|
stJSON, err := json.Marshal(report)
|
||||||
|
assertNoErr(t, err)
|
||||||
|
t.Errorf("Client %s could not establish a DERP connection: %s", client.Hostname(), string(stJSON))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
321
integration/dsic/dsic.go
Normal file
321
integration/dsic/dsic.go
Normal file
|
@ -0,0 +1,321 @@
|
||||||
|
package dsic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
|
"github.com/ory/dockertest/v3"
|
||||||
|
"github.com/ory/dockertest/v3/docker"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dsicHashLength = 6
|
||||||
|
dockerContextPath = "../."
|
||||||
|
caCertRoot = "/usr/local/share/ca-certificates"
|
||||||
|
DERPerCertRoot = "/usr/local/share/derper-certs"
|
||||||
|
dockerExecuteTimeout = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var errDERPerStatusCodeNotOk = errors.New("DERPer status code not OK")
|
||||||
|
|
||||||
|
// DERPServerInContainer represents DERP Server in Container (DSIC).
|
||||||
|
type DERPServerInContainer struct {
|
||||||
|
version string
|
||||||
|
hostname string
|
||||||
|
|
||||||
|
pool *dockertest.Pool
|
||||||
|
container *dockertest.Resource
|
||||||
|
network *dockertest.Network
|
||||||
|
|
||||||
|
stunPort int
|
||||||
|
derpPort int
|
||||||
|
caCerts [][]byte
|
||||||
|
tlsCert []byte
|
||||||
|
tlsKey []byte
|
||||||
|
withExtraHosts []string
|
||||||
|
withVerifyClientURL string
|
||||||
|
workdir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option represent optional settings that can be given to a
|
||||||
|
// DERPer instance.
|
||||||
|
type Option = func(c *DERPServerInContainer)
|
||||||
|
|
||||||
|
// WithCACert adds it to the trusted surtificate of the Tailscale container.
|
||||||
|
func WithCACert(cert []byte) Option {
|
||||||
|
return func(dsic *DERPServerInContainer) {
|
||||||
|
dsic.caCerts = append(dsic.caCerts, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOrCreateNetwork sets the Docker container network to use with
|
||||||
|
// the DERPer instance, if the parameter is nil, a new network,
|
||||||
|
// isolating the DERPer, will be created. If a network is
|
||||||
|
// passed, the DERPer instance will join the given network.
|
||||||
|
func WithOrCreateNetwork(network *dockertest.Network) Option {
|
||||||
|
return func(tsic *DERPServerInContainer) {
|
||||||
|
if network != nil {
|
||||||
|
tsic.network = network
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
network, err := dockertestutil.GetFirstOrCreateNetwork(
|
||||||
|
tsic.pool,
|
||||||
|
tsic.hostname+"-network",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create network: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tsic.network = network
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDockerWorkdir allows the docker working directory to be set.
|
||||||
|
func WithDockerWorkdir(dir string) Option {
|
||||||
|
return func(tsic *DERPServerInContainer) {
|
||||||
|
tsic.workdir = dir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithVerifyClientURL sets the URL to verify the client.
|
||||||
|
func WithVerifyClientURL(url string) Option {
|
||||||
|
return func(tsic *DERPServerInContainer) {
|
||||||
|
tsic.withVerifyClientURL = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExtraHosts adds extra hosts to the container.
|
||||||
|
func WithExtraHosts(hosts []string) Option {
|
||||||
|
return func(tsic *DERPServerInContainer) {
|
||||||
|
tsic.withExtraHosts = hosts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new TailscaleInContainer instance.
|
||||||
|
func New(
|
||||||
|
pool *dockertest.Pool,
|
||||||
|
version string,
|
||||||
|
network *dockertest.Network,
|
||||||
|
opts ...Option,
|
||||||
|
) (*DERPServerInContainer, error) {
|
||||||
|
hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
|
||||||
|
tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err)
|
||||||
|
}
|
||||||
|
dsic := &DERPServerInContainer{
|
||||||
|
version: version,
|
||||||
|
hostname: hostname,
|
||||||
|
pool: pool,
|
||||||
|
network: network,
|
||||||
|
tlsCert: tlsCert,
|
||||||
|
tlsKey: tlsKey,
|
||||||
|
stunPort: 3478, //nolint
|
||||||
|
derpPort: 443, //nolint
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(dsic)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmdArgs strings.Builder
|
||||||
|
fmt.Fprintf(&cmdArgs, "--hostname=%s", hostname)
|
||||||
|
fmt.Fprintf(&cmdArgs, " --certmode=manual")
|
||||||
|
fmt.Fprintf(&cmdArgs, " --certdir=%s", DERPerCertRoot)
|
||||||
|
fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort)
|
||||||
|
fmt.Fprintf(&cmdArgs, " --stun=true")
|
||||||
|
fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort)
|
||||||
|
if dsic.withVerifyClientURL != "" {
|
||||||
|
fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
runOptions := &dockertest.RunOptions{
|
||||||
|
Name: hostname,
|
||||||
|
Networks: []*dockertest.Network{dsic.network},
|
||||||
|
ExtraHosts: dsic.withExtraHosts,
|
||||||
|
// we currently need to give us some time to inject the certificate further down.
|
||||||
|
Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()},
|
||||||
|
ExposedPorts: []string{
|
||||||
|
"80/tcp",
|
||||||
|
fmt.Sprintf("%d/tcp", dsic.derpPort),
|
||||||
|
fmt.Sprintf("%d/udp", dsic.stunPort),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if dsic.workdir != "" {
|
||||||
|
runOptions.WorkingDir = dsic.workdir
|
||||||
|
}
|
||||||
|
|
||||||
|
// dockertest isnt very good at handling containers that has already
|
||||||
|
// been created, this is an attempt to make sure this container isnt
|
||||||
|
// present.
|
||||||
|
err = pool.RemoveContainerByName(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var container *dockertest.Resource
|
||||||
|
buildOptions := &dockertest.BuildOptions{
|
||||||
|
Dockerfile: "Dockerfile.derper",
|
||||||
|
ContextDir: dockerContextPath,
|
||||||
|
BuildArgs: []docker.BuildArg{},
|
||||||
|
}
|
||||||
|
switch version {
|
||||||
|
case "head":
|
||||||
|
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
|
||||||
|
Name: "VERSION_BRANCH",
|
||||||
|
Value: "main",
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
|
||||||
|
Name: "VERSION_BRANCH",
|
||||||
|
Value: "v" + version,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
container, err = pool.BuildAndRunWithBuildOptions(
|
||||||
|
buildOptions,
|
||||||
|
runOptions,
|
||||||
|
dockertestutil.DockerRestartPolicy,
|
||||||
|
dockertestutil.DockerAllowLocalIPv6,
|
||||||
|
dockertestutil.DockerAllowNetworkAdministration,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(
|
||||||
|
"%s could not start tailscale DERPer container (version: %s): %w",
|
||||||
|
hostname,
|
||||||
|
version,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
log.Printf("Created %s container\n", hostname)
|
||||||
|
|
||||||
|
dsic.container = container
|
||||||
|
|
||||||
|
for i, cert := range dsic.caCerts {
|
||||||
|
err = dsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(dsic.tlsCert) != 0 {
|
||||||
|
err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(dsic.tlsKey) != 0 {
|
||||||
|
err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write TLS key to container: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dsic, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown stops and cleans up the DERPer container.
|
||||||
|
func (t *DERPServerInContainer) Shutdown() error {
|
||||||
|
err := t.SaveLog("/tmp/control")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf(
|
||||||
|
"Failed to save log from %s: %s",
|
||||||
|
t.hostname,
|
||||||
|
fmt.Errorf("failed to save log: %w", err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.pool.Purge(t.container)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCert returns the TLS certificate of the DERPer instance.
|
||||||
|
func (t *DERPServerInContainer) GetCert() []byte {
|
||||||
|
return t.tlsCert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hostname returns the hostname of the DERPer instance.
|
||||||
|
func (t *DERPServerInContainer) Hostname() string {
|
||||||
|
return t.hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version returns the running DERPer version of the instance.
|
||||||
|
func (t *DERPServerInContainer) Version() string {
|
||||||
|
return t.version
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the Docker container ID of the DERPServerInContainer
|
||||||
|
// instance.
|
||||||
|
func (t *DERPServerInContainer) ID() string {
|
||||||
|
return t.container.Container.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *DERPServerInContainer) GetHostname() string {
|
||||||
|
return t.hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSTUNPort returns the STUN port of the DERPer instance.
|
||||||
|
func (t *DERPServerInContainer) GetSTUNPort() int {
|
||||||
|
return t.stunPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDERPPort returns the DERP port of the DERPer instance.
|
||||||
|
func (t *DERPServerInContainer) GetDERPPort() int {
|
||||||
|
return t.derpPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForRunning blocks until the DERPer instance is ready to be used.
|
||||||
|
func (t *DERPServerInContainer) WaitForRunning() error {
|
||||||
|
url := "https://" + net.JoinHostPort(t.GetHostname(), strconv.Itoa(t.GetDERPPort())) + "/"
|
||||||
|
log.Printf("waiting for DERPer to be ready at %s", url)
|
||||||
|
|
||||||
|
insecureTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint
|
||||||
|
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint
|
||||||
|
client := &http.Client{Transport: insecureTransport}
|
||||||
|
|
||||||
|
return t.pool.Retry(func() error {
|
||||||
|
resp, err := client.Get(url) //nolint
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("headscale is not ready: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return errDERPerStatusCodeNotOk
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectToNetwork connects the DERPer instance to a network.
|
||||||
|
func (t *DERPServerInContainer) ConnectToNetwork(network *dockertest.Network) error {
|
||||||
|
return t.container.ConnectToNetwork(network)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteFile save file inside the container.
|
||||||
|
func (t *DERPServerInContainer) WriteFile(path string, data []byte) error {
|
||||||
|
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveLog saves the current stdout log of the container to a path
|
||||||
|
// on the host system.
|
||||||
|
func (t *DERPServerInContainer) SaveLog(path string) error {
|
||||||
|
_, _, err := dockertestutil.SaveLog(t.pool, t.container, path)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
|
@ -55,7 +55,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
|
||||||
spec := map[string]ClientsSpec{
|
spec := map[string]ClientsSpec{
|
||||||
"user1": {
|
"user1": {
|
||||||
Plain: 0,
|
Plain: 0,
|
||||||
WebsocketDERP: len(MustTestVersions),
|
WebsocketDERP: 2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -239,10 +239,13 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv(
|
||||||
|
|
||||||
if clientCount.WebsocketDERP > 0 {
|
if clientCount.WebsocketDERP > 0 {
|
||||||
// Containers that use DERP-over-WebSocket
|
// Containers that use DERP-over-WebSocket
|
||||||
|
// Note that these clients *must* be built
|
||||||
|
// from source, which is currently
|
||||||
|
// only done for HEAD.
|
||||||
err = s.CreateTailscaleIsolatedNodesInUser(
|
err = s.CreateTailscaleIsolatedNodesInUser(
|
||||||
hash,
|
hash,
|
||||||
userName,
|
userName,
|
||||||
"all",
|
tsic.VersionHead,
|
||||||
clientCount.WebsocketDERP,
|
clientCount.WebsocketDERP,
|
||||||
tsic.WithWebsocketDERP(true),
|
tsic.WithWebsocketDERP(true),
|
||||||
)
|
)
|
||||||
|
@ -307,7 +310,7 @@ func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser(
|
||||||
cert := hsServer.GetCert()
|
cert := hsServer.GetCert()
|
||||||
|
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
tsic.WithHeadscaleTLS(cert),
|
tsic.WithCACert(cert),
|
||||||
)
|
)
|
||||||
|
|
||||||
user.createWaitGroup.Go(func() error {
|
user.createWaitGroup.Go(func() error {
|
||||||
|
|
|
@ -1,19 +1,12 @@
|
||||||
package hsic
|
package hsic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -32,11 +25,14 @@ import (
|
||||||
"github.com/juanfont/headscale/integration/integrationutil"
|
"github.com/juanfont/headscale/integration/integrationutil"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
"github.com/ory/dockertest/v3/docker"
|
"github.com/ory/dockertest/v3/docker"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
hsicHashLength = 6
|
hsicHashLength = 6
|
||||||
dockerContextPath = "../."
|
dockerContextPath = "../."
|
||||||
|
caCertRoot = "/usr/local/share/ca-certificates"
|
||||||
aclPolicyPath = "/etc/headscale/acl.hujson"
|
aclPolicyPath = "/etc/headscale/acl.hujson"
|
||||||
tlsCertPath = "/etc/headscale/tls.cert"
|
tlsCertPath = "/etc/headscale/tls.cert"
|
||||||
tlsKeyPath = "/etc/headscale/tls.key"
|
tlsKeyPath = "/etc/headscale/tls.key"
|
||||||
|
@ -64,6 +60,7 @@ type HeadscaleInContainer struct {
|
||||||
// optional config
|
// optional config
|
||||||
port int
|
port int
|
||||||
extraPorts []string
|
extraPorts []string
|
||||||
|
caCerts [][]byte
|
||||||
hostPortBindings map[string][]string
|
hostPortBindings map[string][]string
|
||||||
aclPolicy *policy.ACLPolicy
|
aclPolicy *policy.ACLPolicy
|
||||||
env map[string]string
|
env map[string]string
|
||||||
|
@ -81,6 +78,10 @@ type Option = func(c *HeadscaleInContainer)
|
||||||
// HeadscaleInContainer instance.
|
// HeadscaleInContainer instance.
|
||||||
func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
if acl == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Move somewhere appropriate
|
// TODO(kradalby): Move somewhere appropriate
|
||||||
hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath
|
hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath
|
||||||
|
|
||||||
|
@ -88,18 +89,29 @@ func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithCACert adds it to the trusted surtificate of the container.
|
||||||
|
func WithCACert(cert []byte) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
hsic.caCerts = append(hsic.caCerts, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithTLS creates certificates and enables HTTPS.
|
// WithTLS creates certificates and enables HTTPS.
|
||||||
func WithTLS() Option {
|
func WithTLS() Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
cert, key, err := createCertificate(hsic.hostname)
|
cert, key, err := integrationutil.CreateCertificate(hsic.hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create certificates for headscale test: %s", err)
|
log.Fatalf("failed to create certificates for headscale test: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Move somewhere appropriate
|
hsic.tlsCert = cert
|
||||||
hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath
|
hsic.tlsKey = key
|
||||||
hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCustomTLS uses the given certificates for the Headscale instance.
|
||||||
|
func WithCustomTLS(cert, key []byte) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
hsic.tlsCert = cert
|
hsic.tlsCert = cert
|
||||||
hsic.tlsKey = key
|
hsic.tlsKey = key
|
||||||
}
|
}
|
||||||
|
@ -146,6 +158,13 @@ func WithTestName(testName string) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithHostname sets the hostname of the Headscale instance.
|
||||||
|
func WithHostname(hostname string) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
hsic.hostname = hostname
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithHostnameAsServerURL sets the Headscale ServerURL based on
|
// WithHostnameAsServerURL sets the Headscale ServerURL based on
|
||||||
// the Hostname.
|
// the Hostname.
|
||||||
func WithHostnameAsServerURL() Option {
|
func WithHostnameAsServerURL() Option {
|
||||||
|
@ -203,6 +222,34 @@ func WithEmbeddedDERPServerOnly() Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDERPConfig configures Headscale use a custom
|
||||||
|
// DERP server only.
|
||||||
|
func WithDERPConfig(derpMap tailcfg.DERPMap) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
contents, err := yaml.Marshal(derpMap)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to marshal DERP map: %s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hsic.env["HEADSCALE_DERP_PATHS"] = "/etc/headscale/derp.yml"
|
||||||
|
hsic.filesInContainer = append(hsic.filesInContainer,
|
||||||
|
fileInContainer{
|
||||||
|
path: "/etc/headscale/derp.yml",
|
||||||
|
contents: contents,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Disable global DERP server and embedded DERP server
|
||||||
|
hsic.env["HEADSCALE_DERP_URLS"] = ""
|
||||||
|
hsic.env["HEADSCALE_DERP_SERVER_ENABLED"] = "false"
|
||||||
|
|
||||||
|
// Envknob for enabling DERP debug logs
|
||||||
|
hsic.env["DERP_DEBUG_LOGS"] = "true"
|
||||||
|
hsic.env["DERP_PROBER_DEBUG_LOGS"] = "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithTuning allows changing the tuning settings easily.
|
// WithTuning allows changing the tuning settings easily.
|
||||||
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
|
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
@ -300,6 +347,10 @@ func New(
|
||||||
"HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1",
|
"HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1",
|
||||||
"HEADSCALE_DEBUG_DUMP_CONFIG=1",
|
"HEADSCALE_DEBUG_DUMP_CONFIG=1",
|
||||||
}
|
}
|
||||||
|
if hsic.hasTLS() {
|
||||||
|
hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath
|
||||||
|
hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath
|
||||||
|
}
|
||||||
for key, value := range hsic.env {
|
for key, value := range hsic.env {
|
||||||
env = append(env, fmt.Sprintf("%s=%s", key, value))
|
env = append(env, fmt.Sprintf("%s=%s", key, value))
|
||||||
}
|
}
|
||||||
|
@ -313,7 +364,7 @@ func New(
|
||||||
// Cmd: []string{"headscale", "serve"},
|
// Cmd: []string{"headscale", "serve"},
|
||||||
// TODO(kradalby): Get rid of this hack, we currently need to give us some
|
// TODO(kradalby): Get rid of this hack, we currently need to give us some
|
||||||
// to inject the headscale configuration further down.
|
// to inject the headscale configuration further down.
|
||||||
Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; headscale serve ; /bin/sleep 30"},
|
Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; update-ca-certificates ; headscale serve ; /bin/sleep 30"},
|
||||||
Env: env,
|
Env: env,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -351,6 +402,14 @@ func New(
|
||||||
|
|
||||||
hsic.container = container
|
hsic.container = container
|
||||||
|
|
||||||
|
// Write the CA certificates to the container
|
||||||
|
for i, cert := range hsic.caCerts {
|
||||||
|
err = hsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML()))
|
err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to write headscale config to container: %w", err)
|
return nil, fmt.Errorf("failed to write headscale config to container: %w", err)
|
||||||
|
@ -749,86 +808,3 @@ func (t *HeadscaleInContainer) SendInterrupt() error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint
|
|
||||||
func createCertificate(hostname string) ([]byte, []byte, error) {
|
|
||||||
// From:
|
|
||||||
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
|
|
||||||
|
|
||||||
ca := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(2019),
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"Headscale testing INC"},
|
|
||||||
Country: []string{"NL"},
|
|
||||||
Locality: []string{"Leiden"},
|
|
||||||
},
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(60 * time.Hour),
|
|
||||||
IsCA: true,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
||||||
x509.ExtKeyUsageClientAuth,
|
|
||||||
x509.ExtKeyUsageServerAuth,
|
|
||||||
},
|
|
||||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cert := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1658),
|
|
||||||
Subject: pkix.Name{
|
|
||||||
CommonName: hostname,
|
|
||||||
Organization: []string{"Headscale testing INC"},
|
|
||||||
Country: []string{"NL"},
|
|
||||||
Locality: []string{"Leiden"},
|
|
||||||
},
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(60 * time.Minute),
|
|
||||||
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
||||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
||||||
DNSNames: []string{hostname},
|
|
||||||
}
|
|
||||||
|
|
||||||
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
certBytes, err := x509.CreateCertificate(
|
|
||||||
rand.Reader,
|
|
||||||
cert,
|
|
||||||
ca,
|
|
||||||
&certPrivKey.PublicKey,
|
|
||||||
caPrivKey,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
certPEM := new(bytes.Buffer)
|
|
||||||
|
|
||||||
err = pem.Encode(certPEM, &pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: certBytes,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
certPrivKeyPEM := new(bytes.Buffer)
|
|
||||||
|
|
||||||
err = pem.Encode(certPrivKeyPEM, &pem.Block{
|
|
||||||
Type: "RSA PRIVATE KEY",
|
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,9 +3,16 @@ package integrationutil
|
||||||
import (
|
import (
|
||||||
"archive/tar"
|
"archive/tar"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math/big"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
|
@ -93,3 +100,86 @@ func FetchPathFromContainer(
|
||||||
|
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint
|
||||||
|
func CreateCertificate(hostname string) ([]byte, []byte, error) {
|
||||||
|
// From:
|
||||||
|
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
|
||||||
|
|
||||||
|
ca := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(2019),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Headscale testing INC"},
|
||||||
|
Country: []string{"NL"},
|
||||||
|
Locality: []string{"Leiden"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(60 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageClientAuth,
|
||||||
|
x509.ExtKeyUsageServerAuth,
|
||||||
|
},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1658),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: hostname,
|
||||||
|
Organization: []string{"Headscale testing INC"},
|
||||||
|
Country: []string{"NL"},
|
||||||
|
Locality: []string{"Leiden"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(60 * time.Minute),
|
||||||
|
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
DNSNames: []string{hostname},
|
||||||
|
}
|
||||||
|
|
||||||
|
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certBytes, err := x509.CreateCertificate(
|
||||||
|
rand.Reader,
|
||||||
|
cert,
|
||||||
|
ca,
|
||||||
|
&certPrivKey.PublicKey,
|
||||||
|
caPrivKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM := new(bytes.Buffer)
|
||||||
|
|
||||||
|
err = pem.Encode(certPEM, &pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: certBytes,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
certPrivKeyPEM := new(bytes.Buffer)
|
||||||
|
|
||||||
|
err = pem.Encode(certPrivKeyPEM, &pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||||
|
"github.com/juanfont/headscale/integration/dsic"
|
||||||
"github.com/juanfont/headscale/integration/hsic"
|
"github.com/juanfont/headscale/integration/hsic"
|
||||||
"github.com/juanfont/headscale/integration/tsic"
|
"github.com/juanfont/headscale/integration/tsic"
|
||||||
"github.com/ory/dockertest/v3"
|
"github.com/ory/dockertest/v3"
|
||||||
|
@ -140,6 +141,7 @@ type Scenario struct {
|
||||||
// TODO(kradalby): support multiple headcales for later, currently only
|
// TODO(kradalby): support multiple headcales for later, currently only
|
||||||
// use one.
|
// use one.
|
||||||
controlServers *xsync.MapOf[string, ControlServer]
|
controlServers *xsync.MapOf[string, ControlServer]
|
||||||
|
derpServers []*dsic.DERPServerInContainer
|
||||||
|
|
||||||
users map[string]*User
|
users map[string]*User
|
||||||
|
|
||||||
|
@ -224,6 +226,13 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, derp := range s.derpServers {
|
||||||
|
err := derp.Shutdown()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to tear down derp server: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.pool.RemoveNetwork(s.network); err != nil {
|
if err := s.pool.RemoveNetwork(s.network); err != nil {
|
||||||
log.Printf("failed to remove network: %s", err)
|
log.Printf("failed to remove network: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -352,7 +361,7 @@ func (s *Scenario) CreateTailscaleNodesInUser(
|
||||||
hostname := headscale.GetHostname()
|
hostname := headscale.GetHostname()
|
||||||
|
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
tsic.WithHeadscaleTLS(cert),
|
tsic.WithCACert(cert),
|
||||||
tsic.WithHeadscaleName(hostname),
|
tsic.WithHeadscaleName(hostname),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -651,3 +660,20 @@ func (s *Scenario) WaitForTailscaleLogout() error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateDERPServer creates a new DERP server in a container.
|
||||||
|
func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) {
|
||||||
|
derp, err := dsic.New(s.pool, version, s.network, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create DERP server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = derp.WaitForRunning()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to reach DERP server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.derpServers = append(s.derpServers, derp)
|
||||||
|
|
||||||
|
return derp, nil
|
||||||
|
}
|
||||||
|
|
|
@ -30,6 +30,7 @@ type TailscaleClient interface {
|
||||||
FQDN() (string, error)
|
FQDN() (string, error)
|
||||||
Status(...bool) (*ipnstate.Status, error)
|
Status(...bool) (*ipnstate.Status, error)
|
||||||
Netmap() (*netmap.NetworkMap, error)
|
Netmap() (*netmap.NetworkMap, error)
|
||||||
|
DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error)
|
||||||
Netcheck() (*netcheck.Report, error)
|
Netcheck() (*netcheck.Report, error)
|
||||||
WaitForNeedsLogin() error
|
WaitForNeedsLogin() error
|
||||||
WaitForRunning() error
|
WaitForRunning() error
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -32,7 +33,7 @@ const (
|
||||||
defaultPingTimeout = 300 * time.Millisecond
|
defaultPingTimeout = 300 * time.Millisecond
|
||||||
defaultPingCount = 10
|
defaultPingCount = 10
|
||||||
dockerContextPath = "../."
|
dockerContextPath = "../."
|
||||||
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
|
caCertRoot = "/usr/local/share/ca-certificates"
|
||||||
dockerExecuteTimeout = 60 * time.Second
|
dockerExecuteTimeout = 60 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,6 +45,11 @@ var (
|
||||||
errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey")
|
errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey")
|
||||||
errTailscaleNotConnected = errors.New("tailscale not connected")
|
errTailscaleNotConnected = errors.New("tailscale not connected")
|
||||||
errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login")
|
errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login")
|
||||||
|
errInvalidClientConfig = errors.New("verifiably invalid client config requested")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
VersionHead = "head"
|
||||||
)
|
)
|
||||||
|
|
||||||
func errTailscaleStatus(hostname string, err error) error {
|
func errTailscaleStatus(hostname string, err error) error {
|
||||||
|
@ -65,7 +71,7 @@ type TailscaleInContainer struct {
|
||||||
fqdn string
|
fqdn string
|
||||||
|
|
||||||
// optional config
|
// optional config
|
||||||
headscaleCert []byte
|
caCerts [][]byte
|
||||||
headscaleHostname string
|
headscaleHostname string
|
||||||
withWebsocketDERP bool
|
withWebsocketDERP bool
|
||||||
withSSH bool
|
withSSH bool
|
||||||
|
@ -74,17 +80,23 @@ type TailscaleInContainer struct {
|
||||||
withExtraHosts []string
|
withExtraHosts []string
|
||||||
workdir string
|
workdir string
|
||||||
netfilter string
|
netfilter string
|
||||||
|
|
||||||
|
// build options, solely for HEAD
|
||||||
|
buildConfig TailscaleInContainerBuildConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type TailscaleInContainerBuildConfig struct {
|
||||||
|
tags []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option represent optional settings that can be given to a
|
// Option represent optional settings that can be given to a
|
||||||
// Tailscale instance.
|
// Tailscale instance.
|
||||||
type Option = func(c *TailscaleInContainer)
|
type Option = func(c *TailscaleInContainer)
|
||||||
|
|
||||||
// WithHeadscaleTLS takes the certificate of the Headscale instance
|
// WithCACert adds it to the trusted surtificate of the Tailscale container.
|
||||||
// and adds it to the trusted surtificate of the Tailscale container.
|
func WithCACert(cert []byte) Option {
|
||||||
func WithHeadscaleTLS(cert []byte) Option {
|
|
||||||
return func(tsic *TailscaleInContainer) {
|
return func(tsic *TailscaleInContainer) {
|
||||||
tsic.headscaleCert = cert
|
tsic.caCerts = append(tsic.caCerts, cert)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +125,7 @@ func WithOrCreateNetwork(network *dockertest.Network) Option {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithHeadscaleName set the name of the headscale instance,
|
// WithHeadscaleName set the name of the headscale instance,
|
||||||
// mostly useful in combination with TLS and WithHeadscaleTLS.
|
// mostly useful in combination with TLS and WithCACert.
|
||||||
func WithHeadscaleName(hsName string) Option {
|
func WithHeadscaleName(hsName string) Option {
|
||||||
return func(tsic *TailscaleInContainer) {
|
return func(tsic *TailscaleInContainer) {
|
||||||
tsic.headscaleHostname = hsName
|
tsic.headscaleHostname = hsName
|
||||||
|
@ -175,6 +187,22 @@ func WithNetfilter(state string) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithBuildTag adds an additional value to the `-tags=` parameter
|
||||||
|
// of the Go compiler, allowing callers to customize the Tailscale client build.
|
||||||
|
// This option is only meaningful when invoked on **HEAD** versions of the client.
|
||||||
|
// Attempts to use it with any other version is a bug in the calling code.
|
||||||
|
func WithBuildTag(tag string) Option {
|
||||||
|
return func(tsic *TailscaleInContainer) {
|
||||||
|
if tsic.version != VersionHead {
|
||||||
|
panic(errInvalidClientConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
tsic.buildConfig.tags = append(
|
||||||
|
tsic.buildConfig.tags, tag,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// New returns a new TailscaleInContainer instance.
|
// New returns a new TailscaleInContainer instance.
|
||||||
func New(
|
func New(
|
||||||
pool *dockertest.Pool,
|
pool *dockertest.Pool,
|
||||||
|
@ -219,18 +247,20 @@ func New(
|
||||||
}
|
}
|
||||||
|
|
||||||
if tsic.withWebsocketDERP {
|
if tsic.withWebsocketDERP {
|
||||||
|
if version != VersionHead {
|
||||||
|
return tsic, errInvalidClientConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
WithBuildTag("ts_debug_websockets")(tsic)
|
||||||
|
|
||||||
tailscaleOptions.Env = append(
|
tailscaleOptions.Env = append(
|
||||||
tailscaleOptions.Env,
|
tailscaleOptions.Env,
|
||||||
fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP),
|
fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tsic.headscaleHostname != "" {
|
tailscaleOptions.ExtraHosts = append(tailscaleOptions.ExtraHosts,
|
||||||
tailscaleOptions.ExtraHosts = []string{
|
"host.docker.internal:host-gateway")
|
||||||
"host.docker.internal:host-gateway",
|
|
||||||
fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tsic.workdir != "" {
|
if tsic.workdir != "" {
|
||||||
tailscaleOptions.WorkingDir = tsic.workdir
|
tailscaleOptions.WorkingDir = tsic.workdir
|
||||||
|
@ -245,14 +275,36 @@ func New(
|
||||||
}
|
}
|
||||||
|
|
||||||
var container *dockertest.Resource
|
var container *dockertest.Resource
|
||||||
|
|
||||||
|
if version != VersionHead {
|
||||||
|
// build options are not meaningful with pre-existing images,
|
||||||
|
// let's not lead anyone astray by pretending otherwise.
|
||||||
|
defaultBuildConfig := TailscaleInContainerBuildConfig{}
|
||||||
|
hasBuildConfig := !reflect.DeepEqual(defaultBuildConfig, tsic.buildConfig)
|
||||||
|
if hasBuildConfig {
|
||||||
|
return tsic, errInvalidClientConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch version {
|
switch version {
|
||||||
case "head":
|
case VersionHead:
|
||||||
buildOptions := &dockertest.BuildOptions{
|
buildOptions := &dockertest.BuildOptions{
|
||||||
Dockerfile: "Dockerfile.tailscale-HEAD",
|
Dockerfile: "Dockerfile.tailscale-HEAD",
|
||||||
ContextDir: dockerContextPath,
|
ContextDir: dockerContextPath,
|
||||||
BuildArgs: []docker.BuildArg{},
|
BuildArgs: []docker.BuildArg{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buildTags := strings.Join(tsic.buildConfig.tags, ",")
|
||||||
|
if len(buildTags) > 0 {
|
||||||
|
buildOptions.BuildArgs = append(
|
||||||
|
buildOptions.BuildArgs,
|
||||||
|
docker.BuildArg{
|
||||||
|
Name: "BUILD_TAGS",
|
||||||
|
Value: buildTags,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
container, err = pool.BuildAndRunWithBuildOptions(
|
container, err = pool.BuildAndRunWithBuildOptions(
|
||||||
buildOptions,
|
buildOptions,
|
||||||
tailscaleOptions,
|
tailscaleOptions,
|
||||||
|
@ -294,8 +346,8 @@ func New(
|
||||||
|
|
||||||
tsic.container = container
|
tsic.container = container
|
||||||
|
|
||||||
if tsic.hasTLS() {
|
for i, cert := range tsic.caCerts {
|
||||||
err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert)
|
err = tsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -304,10 +356,6 @@ func New(
|
||||||
return tsic, nil
|
return tsic, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TailscaleInContainer) hasTLS() bool {
|
|
||||||
return len(t.headscaleCert) != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown stops and cleans up the Tailscale container.
|
// Shutdown stops and cleans up the Tailscale container.
|
||||||
func (t *TailscaleInContainer) Shutdown() error {
|
func (t *TailscaleInContainer) Shutdown() error {
|
||||||
err := t.SaveLog("/tmp/control")
|
err := t.SaveLog("/tmp/control")
|
||||||
|
@ -682,6 +730,34 @@ func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TailscaleInContainer) DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error) {
|
||||||
|
if !util.TailscaleVersionNewerOrEqual("1.34", t.version) {
|
||||||
|
panic("tsic.DebugDERPRegion() called with unsupported version: " + t.version)
|
||||||
|
}
|
||||||
|
|
||||||
|
command := []string{
|
||||||
|
"tailscale",
|
||||||
|
"debug",
|
||||||
|
"derp",
|
||||||
|
region,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, stderr, err := t.Execute(command)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("stderr: %s\n", stderr) // nolint
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to execute tailscale debug derp command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var report ipnstate.DebugDERPRegionReport
|
||||||
|
err = json.Unmarshal([]byte(result), &report)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal tailscale derp region report: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &report, err
|
||||||
|
}
|
||||||
|
|
||||||
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
|
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
|
||||||
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
|
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
|
||||||
command := []string{
|
command := []string{
|
||||||
|
|
Loading…
Reference in a new issue