Compare commits

...

10 commits

Author SHA1 Message Date
Kristoffer Dalby
99b2ab843a
Merge 907449bd99 into edf9e25001 2024-11-22 13:30:06 +01:00
ArcticLampyrid
edf9e25001
feat: support client verify for derp (add integration tests) (#2046)
Some checks are pending
Build / build (push) Waiting to run
Build documentation / build (push) Waiting to run
Build documentation / deploy (push) Blocked by required conditions
Tests / test (push) Waiting to run
* feat: support client verify for derp

* docs: fix doc for integration test

* tests: add integration test for DERP verify endpoint

* tests: use `tailcfg.DERPMap` instead of `[]byte`

* refactor: introduce func `ContainsNodeKey`

* tests(dsic): use string builder for cmd args

* ci: fix tests order

* tests: fix derper failure

* chore: cleanup

* tests(verify-client): perfer to use `CreateHeadscaleEnv`

* refactor(verify-client): simplify error handling

* tests: fix `TestDERPVerifyEndpoint`

* refactor: make `doVerify` a seperated func

---------

Co-authored-by: 117503445 <t117503445@gmail.com>
2024-11-22 13:23:05 +01:00
Motiejus Jakštys
c6336adb01
config: loosen up BaseDomain and ServerURL checks (#2248)
* config: loosen up BaseDomain and ServerURL checks

Requirements [here][1]:

> OK:
> server_url: headscale.com, base: clients.headscale.com
> server_url: headscale.com, base: headscale.net
>
> Not OK:
> server_url: server.headscale.com, base: headscale.com
>
> Essentially we have to prevent the possibility where the headscale
> server has a URL which can also be assigned to a node.
>
> So for the Not OK scenario:
>
> if the server is: server.headscale.com, and a node joins with the name
> server, it will be assigned server.headscale.com and that will break
> the connection for nodes which will now try to connect to that node
> instead of the headscale server.

Fixes #2210

[1]: https://github.com/juanfont/headscale/issues/2210#issuecomment-2488165187

* server_url and base_domain: re-word error message, fix a one-off bug and add a test case for the bug.

* lint

* lint again
2024-11-22 13:21:44 +01:00
Kristoffer Dalby
907449bd99
policy manager
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:24:22 +01:00
Kristoffer Dalby
53cbdfc277
copy reduce test filter test to compare v1 vs v2
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:24:22 +01:00
Kristoffer Dalby
a7b2468a42
use json in TestReduceFilterRules test
This is to allow for the tests to be ran with
the new upcoming parser to ensure we get the
same input.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:24:22 +01:00
Kristoffer Dalby
38f2159c56
implement asterix, pass old parsing test
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:24:22 +01:00
Kristoffer Dalby
e5e1f15dd9
loads
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:24:22 +01:00
Kristoffer Dalby
086f6a005e
can the policy be typed at parsetime?
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-11-22 12:24:22 +01:00
enoperm
5fbf3f8327
Websocket derp test fixes (#2247)
* integration testing: add and validate build-time options for tailscale head

* fixup! integration testing: add and validate build-time options for tailscale head

integration testing: comply with linter

* fixup! fixup! integration testing: add and validate build-time options for tailscale head

integration testing: tsic.New must never return nil

* fixup! fixup! fixup! integration testing: add and validate build-time options for tailscale head

* minor fixes
2024-11-22 11:57:01 +01:00
27 changed files with 4095 additions and 342 deletions

View file

@ -38,6 +38,7 @@ jobs:
- TestNodeMoveCommand - TestNodeMoveCommand
- TestPolicyCommand - TestPolicyCommand
- TestPolicyBrokenConfigCommand - TestPolicyBrokenConfigCommand
- TestDERPVerifyEndpoint
- TestResolveMagicDNS - TestResolveMagicDNS
- TestValidateResolvConf - TestValidateResolvConf
- TestDERPServerScenario - TestDERPServerScenario

View file

@ -23,6 +23,7 @@
- Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198) - Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules applied [#2198](https://github.com/juanfont/headscale/pull/2198)
- Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199) - Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199)
- Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232) - Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pr/2232)
- Loosened up `server_url` and `base_domain` check. It was overly strict in some cases.
## 0.23.0 (2024-09-18) ## 0.23.0 (2024-09-18)

19
Dockerfile.derper Normal file
View file

@ -0,0 +1,19 @@
# For testing purposes only
FROM golang:alpine AS build-env
WORKDIR /go/src
RUN apk add --no-cache git
ARG VERSION_BRANCH=main
RUN git clone https://github.com/tailscale/tailscale.git --branch=$VERSION_BRANCH --depth=1
WORKDIR /go/src/tailscale
ARG TARGETARCH
RUN GOARCH=$TARGETARCH go install -v ./cmd/derper
FROM alpine:3.18
RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl
COPY --from=build-env /go/bin/* /usr/local/bin/
ENTRYPOINT [ "/usr/local/bin/derper" ]

View file

@ -28,7 +28,9 @@ ARG VERSION_GIT_HASH=""
ENV VERSION_GIT_HASH=$VERSION_GIT_HASH ENV VERSION_GIT_HASH=$VERSION_GIT_HASH
ARG TARGETARCH ARG TARGETARCH
RUN GOARCH=$TARGETARCH go install -ldflags="\ ARG BUILD_TAGS=""
RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\
-X tailscale.com/version.longStamp=$VERSION_LONG \ -X tailscale.com/version.longStamp=$VERSION_LONG \
-X tailscale.com/version.shortStamp=$VERSION_SHORT \ -X tailscale.com/version.shortStamp=$VERSION_SHORT \
-X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \ -X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \

View file

@ -30,6 +30,7 @@ import (
"github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policyv2"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog" zerolog "github.com/philip-bui/grpc-zerolog"
@ -89,6 +90,7 @@ type Headscale struct {
DERPServer *derpServer.DERPServer DERPServer *derpServer.DERPServer
ACLPolicy *policy.ACLPolicy ACLPolicy *policy.ACLPolicy
PolicyManager *policyv2.PolicyManager
mapper *mapper.Mapper mapper *mapper.Mapper
nodeNotifier *notifier.Notifier nodeNotifier *notifier.Notifier
@ -457,6 +459,8 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1). router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1).
Methods(http.MethodGet) Methods(http.MethodGet)
router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost)
if h.cfg.DERP.ServerEnabled { if h.cfg.DERP.ServerEnabled {
router.HandleFunc("/derp", h.DERPServer.DERPHandler) router.HandleFunc("/derp", h.DERPServer.DERPHandler)
router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler)

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -56,6 +57,65 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error)
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
} }
func (h *Headscale) handleVerifyRequest(
req *http.Request,
) (bool, error) {
body, err := io.ReadAll(req.Body)
if err != nil {
return false, fmt.Errorf("cannot read request body: %w", err)
}
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
return false, fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)
}
nodes, err := h.db.ListNodes()
if err != nil {
return false, fmt.Errorf("cannot list nodes: %w", err)
}
return nodes.ContainsNodeKey(derpAdmitClientRequest.NodePublic), nil
}
// see https://github.com/tailscale/tailscale/blob/964282d34f06ecc06ce644769c66b0b31d118340/derp/derp_server.go#L1159, Derp use verifyClientsURL to verify whether a client is allowed to connect to the DERP server.
func (h *Headscale) VerifyHandler(
writer http.ResponseWriter,
req *http.Request,
) {
if req.Method != http.MethodPost {
http.Error(writer, "Wrong method", http.StatusMethodNotAllowed)
return
}
log.Debug().
Str("handler", "/verify").
Msg("verify client")
allow, err := h.handleVerifyRequest(req)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to verify client")
http.Error(writer, "Internal error", http.StatusInternalServerError)
}
resp := tailcfg.DERPAdmitClientResponse{
Allow: allow,
}
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
err = json.NewEncoder(writer).Encode(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
// Listens in /key. // Listens in /key.
func (h *Headscale) KeyHandler( func (h *Headscale) KeyHandler(

View file

@ -1764,9 +1764,9 @@ var tsExitNodeDest = []tailcfg.NetPortRange{
}, },
} }
// hsExitNodeDest is the list of destination IP ranges that are allowed when // hsExitNodeDestForTest is the list of destination IP ranges that are allowed when
// we use headscale "autogroup:internet". // we use headscale "autogroup:internet".
var hsExitNodeDest = []tailcfg.NetPortRange{ var hsExitNodeDestForTest = []tailcfg.NetPortRange{
{IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny},
{IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny},
{IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny},
@ -1823,13 +1823,13 @@ func TestTheInternet(t *testing.T) {
internetPrefs := internetSet.Prefixes() internetPrefs := internetSet.Prefixes()
for i := range internetPrefs { for i := range internetPrefs {
if internetPrefs[i].String() != hsExitNodeDest[i].IP { if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP {
t.Errorf("prefix from internet set %q != hsExit list %q", internetPrefs[i].String(), hsExitNodeDest[i].IP) t.Errorf("prefix from internet set %q != hsExit list %q", internetPrefs[i].String(), hsExitNodeDestForTest[i].IP)
} }
} }
if len(internetPrefs) != len(hsExitNodeDest) { if len(internetPrefs) != len(hsExitNodeDestForTest) {
t.Fatalf("expected same length of prefixes, internet: %d, hsExit: %d", len(internetPrefs), len(hsExitNodeDest)) t.Fatalf("expected same length of prefixes, internet: %d, hsExit: %d", len(internetPrefs), len(hsExitNodeDestForTest))
} }
} }
@ -1838,20 +1838,27 @@ func TestReduceFilterRules(t *testing.T) {
name string name string
node *types.Node node *types.Node
peers types.Nodes peers types.Nodes
pol ACLPolicy pol string
want []tailcfg.FilterRule want []tailcfg.FilterRule
}{ }{
{ {
name: "host1-can-reach-host2-no-rules", name: "host1-can-reach-host2-no-rules",
pol: ACLPolicy{ pol: `
ACLs: []ACL{ {
"acls": [
{ {
Action: "accept", "action": "accept",
Sources: []string{"100.64.0.1"}, "proto": "",
Destinations: []string{"100.64.0.2:*"}, "src": [
}, "100.64.0.1"
}, ],
}, "dst": [
"100.64.0.2:*"
]
}
],
}
`,
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), IPv6: iap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
@ -1868,23 +1875,37 @@ func TestReduceFilterRules(t *testing.T) {
}, },
{ {
name: "1604-subnet-routers-are-preserved", name: "1604-subnet-routers-are-preserved",
pol: ACLPolicy{ pol: `
Groups: Groups{ {
"group:admins": {"user1"}, "groups": {
"group:admins": [
"user1"
]
}, },
ACLs: []ACL{ "acls": [
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:admins"}, "proto": "",
Destinations: []string{"group:admins:*"}, "src": [
"group:admins"
],
"dst": [
"group:admins:*"
]
}, },
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:admins"}, "proto": "",
Destinations: []string{"10.33.0.0/16:*"}, "src": [
}, "group:admins"
}, ],
}, "dst": [
"10.33.0.0/16:*"
]
}
],
}
`,
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
@ -1939,31 +1960,42 @@ func TestReduceFilterRules(t *testing.T) {
}, },
{ {
name: "1786-reducing-breaks-exit-nodes-the-client", name: "1786-reducing-breaks-exit-nodes-the-client",
pol: ACLPolicy{ pol: `
Hosts: Hosts{ {
// Exit node "groups": {
"internal": netip.MustParsePrefix("100.64.0.100/32"), "group:team": [
"user3",
"user2",
"user1"
]
}, },
Groups: Groups{ "hosts": {
"group:team": {"user3", "user2", "user1"}, "internal": "100.64.0.100/32"
}, },
ACLs: []ACL{ "acls": [
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"internal:*", "group:team"
}, ],
"dst": [
"internal:*"
]
}, },
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"autogroup:internet:*", "group:team"
}, ],
}, "dst": [
}, "autogroup:internet:*"
}, ]
}
],
}
`,
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.1"), IPv4: iap("100.64.0.1"),
IPv6: iap("fd7a:115c:a1e0::1"), IPv6: iap("fd7a:115c:a1e0::1"),
@ -1989,31 +2021,42 @@ func TestReduceFilterRules(t *testing.T) {
}, },
{ {
name: "1786-reducing-breaks-exit-nodes-the-exit", name: "1786-reducing-breaks-exit-nodes-the-exit",
pol: ACLPolicy{ pol: `
Hosts: Hosts{ {
// Exit node "groups": {
"internal": netip.MustParsePrefix("100.64.0.100/32"), "group:team": [
"user3",
"user2",
"user1"
]
}, },
Groups: Groups{ "hosts": {
"group:team": {"user3", "user2", "user1"}, "internal": "100.64.0.100/32"
}, },
ACLs: []ACL{ "acls": [
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"internal:*", "group:team"
}, ],
"dst": [
"internal:*"
]
}, },
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"autogroup:internet:*", "group:team"
}, ],
}, "dst": [
}, "autogroup:internet:*"
}, ]
}
],
}
`,
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
@ -2050,32 +2093,42 @@ func TestReduceFilterRules(t *testing.T) {
}, },
{ {
SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"},
DstPorts: hsExitNodeDest, DstPorts: hsExitNodeDestForTest,
}, },
}, },
}, },
{ {
name: "1786-reducing-breaks-exit-nodes-the-example-from-issue", name: "1786-reducing-breaks-exit-nodes-the-example-from-issue",
pol: ACLPolicy{ pol: `
Hosts: Hosts{ {
// Exit node "groups": {
"internal": netip.MustParsePrefix("100.64.0.100/32"), "group:team": [
"user3",
"user2",
"user1"
]
}, },
Groups: Groups{ "hosts": {
"group:team": {"user3", "user2", "user1"}, "internal": "100.64.0.100/32"
}, },
ACLs: []ACL{ "acls": [
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"internal:*", "group:team"
}, ],
"dst": [
"internal:*"
]
}, },
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"group:team"
],
"dst": [
"0.0.0.0/5:*", "0.0.0.0/5:*",
"8.0.0.0/7:*", "8.0.0.0/7:*",
"11.0.0.0/8:*", "11.0.0.0/8:*",
@ -2105,11 +2158,12 @@ func TestReduceFilterRules(t *testing.T) {
"194.0.0.0/7:*", "194.0.0.0/7:*",
"196.0.0.0/6:*", "196.0.0.0/6:*",
"200.0.0.0/5:*", "200.0.0.0/5:*",
"208.0.0.0/4:*", "208.0.0.0/4:*"
}, ]
}, }
}, ],
}, }
`,
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
@ -2186,32 +2240,43 @@ func TestReduceFilterRules(t *testing.T) {
}, },
{ {
name: "1786-reducing-breaks-exit-nodes-app-connector-like", name: "1786-reducing-breaks-exit-nodes-app-connector-like",
pol: ACLPolicy{ pol: `
Hosts: Hosts{ {
// Exit node "groups": {
"internal": netip.MustParsePrefix("100.64.0.100/32"), "group:team": [
"user3",
"user2",
"user1"
]
}, },
Groups: Groups{ "hosts": {
"group:team": {"user3", "user2", "user1"}, "internal": "100.64.0.100/32"
}, },
ACLs: []ACL{ "acls": [
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"internal:*", "group:team"
}, ],
"dst": [
"internal:*"
]
}, },
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"group:team"
],
"dst": [
"8.0.0.0/8:*", "8.0.0.0/8:*",
"16.0.0.0/8:*", "16.0.0.0/8:*"
}, ]
}, }
}, ],
}, }
`,
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
@ -2263,32 +2328,43 @@ func TestReduceFilterRules(t *testing.T) {
}, },
{ {
name: "1786-reducing-breaks-exit-nodes-app-connector-like2", name: "1786-reducing-breaks-exit-nodes-app-connector-like2",
pol: ACLPolicy{ pol: `
Hosts: Hosts{ {
// Exit node "groups": {
"internal": netip.MustParsePrefix("100.64.0.100/32"), "group:team": [
"user3",
"user2",
"user1"
]
}, },
Groups: Groups{ "hosts": {
"group:team": {"user3", "user2", "user1"}, "internal": "100.64.0.100/32"
}, },
ACLs: []ACL{ "acls": [
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"internal:*", "group:team"
}, ],
"dst": [
"internal:*"
]
}, },
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:team"}, "proto": "",
Destinations: []string{ "src": [
"group:team"
],
"dst": [
"8.0.0.0/16:*", "8.0.0.0/16:*",
"16.0.0.0/16:*", "16.0.0.0/16:*"
}, ]
}, }
}, ],
}, }
`,
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
@ -2340,25 +2416,32 @@ func TestReduceFilterRules(t *testing.T) {
}, },
{ {
name: "1817-reduce-breaks-32-mask", name: "1817-reduce-breaks-32-mask",
pol: ACLPolicy{ pol: `
Hosts: Hosts{ {
"vlan1": netip.MustParsePrefix("172.16.0.0/24"), "groups": {
"dns1": netip.MustParsePrefix("172.16.0.21/32"), "group:access": [
"user1"
]
}, },
Groups: Groups{ "hosts": {
"group:access": {"user1"}, "dns1": "172.16.0.21/32",
"vlan1": "172.16.0.0/24"
}, },
ACLs: []ACL{ "acls": [
{ {
Action: "accept", "action": "accept",
Sources: []string{"group:access"}, "proto": "",
Destinations: []string{ "src": [
"group:access"
],
"dst": [
"tag:access-servers:*", "tag:access-servers:*",
"dns1:*", "dns1:*"
}, ]
}, }
}, ],
}, }
`,
node: &types.Node{ node: &types.Node{
IPv4: iap("100.64.0.100"), IPv4: iap("100.64.0.100"),
IPv6: iap("fd7a:115c:a1e0::100"), IPv6: iap("fd7a:115c:a1e0::100"),
@ -2399,7 +2482,11 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, _ := tt.pol.CompileFilterRules( pol, err := LoadACLPolicyFromBytes([]byte(tt.pol))
if err != nil {
t.Fatalf("parsing policy: %s", err)
}
got, _ := pol.CompileFilterRules(
append(tt.peers, tt.node), append(tt.peers, tt.node),
) )

View file

@ -0,0 +1,180 @@
package policyv2
import (
"errors"
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/tailcfg"
)
var (
ErrInvalidAction = errors.New("invalid action")
)
// CompileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *Policy) CompileFilterRules(
users types.Users,
nodes types.Nodes,
) ([]tailcfg.FilterRule, error) {
if pol == nil {
return tailcfg.FilterAllowAll, nil
}
var rules []tailcfg.FilterRule
for _, acl := range pol.ACLs {
if acl.Action != "accept" {
return nil, ErrInvalidAction
}
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
if err != nil {
return nil, fmt.Errorf("resolving source ips: %w", err)
}
// TODO(kradalby): integrate type into schema
// TODO(kradalby): figure out the _ is wildcard stuff
protocols, _, err := parseProtocol(acl.Protocol)
if err != nil {
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
}
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, users, nodes)
if err != nil {
return nil, err
}
for _, pref := range ips.Prefixes() {
for _, port := range dest.Ports {
pr := tailcfg.NetPortRange{
IP: pref.String(),
Ports: port,
}
destPorts = append(destPorts, pr)
}
}
}
rules = append(rules, tailcfg.FilterRule{
SrcIPs: ipSetToPrefixStringList(srcIPs),
DstPorts: destPorts,
IPProto: protocols,
})
}
return rules, nil
}
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
return tailcfg.SSHAction{
Reject: !accept,
Accept: accept,
SessionDuration: duration,
AllowAgentForwarding: true,
AllowLocalPortForwarding: true,
}
}
func (pol *Policy) CompileSSHPolicy(
users types.Users,
node types.Node,
nodes types.Nodes,
) (*tailcfg.SSHPolicy, error) {
if pol == nil {
return nil, nil
}
var rules []*tailcfg.SSHRule
for index, rule := range pol.SSHs {
var dest netipx.IPSetBuilder
for _, src := range rule.Destinations {
ips, err := src.Resolve(pol, users, nodes)
if err != nil {
return nil, err
}
dest.AddSet(ips)
}
destSet, err := dest.IPSet()
if err != nil {
return nil, err
}
if !node.InIPSet(destSet) {
continue
}
var action tailcfg.SSHAction
switch rule.Action {
case "accept":
action = sshAction(true, 0)
case "check":
action = sshAction(true, rule.CheckPeriod)
default:
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
}
var principals []*tailcfg.SSHPrincipal
for _, src := range rule.Sources {
if isWildcard(rawSrc) {
principals = append(principals, &tailcfg.SSHPrincipal{
Any: true,
})
} else if isGroup(rawSrc) {
users, err := pol.expandUsersFromGroup(rawSrc)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding user from group, index: %d->%d: %w", index, innerIndex, err)
}
for _, user := range users {
principals = append(principals, &tailcfg.SSHPrincipal{
UserLogin: user,
})
}
} else {
expandedSrcs, err := pol.ExpandAlias(
peers,
rawSrc,
)
if err != nil {
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
}
for _, expandedSrc := range expandedSrcs.Prefixes() {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc.Addr().String(),
})
}
}
}
userMap := make(map[string]string, len(rule.Users))
for _, user := range rule.Users {
userMap[user] = "="
}
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}
return &tailcfg.SSHPolicy{
Rules: rules,
}, nil
}
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
var out []string
for _, pref := range ips.Prefixes() {
out = append(out, pref.String())
}
return out
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,80 @@
package policyv2
import (
"fmt"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
)
type PolicyManager struct {
mu sync.Mutex
pol *Policy
users []types.User
nodes types.Nodes
filter []tailcfg.FilterRule
// TODO(kradalby): Implement SSH policy
sshPolicy *tailcfg.SSHPolicy
}
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
// It returns an error if the policy file is invalid.
// The policy manager will update the filter rules based on the users and nodes.
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
policy, err := policyFromBytes(b)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
}
pm := PolicyManager{
pol: policy,
users: users,
nodes: nodes,
}
err = pm.updateLocked()
if err != nil {
return nil, err
}
return &pm, nil
}
// Filter returns the current filter rules for the entire tailnet.
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter
}
// updateLocked updates the filter rules based on the current policy and nodes.
// It must be called with the lock held.
func (pm *PolicyManager) updateLocked() error {
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
if err != nil {
return fmt.Errorf("compiling filter rules: %w", err)
}
pm.filter = filter
return nil
}
// SetUsers updates the users in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetUsers(users []types.User) error {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetNodes(nodes types.Nodes) error {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}

View file

@ -0,0 +1,58 @@
package policyv2
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
return &types.Node{
ID: 0,
Hostname: name,
IPv4: ap(ipv4),
IPv6: ap(ipv6),
User: user,
UserID: user.ID,
Hostinfo: hostinfo,
}
}
func TestPolicyManager(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
}
tests := []struct {
name string
pol string
nodes types.Nodes
wantFilter []tailcfg.FilterRule
}{
{
name: "empty-policy",
pol: "{}",
nodes: types.Nodes{},
wantFilter: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
require.NoError(t, err)
filter := pm.Filter()
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
}
// TODO(kradalby): Test SSH Policy
})
}
}

821
hscontrol/policyv2/types.go Normal file
View file

@ -0,0 +1,821 @@
package policyv2
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/netip"
"strconv"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/tailscale/hujson"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var theInternetSet *netipx.IPSet
// theInternet returns the IPSet for the Internet.
// https://www.youtube.com/watch?v=iDbyYGrswtg
func theInternet() *netipx.IPSet {
if theInternetSet != nil {
return theInternetSet
}
var internetBuilder netipx.IPSetBuilder
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
internetBuilder.AddPrefix(tsaddr.AllIPv4())
// Delete Private network addresses
// https://datatracker.ietf.org/doc/html/rfc1918
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
// Delete Tailscale networks
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
// Delete "cant find DHCP networks"
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-loca
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
theInternetSet, _ := internetBuilder.IPSet()
return theInternetSet
}
type Asterix int
func (a Asterix) Validate() error {
return nil
}
func (a Asterix) String() string {
return "*"
}
func (a Asterix) UnmarshalJSON(b []byte) error {
return nil
}
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
ips.AddPrefix(tsaddr.AllIPv4())
ips.AddPrefix(tsaddr.AllIPv6())
return ips.IPSet()
}
// Username is a string that represents a username, it must contain an @.
type Username string
func (u Username) Validate() error {
if strings.Contains(string(u), "@") {
return nil
}
return fmt.Errorf("Username has to contain @, got: %q", u)
}
func (u *Username) String() string {
return string(*u)
}
func (u *Username) UnmarshalJSON(b []byte) error {
*u = Username(strings.Trim(string(b), `"`))
if err := u.Validate(); err != nil {
return err
}
return nil
}
func (u Username) CanBeTagOwner() bool {
return true
}
func (u Username) resolveUser(users types.Users) (*types.User, error) {
var potentialUsers types.Users
for _, user := range users {
if user.ProviderIdentifier == string(u) {
potentialUsers = append(potentialUsers, user)
break
}
if user.Email == string(u) {
potentialUsers = append(potentialUsers, user)
}
if user.Name == string(u) {
potentialUsers = append(potentialUsers, user)
}
}
if len(potentialUsers) > 1 {
return nil, fmt.Errorf("unable to resolve user identifier to distinct: %s matched multiple %s", u, potentialUsers)
} else if len(potentialUsers) == 0 {
return nil, fmt.Errorf("unable to resolve user identifier, no user found: %s not in %s", u, users)
}
user := potentialUsers[0]
return &user, nil
}
func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
user, err := u.resolveUser(users)
if err != nil {
return nil, err
}
for _, node := range nodes {
if node.IsTagged() {
continue
}
if node.User.ID == user.ID {
node.AppendToIPSet(&ips)
}
}
return ips.IPSet()
}
// Group is a special string which is always prefixed with `group:`
type Group string
func (g Group) Validate() error {
if strings.HasPrefix(string(g), "group:") {
return nil
}
return fmt.Errorf(`Group has to start with "group:", got: %q`, g)
}
func (g *Group) UnmarshalJSON(b []byte) error {
*g = Group(strings.Trim(string(b), `"`))
if err := g.Validate(); err != nil {
return err
}
return nil
}
func (g Group) CanBeTagOwner() bool {
return true
}
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
for _, user := range p.Groups[g] {
uips, err := user.Resolve(nil, users, nodes)
if err != nil {
return nil, err
}
ips.AddSet(uips)
}
return ips.IPSet()
}
// Tag is a special string which is always prefixed with `tag:`
type Tag string
func (t Tag) Validate() error {
if strings.HasPrefix(string(t), "tag:") {
return nil
}
return fmt.Errorf(`tag has to start with "tag:", got: %q`, t)
}
func (t *Tag) UnmarshalJSON(b []byte) error {
*t = Tag(strings.Trim(string(b), `"`))
if err := t.Validate(); err != nil {
return err
}
return nil
}
func (t Tag) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
for _, node := range nodes {
if node.HasTag(string(t)) {
node.AppendToIPSet(&ips)
}
}
return ips.IPSet()
}
// Host is a string that represents a hostname.
type Host string
func (h Host) Validate() error {
return nil
}
func (h *Host) UnmarshalJSON(b []byte) error {
*h = Host(strings.Trim(string(b), `"`))
if err := h.Validate(); err != nil {
return err
}
return nil
}
func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
pref, ok := p.Hosts[h]
if !ok {
return nil, fmt.Errorf("unable to resolve host: %q", h)
}
err := pref.Validate()
if err != nil {
return nil, err
}
// If the IP is a single host, look for a node to ensure we add all the IPs of
// the node to the IPSet.
appendIfNodeHasIP(nodes, &ips, pref)
ips.AddPrefix(netip.Prefix(pref))
return ips.IPSet()
}
func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref Prefix) {
if netip.Prefix(pref).IsSingleIP() {
addr := netip.Prefix(pref).Addr()
for _, node := range nodes {
if node.HasIP(addr) {
node.AppendToIPSet(ips)
}
}
}
}
type Prefix netip.Prefix
func (p Prefix) Validate() error {
if !netip.Prefix(p).IsValid() {
return fmt.Errorf("Prefix %q is invalid", p)
}
return nil
}
func (p Prefix) String() string {
return netip.Prefix(p).String()
}
func (p *Prefix) parseString(addr string) error {
if !strings.Contains(addr, "/") {
addr, err := netip.ParseAddr(addr)
if err != nil {
return err
}
addrPref, err := addr.Prefix(addr.BitLen())
if err != nil {
return err
}
*p = Prefix(addrPref)
return nil
}
pref, err := netip.ParsePrefix(addr)
if err != nil {
return err
}
*p = Prefix(pref)
return nil
}
func (p *Prefix) UnmarshalJSON(b []byte) error {
err := p.parseString(strings.Trim(string(b), `"`))
if err != nil {
return err
}
if err := p.Validate(); err != nil {
return err
}
return nil
}
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
appendIfNodeHasIP(nodes, &ips, p)
ips.AddPrefix(netip.Prefix(p))
return ips.IPSet()
}
// AutoGroup is a special string which is always prefixed with `autogroup:`
type AutoGroup string
const (
AutoGroupInternet = "autogroup:internet"
)
var autogroups = []string{AutoGroupInternet}
func (ag AutoGroup) Validate() error {
for _, valid := range autogroups {
if valid == string(ag) {
return nil
}
}
return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups)
}
func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
*ag = AutoGroup(strings.Trim(string(b), `"`))
if err := ag.Validate(); err != nil {
return err
}
return nil
}
func (ag AutoGroup) Resolve(_ *Policy, _ types.Users, _ types.Nodes) (*netipx.IPSet, error) {
switch ag {
case AutoGroupInternet:
return theInternet(), nil
}
return nil, nil
}
type Alias interface {
Validate() error
UnmarshalJSON([]byte) error
Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error)
}
type AliasWithPorts struct {
Alias
Ports []tailcfg.PortRange
}
func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
return err
}
switch vs := v.(type) {
case string:
var portsPart string
var err error
if strings.Contains(vs, ":") {
vs, portsPart, err = splitDestination(vs)
if err != nil {
return err
}
ports, err := parsePorts(portsPart)
if err != nil {
return err
}
ve.Ports = ports
}
ve.Alias = parseAlias(vs)
if ve.Alias == nil {
return fmt.Errorf("could not determine the type of %q", vs)
}
if err := ve.Alias.Validate(); err != nil {
return err
}
default:
return fmt.Errorf("type %T not supported", vs)
}
return nil
}
func parseAlias(vs string) Alias {
// case netip.Addr:
// ve.Alias = Addr(val)
// case netip.Prefix:
// ve.Alias = Prefix(val)
var pref Prefix
err := pref.parseString(vs)
if err == nil {
return &pref
}
switch {
case vs == "*":
return Asterix(0)
case strings.Contains(vs, "@"):
return ptr.To(Username(vs))
case strings.HasPrefix(vs, "group:"):
return ptr.To(Group(vs))
case strings.HasPrefix(vs, "tag:"):
return ptr.To(Tag(vs))
case strings.HasPrefix(vs, "autogroup:"):
return ptr.To(AutoGroup(vs))
}
if !strings.Contains(vs, "@") && !strings.Contains(vs, ":") {
return ptr.To(Host(vs))
}
return nil
}
// AliasEnc is used to deserialize a Alias.
type AliasEnc struct{ Alias }
func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
return err
}
switch val := v.(type) {
case string:
ve.Alias = parseAlias(val)
if ve.Alias == nil {
return fmt.Errorf("could not determine the type of %q", val)
}
if err := ve.Alias.Validate(); err != nil {
return err
}
default:
return fmt.Errorf("type %T not supported", val)
}
return nil
}
type Aliases []Alias
func (a *Aliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
(*a)[i] = alias.Alias
}
return nil
}
func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
for _, alias := range a {
aips, err := alias.Resolve(p, users, nodes)
if err != nil {
return nil, err
}
ips.AddSet(aips)
}
return ips.IPSet()
}
type Owner interface {
CanBeTagOwner() bool
UnmarshalJSON([]byte) error
}
// OwnerEnc is used to deserialize a Owner.
type OwnerEnc struct{ Owner }
func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
// TODO(kradalby): use encoding/json/v2 (go-json-experiment)
dec := json.NewDecoder(bytes.NewReader(b))
var v any
if err := dec.Decode(&v); err != nil {
return err
}
switch val := v.(type) {
case string:
switch {
case strings.Contains(val, "@"):
ve.Owner = ptr.To(Username(val))
case strings.HasPrefix(val, "group:"):
ve.Owner = ptr.To(Group(val))
}
default:
return fmt.Errorf("type %T not supported", val)
}
return nil
}
type Owners []Owner
func (o *Owners) UnmarshalJSON(b []byte) error {
var owners []OwnerEnc
err := json.Unmarshal(b, &owners)
if err != nil {
return err
}
*o = make([]Owner, len(owners))
for i, owner := range owners {
(*o)[i] = owner.Owner
}
return nil
}
type Usernames []Username
// Groups are a map of Group to a list of Username.
type Groups map[Group]Usernames
// Hosts are alias for IP addresses or subnets.
type Hosts map[Host]Prefix
// TagOwners are a map of Tag to a list of the UserEntities that own the tag.
type TagOwners map[Tag]Owners
type AutoApprovers struct {
Routes map[string][]string `json:"routes"`
ExitNode []string `json:"exitNode"`
}
type ACL struct {
Action string `json:"action"`
Protocol string `json:"proto"`
Sources Aliases `json:"src"`
Destinations []AliasWithPorts `json:"dst"`
}
// Policy represents a Tailscale Network Policy.
// TODO(kradalby):
// Add validation method checking:
// All users exists
// All groups and users are valid tag TagOwners
// Everything referred to in ACLs exists in other
// entities.
type Policy struct {
// validated is set if the policy has been validated.
// It is not safe to use before it is validated, and
// callers using it should panic if not
validated bool `json:"-"`
Groups Groups `json:"groups"`
Hosts Hosts `json:"hosts"`
TagOwners TagOwners `json:"tagOwners"`
ACLs []ACL `json:"acls"`
AutoApprovers AutoApprovers `json:"autoApprovers"`
SSHs []SSH `json:"ssh"`
}
// SSH controls who can ssh into which machines.
type SSH struct {
Action string `json:"action"`
Sources SSHSrcAliases `json:"src"`
Destinations SSHDstAliases `json:"dst"`
Users []SSHUser `json:"users"`
CheckPeriod time.Duration `json:"checkPeriod,omitempty"`
}
// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule.
// It can be a list of usernames, groups, tags or autogroups.
type SSHSrcAliases []Alias
func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
switch alias.Alias.(type) {
case *Username, *Group, *Tag, *AutoGroup:
(*a)[i] = alias.Alias
default:
return fmt.Errorf("type %T not supported", alias.Alias)
}
}
return nil
}
// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule.
// It can be a list of usernames, tags or autogroups.
type SSHDstAliases []Alias
func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases)
if err != nil {
return err
}
*a = make([]Alias, len(aliases))
for i, alias := range aliases {
switch alias.Alias.(type) {
case *Username, *Tag, *AutoGroup:
(*a)[i] = alias.Alias
default:
return fmt.Errorf("type %T not supported", alias.Alias)
}
}
return nil
}
type SSHUser string
func policyFromBytes(b []byte) (*Policy, error) {
var policy Policy
ast, err := hujson.Parse(b)
if err != nil {
return nil, fmt.Errorf("parsing HuJSON: %w", err)
}
ast.Standardize()
acl := ast.Pack()
err = json.Unmarshal(acl, &policy)
if err != nil {
return nil, fmt.Errorf("parsing policy from bytes: %w", err)
}
return &policy, nil
}
const (
expectedTokenItems = 2
)
// TODO(kradalby): copy tests from parseDestination in policy
func splitDestination(dest string) (string, string, error) {
var tokens []string
// Check if there is a IPv4/6:Port combination, IPv6 has more than
// three ":".
tokens = strings.Split(dest, ":")
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
port := tokens[len(tokens)-1]
maybeIPv6Str := strings.TrimSuffix(dest, ":"+port)
filteredMaybeIPv6Str := maybeIPv6Str
if strings.Contains(maybeIPv6Str, "/") {
networkParts := strings.Split(maybeIPv6Str, "/")
filteredMaybeIPv6Str = networkParts[0]
}
if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() {
return "", "", fmt.Errorf(
"failed to split destination: %v",
tokens,
)
} else {
tokens = []string{maybeIPv6Str, port}
}
}
var alias string
// We can have here stuff like:
// git-server:*
// 192.168.1.0/24:22
// fd7a:115c:a1e0::2:22
// fd7a:115c:a1e0::2/128:22
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
if len(tokens) == expectedTokenItems {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
return alias, tokens[len(tokens)-1], nil
}
// TODO(kradalby): write/copy tests from expandPorts in policy
func parsePorts(portsStr string) ([]tailcfg.PortRange, error) {
if portsStr == "*" {
return []tailcfg.PortRange{
tailcfg.PortRangeAny,
}, nil
}
var ports []tailcfg.PortRange
for _, portStr := range strings.Split(portsStr, ",") {
rang := strings.Split(portStr, "-")
switch len(rang) {
case 1:
port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(port),
Last: uint16(port),
})
case expectedTokenItems:
start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(start),
Last: uint16(last),
})
default:
return nil, errors.New("invalid ports")
}
}
return ports, nil
}
// For some reason golang.org/x/net/internal/iana is an internal package.
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
protocolIPv4 = 4 // IPv4 encapsulation
protocolTCP = 6 // Transmission Control
protocolEGP = 8 // Exterior Gateway Protocol
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
protocolUDP = 17 // User Datagram
protocolGRE = 47 // Generic Routing Encapsulation
protocolESP = 50 // Encap Security Payload
protocolAH = 51 // Authentication Header
protocolIPv6ICMP = 58 // ICMP for IPv6
protocolSCTP = 132 // Stream Control Transmission Protocol
ProtocolFC = 133 // Fibre Channel
)
// parseProtocol reads the proto field of the ACL and generates a list of
// protocols that will be allowed, following the IANA IP protocol number
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
//
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
// as per Tailscale behaviour (see tailcfg.FilterRule).
//
// Also returns a boolean indicating if the protocol
// requires all the destinations to use wildcard as port number (only TCP,
// UDP and SCTP support specifying ports).
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return nil, false, nil
case "igmp":
return []int{protocolIGMP}, true, nil
case "ipv4", "ip-in-ip":
return []int{protocolIPv4}, true, nil
case "tcp":
return []int{protocolTCP}, false, nil
case "egp":
return []int{protocolEGP}, true, nil
case "igp":
return []int{protocolIGP}, true, nil
case "udp":
return []int{protocolUDP}, false, nil
case "gre":
return []int{protocolGRE}, true, nil
case "esp":
return []int{protocolESP}, true, nil
case "ah":
return []int{protocolAH}, true, nil
case "sctp":
return []int{protocolSCTP}, false, nil
case "icmp":
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
default:
protocolNumber, err := strconv.Atoi(protocol)
if err != nil {
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
}
// TODO(kradalby): What is this?
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard, nil
}
}

View file

@ -0,0 +1,555 @@
package policyv2
import (
"net/netip"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
func TestUnmarshalPolicy(t *testing.T) {
tests := []struct {
name string
input string
want *Policy
wantErr string
}{
{
name: "empty",
input: "{}",
want: &Policy{},
},
{
name: "groups",
input: `
{
"groups": {
"group:example": [
"derp@headscale.net",
],
},
}
`,
want: &Policy{
Groups: Groups{
Group("group:example"): []Username{Username("derp@headscale.net")},
},
},
},
{
name: "basic-types",
input: `
{
"groups": {
"group:example": [
"testuser@headscale.net",
],
"group:other": [
"otheruser@headscale.net",
],
},
"tagOwners": {
"tag:user": ["testuser@headscale.net"],
"tag:group": ["group:other"],
"tag:userandgroup": ["testuser@headscale.net", "group:other"],
},
"hosts": {
"host-1": "100.100.100.100",
"subnet-1": "100.100.101.100/24",
"outside": "192.168.0.0/16",
},
"acls": [
// All
{
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["*:*"],
},
// Users
{
"action": "accept",
"proto": "tcp",
"src": ["testuser@headscale.net"],
"dst": ["otheruser@headscale.net:80"],
},
// Groups
{
"action": "accept",
"proto": "tcp",
"src": ["group:example"],
"dst": ["group:other:80"],
},
// Tailscale IP
{
"action": "accept",
"proto": "tcp",
"src": ["100.101.102.103"],
"dst": ["100.101.102.104:80"],
},
// Subnet
{
"action": "accept",
"proto": "udp",
"src": ["10.0.0.0/8"],
"dst": ["172.16.0.0/16:80"],
},
// Hosts
{
"action": "accept",
"proto": "tcp",
"src": ["subnet-1"],
"dst": ["host-1:80-88"],
},
// Tags
{
"action": "accept",
"proto": "tcp",
"src": ["tag:group"],
"dst": ["tag:user:80,443"],
},
// Autogroup
{
"action": "accept",
"proto": "tcp",
"src": ["tag:group"],
"dst": ["autogroup:internet:80"],
},
],
}
`,
want: &Policy{
Groups: Groups{
Group("group:example"): []Username{Username("testuser@headscale.net")},
Group("group:other"): []Username{Username("otheruser@headscale.net")},
},
TagOwners: TagOwners{
Tag("tag:user"): Owners{up("testuser@headscale.net")},
Tag("tag:group"): Owners{gp("group:other")},
Tag("tag:userandgroup"): Owners{up("testuser@headscale.net"), gp("group:other")},
},
Hosts: Hosts{
"host-1": Prefix(netip.MustParsePrefix("100.100.100.100/32")),
"subnet-1": Prefix(netip.MustParsePrefix("100.100.101.100/24")),
"outside": Prefix(netip.MustParsePrefix("192.168.0.0/16")),
},
ACLs: []ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
// TODO(kradalby): Should this be host?
// It is:
// All traffic originating from Tailscale devices in your tailnet,
// any approved subnets and autogroup:shared.
// It does not allow traffic originating from
// non-tailscale devices (unless it is an approved route).
hp("*"),
},
Destinations: []AliasWithPorts{
{
// TODO(kradalby): Should this be host?
// It is:
// Includes any destination (no restrictions).
Alias: hp("*"),
Ports: []tailcfg.PortRange{tailcfg.PortRangeAny},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
ptr.To(Username("testuser@headscale.net")),
},
Destinations: []AliasWithPorts{
{
Alias: ptr.To(Username("otheruser@headscale.net")),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
gp("group:example"),
},
Destinations: []AliasWithPorts{
{
Alias: gp("group:other"),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
pp("100.101.102.103/32"),
},
Destinations: []AliasWithPorts{
{
Alias: pp("100.101.102.104/32"),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
{
Action: "accept",
Protocol: "udp",
Sources: Aliases{
pp("10.0.0.0/8"),
},
Destinations: []AliasWithPorts{
{
Alias: pp("172.16.0.0/16"),
Ports: []tailcfg.PortRange{{First: 80, Last: 80}},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
hp("subnet-1"),
},
Destinations: []AliasWithPorts{
{
Alias: hp("host-1"),
Ports: []tailcfg.PortRange{{First: 80, Last: 88}},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
tp("tag:group"),
},
Destinations: []AliasWithPorts{
{
Alias: tp("tag:user"),
Ports: []tailcfg.PortRange{
{First: 80, Last: 80},
{First: 443, Last: 443},
},
},
},
},
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
tp("tag:group"),
},
Destinations: []AliasWithPorts{
{
Alias: agp("autogroup:internet"),
Ports: []tailcfg.PortRange{
{First: 80, Last: 80},
},
},
},
},
},
},
},
{
name: "invalid-username",
input: `
{
"groups": {
"group:example": [
"valid@",
"invalid",
],
},
}
`,
wantErr: `Username has to contain @, got: "invalid"`,
},
{
name: "invalid-group",
input: `
{
"groups": {
"grou:example": [
"valid@",
],
},
}
`,
wantErr: `Group has to start with "group:", got: "grou:example"`,
},
{
name: "group-in-group",
input: `
{
"groups": {
"group:inner": [],
"group:example": [
"group:inner",
],
},
}
`,
wantErr: `Username has to contain @, got: "group:inner"`,
},
{
name: "invalid-prefix",
input: `
{
"hosts": {
"derp": "10.0",
},
}
`,
wantErr: `ParseAddr("10.0"): IPv4 address too short`,
},
{
name: "invalid-auto-group",
input: `
{
"acls": [
// Autogroup
{
"action": "accept",
"proto": "tcp",
"src": ["tag:group"],
"dst": ["autogroup:invalid:80"],
},
],
}
`,
wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet]`,
},
}
cmps := append(util.Comparers, cmp.Comparer(func(x, y Prefix) bool {
return x == y
}))
cmps = append(cmps, cmpopts.IgnoreUnexported(Policy{}))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
policy, err := policyFromBytes([]byte(tt.input))
// TODO(kradalby): This error checking is broken,
// but so is my brain, #longflight
if err == nil {
if tt.wantErr == "" {
return
}
t.Fatalf("got success; wanted error %q", tt.wantErr)
}
if err.Error() != tt.wantErr {
t.Fatalf("got error %q; want %q", err, tt.wantErr)
// } else if err.Error() == tt.wantErr {
// return
}
if err != nil {
t.Fatalf("unexpected err: %q", err)
}
if diff := cmp.Diff(tt.want, &policy, cmps...); diff != "" {
t.Fatalf("unexpected policy (-want +got):\n%s", diff)
}
})
}
}
func gp(s string) *Group { return ptr.To(Group(s)) }
func up(s string) *Username { return ptr.To(Username(s)) }
func hp(s string) *Host { return ptr.To(Host(s)) }
func tp(s string) *Tag { return ptr.To(Tag(s)) }
func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) }
func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) }
func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) }
func pp(pref string) *Prefix { return ptr.To(Prefix(netip.MustParsePrefix(pref))) }
func p(pref string) Prefix { return Prefix(netip.MustParsePrefix(pref)) }
func TestResolvePolicy(t *testing.T) {
tests := []struct {
name string
nodes types.Nodes
pol *Policy
toResolve Alias
want []netip.Prefix
}{
{
name: "prefix",
toResolve: pp("100.100.101.101/32"),
want: []netip.Prefix{mp("100.100.101.101/32")},
},
{
name: "host",
pol: &Policy{
Hosts: Hosts{
"testhost": p("100.100.101.102/32"),
},
},
toResolve: hp("testhost"),
want: []netip.Prefix{mp("100.100.101.102/32")},
},
{
name: "username",
toResolve: ptr.To(Username("testuser")),
nodes: types.Nodes{
// Not matching other user
{
User: types.User{
Name: "notme",
},
IPv4: ap("100.100.101.1"),
},
// Not matching forced tags
{
User: types.User{
Name: "testuser",
},
ForcedTags: []string{"tag:anything"},
IPv4: ap("100.100.101.2"),
},
// not matchin pak tag
{
User: types.User{
Name: "testuser",
},
AuthKey: &types.PreAuthKey{
Tags: []string{"alsotagged"},
},
IPv4: ap("100.100.101.3"),
},
{
User: types.User{
Name: "testuser",
},
IPv4: ap("100.100.101.103"),
},
{
User: types.User{
Name: "testuser",
},
IPv4: ap("100.100.101.104"),
},
},
want: []netip.Prefix{mp("100.100.101.103/32"), mp("100.100.101.104/32")},
},
{
name: "group",
toResolve: ptr.To(Group("group:testgroup")),
nodes: types.Nodes{
// Not matching other user
{
User: types.User{
Name: "notmetoo",
},
IPv4: ap("100.100.101.4"),
},
// Not matching forced tags
{
User: types.User{
Name: "groupuser",
},
ForcedTags: []string{"tag:anything"},
IPv4: ap("100.100.101.5"),
},
// not matchin pak tag
{
User: types.User{
Name: "groupuser",
},
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:alsotagged"},
},
IPv4: ap("100.100.101.6"),
},
{
User: types.User{
Name: "groupuser",
},
IPv4: ap("100.100.101.203"),
},
{
User: types.User{
Name: "groupuser",
},
IPv4: ap("100.100.101.204"),
},
},
pol: &Policy{
Groups: Groups{
"group:testgroup": Usernames{"groupuser"},
"group:othergroup": Usernames{"notmetoo"},
},
},
want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")},
},
{
name: "tag",
toResolve: tp("tag:test"),
nodes: types.Nodes{
// Not matching other user
{
User: types.User{
Name: "notmetoo",
},
IPv4: ap("100.100.101.9"),
},
// Not matching forced tags
{
ForcedTags: []string{"tag:anything"},
IPv4: ap("100.100.101.10"),
},
// not matchin pak tag
{
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:alsotagged"},
},
IPv4: ap("100.100.101.11"),
},
// Not matching forced tags
{
ForcedTags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
},
// not matchin pak tag
{
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:test"},
},
IPv4: ap("100.100.101.239"),
},
},
// TODO(kradalby): tests handling TagOwners + hostinfo
pol: &Policy{},
want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ips, err := tt.toResolve.Resolve(tt.pol,
types.Users{},
tt.nodes)
if err != nil {
t.Fatalf("failed to resolve: %s", err)
}
prefs := ips.Prefixes()
if diff := cmp.Diff(tt.want, prefs, util.Comparers...); diff != "" {
t.Fatalf("unexpected prefs (-want +got):\n%s", diff)
}
})
}
}

View file

@ -28,8 +28,9 @@ const (
maxDuration time.Duration = 1<<63 - 1 maxDuration time.Duration = 1<<63 - 1
) )
var errOidcMutuallyExclusive = errors.New( var (
"oidc_client_secret and oidc_client_secret_path are mutually exclusive", errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
) )
type IPAllocationStrategy string type IPAllocationStrategy string
@ -827,11 +828,10 @@ func LoadServerConfig() (*Config, error) {
// - DERP run on their own domains // - DERP run on their own domains
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com // - Control plane runs on login.tailscale.com/controlplane.tailscale.com
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net) // - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
if dnsConfig.BaseDomain != "" && if dnsConfig.BaseDomain != "" {
strings.Contains(serverURL, dnsConfig.BaseDomain) { if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil {
return nil, errors.New( return nil, err
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.", }
)
} }
return &Config{ return &Config{
@ -924,6 +924,37 @@ func LoadServerConfig() (*Config, error) {
}, nil }, nil
} }
// BaseDomain cannot be a suffix of the server URL.
// This is because Tailscale takes over the domain in BaseDomain,
// causing the headscale server and DERP to be unreachable.
// For Tailscale upstream, the following is true:
// - DERP run on their own domains.
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com.
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net).
func isSafeServerURL(serverURL, baseDomain string) error {
server, err := url.Parse(serverURL)
if err != nil {
return err
}
serverDomainParts := strings.Split(server.Host, ".")
baseDomainParts := strings.Split(baseDomain, ".")
if len(serverDomainParts) <= len(baseDomainParts) {
return nil
}
s := len(serverDomainParts)
b := len(baseDomainParts)
for i := range len(baseDomainParts) {
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
return nil
}
}
return errServerURLSuffix
}
type deprecator struct { type deprecator struct {
warns set.Set[string] warns set.Set[string]
fatals set.Set[string] fatals set.Set[string]

View file

@ -1,6 +1,7 @@
package types package types
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -139,7 +140,7 @@ func TestReadConfig(t *testing.T) {
return LoadServerConfig() return LoadServerConfig()
}, },
want: nil, want: nil,
wantErr: "server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.", wantErr: errServerURLSuffix.Error(),
}, },
{ {
name: "base-domain-not-in-server-url", name: "base-domain-not-in-server-url",
@ -333,3 +334,64 @@ tls_letsencrypt_challenge_type: TLS-ALPN-01
err = LoadConfig(tmpDir, false) err = LoadConfig(tmpDir, false)
assert.NoError(t, err) assert.NoError(t, err)
} }
// OK
// server_url: headscale.com, base: clients.headscale.com
// server_url: headscale.com, base: headscale.net
//
// NOT OK
// server_url: server.headscale.com, base: headscale.com.
func TestSafeServerURL(t *testing.T) {
tests := []struct {
serverURL, baseDomain,
wantErr string
}{
{
serverURL: "https://example.com",
baseDomain: "example.org",
},
{
serverURL: "https://headscale.com",
baseDomain: "headscale.com",
},
{
serverURL: "https://headscale.com",
baseDomain: "clients.headscale.com",
},
{
serverURL: "https://headscale.com",
baseDomain: "clients.subdomain.headscale.com",
},
{
serverURL: "https://headscale.kristoffer.com",
baseDomain: "mybase",
},
{
serverURL: "https://server.headscale.com",
baseDomain: "headscale.com",
wantErr: errServerURLSuffix.Error(),
},
{
serverURL: "https://server.subdomain.headscale.com",
baseDomain: "headscale.com",
wantErr: errServerURLSuffix.Error(),
},
{
serverURL: "http://foo\x00",
wantErr: `parse "http://foo\x00": net/url: invalid control character in URL`,
},
}
for _, tt := range tests {
testName := fmt.Sprintf("server=%s domain=%s", tt.serverURL, tt.baseDomain)
t.Run(testName, func(t *testing.T) {
err := isSafeServerURL(tt.serverURL, tt.baseDomain)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
return
}
assert.NoError(t, err)
})
}
}

View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -134,6 +135,60 @@ func (node *Node) IPs() []netip.Addr {
return ret return ret
} }
// HasIP reports if a node has a given IP address.
func (node *Node) HasIP(i netip.Addr) bool {
for _, ip := range node.IPs() {
if ip.Compare(i) == 0 {
return true
}
}
return false
}
// IsTagged reports if a device is tagged
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
func (node *Node) IsTagged() bool {
if len(node.ForcedTags) > 0 {
return true
}
if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 {
return true
}
if node.Hostinfo == nil {
return false
}
// TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags.
// Do this in other work.
return false
}
// HasTag reports if a node has a given tag.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys)
func (node *Node) HasTag(tag string) bool {
if slices.Contains(node.ForcedTags, tag) {
return true
}
if node.AuthKey != nil && slices.Contains(node.AuthKey.Tags, tag) {
return true
}
// TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags.
// Do this in other work.
return false
}
func (node *Node) Prefixes() []netip.Prefix { func (node *Node) Prefixes() []netip.Prefix {
addrs := []netip.Prefix{} addrs := []netip.Prefix{}
for _, nodeAddress := range node.IPs() { for _, nodeAddress := range node.IPs() {
@ -223,6 +278,16 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
return found return found
} }
func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool {
for _, node := range nodes {
if node.NodeKey == nodeKey {
return true
}
}
return false
}
func (node *Node) Proto() *v1.Node { func (node *Node) Proto() *v1.Node {
nodeProto := &v1.Node{ nodeProto := &v1.Node{
Id: uint64(node.ID), Id: uint64(node.ID),

View file

@ -8,7 +8,7 @@ prefixes:
database: database:
type: sqlite3 type: sqlite3
server_url: "https://derp.no" server_url: "https://server.derp.no"
dns: dns:
magic_dns: true magic_dns: true

View file

@ -2,7 +2,9 @@ package types
import ( import (
"cmp" "cmp"
"fmt"
"strconv" "strconv"
"strings"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@ -13,6 +15,19 @@ import (
type UserID uint64 type UserID uint64
type Users []User
func (u Users) String() string {
var sb strings.Builder
sb.WriteString("[ ")
for _, user := range u {
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
}
sb.WriteString(" ]")
return sb.String()
}
// User is the way Headscale implements the concept of users in Tailscale // User is the way Headscale implements the concept of users in Tailscale
// //
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users // At the end of the day, users in Tailscale are some kind of 'bubbles' or users

View file

@ -11,10 +11,10 @@ Tests are located in files ending with `_test.go` and the framework are located
## Running integration tests locally ## Running integration tests locally
The easiest way to run tests locally is to use `[act](INSERT LINK)`, a local GitHub Actions runner: The easiest way to run tests locally is to use [act](https://github.com/nektos/act), a local GitHub Actions runner:
``` ```
act pull_request -W .github/workflows/test-integration-v2-TestPingAllByIP.yaml act pull_request -W .github/workflows/test-integration.yaml
``` ```
Alternatively, the `docker run` command in each GitHub workflow file can be used. Alternatively, the `docker run` command in each GitHub workflow file can be used.

View file

@ -0,0 +1,96 @@
package integration
import (
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"testing"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dsic"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic"
"tailscale.com/tailcfg"
)
func TestDERPVerifyEndpoint(t *testing.T) {
IntegrationSkip(t)
// Generate random hostname for the headscale instance
hash, err := util.GenerateRandomStringDNSSafe(6)
assertNoErr(t, err)
testName := "derpverify"
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
headscalePort := 8080
// Create cert for headscale
certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname)
assertNoErr(t, err)
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.ShutdownAssertNoPanics(t)
spec := map[string]int{
"user1": len(MustTestVersions),
}
derper, err := scenario.CreateDERPServer("head",
dsic.WithCACert(certHeadscale),
dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))),
)
assertNoErr(t, err)
derpMap := tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
900: {
RegionID: 900,
RegionCode: "test-derpverify",
RegionName: "TestDerpVerify",
Nodes: []*tailcfg.DERPNode{
{
Name: "TestDerpVerify",
RegionID: 900,
HostName: derper.GetHostname(),
STUNPort: derper.GetSTUNPort(),
STUNOnly: false,
DERPPort: derper.GetDERPPort(),
},
},
},
},
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())},
hsic.WithHostname(hostname),
hsic.WithPort(headscalePort),
hsic.WithCustomTLS(certHeadscale, keyHeadscale),
hsic.WithHostnameAsServerURL(),
hsic.WithDERPConfig(derpMap))
assertNoErrHeadscaleEnv(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
for _, client := range allClients {
report, err := client.DebugDERPRegion("test-derpverify")
assertNoErr(t, err)
successful := false
for _, line := range report.Info {
if strings.Contains(line, "Successfully established a DERP connection with node") {
successful = true
break
}
}
if !successful {
stJSON, err := json.Marshal(report)
assertNoErr(t, err)
t.Errorf("Client %s could not establish a DERP connection: %s", client.Hostname(), string(stJSON))
}
}
}

321
integration/dsic/dsic.go Normal file
View file

@ -0,0 +1,321 @@
package dsic
import (
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
)
const (
dsicHashLength = 6
dockerContextPath = "../."
caCertRoot = "/usr/local/share/ca-certificates"
DERPerCertRoot = "/usr/local/share/derper-certs"
dockerExecuteTimeout = 60 * time.Second
)
var errDERPerStatusCodeNotOk = errors.New("DERPer status code not OK")
// DERPServerInContainer represents DERP Server in Container (DSIC).
type DERPServerInContainer struct {
version string
hostname string
pool *dockertest.Pool
container *dockertest.Resource
network *dockertest.Network
stunPort int
derpPort int
caCerts [][]byte
tlsCert []byte
tlsKey []byte
withExtraHosts []string
withVerifyClientURL string
workdir string
}
// Option represent optional settings that can be given to a
// DERPer instance.
type Option = func(c *DERPServerInContainer)
// WithCACert adds it to the trusted surtificate of the Tailscale container.
func WithCACert(cert []byte) Option {
return func(dsic *DERPServerInContainer) {
dsic.caCerts = append(dsic.caCerts, cert)
}
}
// WithOrCreateNetwork sets the Docker container network to use with
// the DERPer instance, if the parameter is nil, a new network,
// isolating the DERPer, will be created. If a network is
// passed, the DERPer instance will join the given network.
func WithOrCreateNetwork(network *dockertest.Network) Option {
return func(tsic *DERPServerInContainer) {
if network != nil {
tsic.network = network
return
}
network, err := dockertestutil.GetFirstOrCreateNetwork(
tsic.pool,
tsic.hostname+"-network",
)
if err != nil {
log.Fatalf("failed to create network: %s", err)
}
tsic.network = network
}
}
// WithDockerWorkdir allows the docker working directory to be set.
func WithDockerWorkdir(dir string) Option {
return func(tsic *DERPServerInContainer) {
tsic.workdir = dir
}
}
// WithVerifyClientURL sets the URL to verify the client.
func WithVerifyClientURL(url string) Option {
return func(tsic *DERPServerInContainer) {
tsic.withVerifyClientURL = url
}
}
// WithExtraHosts adds extra hosts to the container.
func WithExtraHosts(hosts []string) Option {
return func(tsic *DERPServerInContainer) {
tsic.withExtraHosts = hosts
}
}
// New returns a new TailscaleInContainer instance.
func New(
pool *dockertest.Pool,
version string,
network *dockertest.Network,
opts ...Option,
) (*DERPServerInContainer, error) {
hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength)
if err != nil {
return nil, err
}
hostname := fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname)
if err != nil {
return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err)
}
dsic := &DERPServerInContainer{
version: version,
hostname: hostname,
pool: pool,
network: network,
tlsCert: tlsCert,
tlsKey: tlsKey,
stunPort: 3478, //nolint
derpPort: 443, //nolint
}
for _, opt := range opts {
opt(dsic)
}
var cmdArgs strings.Builder
fmt.Fprintf(&cmdArgs, "--hostname=%s", hostname)
fmt.Fprintf(&cmdArgs, " --certmode=manual")
fmt.Fprintf(&cmdArgs, " --certdir=%s", DERPerCertRoot)
fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort)
fmt.Fprintf(&cmdArgs, " --stun=true")
fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort)
if dsic.withVerifyClientURL != "" {
fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL)
}
runOptions := &dockertest.RunOptions{
Name: hostname,
Networks: []*dockertest.Network{dsic.network},
ExtraHosts: dsic.withExtraHosts,
// we currently need to give us some time to inject the certificate further down.
Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()},
ExposedPorts: []string{
"80/tcp",
fmt.Sprintf("%d/tcp", dsic.derpPort),
fmt.Sprintf("%d/udp", dsic.stunPort),
},
}
if dsic.workdir != "" {
runOptions.WorkingDir = dsic.workdir
}
// dockertest isnt very good at handling containers that has already
// been created, this is an attempt to make sure this container isnt
// present.
err = pool.RemoveContainerByName(hostname)
if err != nil {
return nil, err
}
var container *dockertest.Resource
buildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.derper",
ContextDir: dockerContextPath,
BuildArgs: []docker.BuildArg{},
}
switch version {
case "head":
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
Name: "VERSION_BRANCH",
Value: "main",
})
default:
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
Name: "VERSION_BRANCH",
Value: "v" + version,
})
}
container, err = pool.BuildAndRunWithBuildOptions(
buildOptions,
runOptions,
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
)
if err != nil {
return nil, fmt.Errorf(
"%s could not start tailscale DERPer container (version: %s): %w",
hostname,
version,
err,
)
}
log.Printf("Created %s container\n", hostname)
dsic.container = container
for i, cert := range dsic.caCerts {
err = dsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
if len(dsic.tlsCert) != 0 {
err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
if len(dsic.tlsKey) != 0 {
err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey)
if err != nil {
return nil, fmt.Errorf("failed to write TLS key to container: %w", err)
}
}
return dsic, nil
}
// Shutdown stops and cleans up the DERPer container.
func (t *DERPServerInContainer) Shutdown() error {
err := t.SaveLog("/tmp/control")
if err != nil {
log.Printf(
"Failed to save log from %s: %s",
t.hostname,
fmt.Errorf("failed to save log: %w", err),
)
}
return t.pool.Purge(t.container)
}
// GetCert returns the TLS certificate of the DERPer instance.
func (t *DERPServerInContainer) GetCert() []byte {
return t.tlsCert
}
// Hostname returns the hostname of the DERPer instance.
func (t *DERPServerInContainer) Hostname() string {
return t.hostname
}
// Version returns the running DERPer version of the instance.
func (t *DERPServerInContainer) Version() string {
return t.version
}
// ID returns the Docker container ID of the DERPServerInContainer
// instance.
func (t *DERPServerInContainer) ID() string {
return t.container.Container.ID
}
func (t *DERPServerInContainer) GetHostname() string {
return t.hostname
}
// GetSTUNPort returns the STUN port of the DERPer instance.
func (t *DERPServerInContainer) GetSTUNPort() int {
return t.stunPort
}
// GetDERPPort returns the DERP port of the DERPer instance.
func (t *DERPServerInContainer) GetDERPPort() int {
return t.derpPort
}
// WaitForRunning blocks until the DERPer instance is ready to be used.
func (t *DERPServerInContainer) WaitForRunning() error {
url := "https://" + net.JoinHostPort(t.GetHostname(), strconv.Itoa(t.GetDERPPort())) + "/"
log.Printf("waiting for DERPer to be ready at %s", url)
insecureTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint
client := &http.Client{Transport: insecureTransport}
return t.pool.Retry(func() error {
resp, err := client.Get(url) //nolint
if err != nil {
return fmt.Errorf("headscale is not ready: %w", err)
}
if resp.StatusCode != http.StatusOK {
return errDERPerStatusCodeNotOk
}
return nil
})
}
// ConnectToNetwork connects the DERPer instance to a network.
func (t *DERPServerInContainer) ConnectToNetwork(network *dockertest.Network) error {
return t.container.ConnectToNetwork(network)
}
// WriteFile save file inside the container.
func (t *DERPServerInContainer) WriteFile(path string, data []byte) error {
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
}
// SaveLog saves the current stdout log of the container to a path
// on the host system.
func (t *DERPServerInContainer) SaveLog(path string) error {
_, _, err := dockertestutil.SaveLog(t.pool, t.container, path)
return err
}

View file

@ -55,7 +55,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) {
spec := map[string]ClientsSpec{ spec := map[string]ClientsSpec{
"user1": { "user1": {
Plain: 0, Plain: 0,
WebsocketDERP: len(MustTestVersions), WebsocketDERP: 2,
}, },
} }
@ -239,10 +239,13 @@ func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv(
if clientCount.WebsocketDERP > 0 { if clientCount.WebsocketDERP > 0 {
// Containers that use DERP-over-WebSocket // Containers that use DERP-over-WebSocket
// Note that these clients *must* be built
// from source, which is currently
// only done for HEAD.
err = s.CreateTailscaleIsolatedNodesInUser( err = s.CreateTailscaleIsolatedNodesInUser(
hash, hash,
userName, userName,
"all", tsic.VersionHead,
clientCount.WebsocketDERP, clientCount.WebsocketDERP,
tsic.WithWebsocketDERP(true), tsic.WithWebsocketDERP(true),
) )
@ -307,7 +310,7 @@ func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser(
cert := hsServer.GetCert() cert := hsServer.GetCert()
opts = append(opts, opts = append(opts,
tsic.WithHeadscaleTLS(cert), tsic.WithCACert(cert),
) )
user.createWaitGroup.Go(func() error { user.createWaitGroup.Go(func() error {

View file

@ -1,19 +1,12 @@
package hsic package hsic
import ( import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json" "encoding/json"
"encoding/pem"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"math/big"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -32,11 +25,14 @@ import (
"github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/integrationutil"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker" "github.com/ory/dockertest/v3/docker"
"gopkg.in/yaml.v3"
"tailscale.com/tailcfg"
) )
const ( const (
hsicHashLength = 6 hsicHashLength = 6
dockerContextPath = "../." dockerContextPath = "../."
caCertRoot = "/usr/local/share/ca-certificates"
aclPolicyPath = "/etc/headscale/acl.hujson" aclPolicyPath = "/etc/headscale/acl.hujson"
tlsCertPath = "/etc/headscale/tls.cert" tlsCertPath = "/etc/headscale/tls.cert"
tlsKeyPath = "/etc/headscale/tls.key" tlsKeyPath = "/etc/headscale/tls.key"
@ -64,6 +60,7 @@ type HeadscaleInContainer struct {
// optional config // optional config
port int port int
extraPorts []string extraPorts []string
caCerts [][]byte
hostPortBindings map[string][]string hostPortBindings map[string][]string
aclPolicy *policy.ACLPolicy aclPolicy *policy.ACLPolicy
env map[string]string env map[string]string
@ -88,18 +85,29 @@ func WithACLPolicy(acl *policy.ACLPolicy) Option {
} }
} }
// WithCACert adds it to the trusted surtificate of the container.
func WithCACert(cert []byte) Option {
return func(hsic *HeadscaleInContainer) {
hsic.caCerts = append(hsic.caCerts, cert)
}
}
// WithTLS creates certificates and enables HTTPS. // WithTLS creates certificates and enables HTTPS.
func WithTLS() Option { func WithTLS() Option {
return func(hsic *HeadscaleInContainer) { return func(hsic *HeadscaleInContainer) {
cert, key, err := createCertificate(hsic.hostname) cert, key, err := integrationutil.CreateCertificate(hsic.hostname)
if err != nil { if err != nil {
log.Fatalf("failed to create certificates for headscale test: %s", err) log.Fatalf("failed to create certificates for headscale test: %s", err)
} }
// TODO(kradalby): Move somewhere appropriate hsic.tlsCert = cert
hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath hsic.tlsKey = key
hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath }
}
// WithCustomTLS uses the given certificates for the Headscale instance.
func WithCustomTLS(cert, key []byte) Option {
return func(hsic *HeadscaleInContainer) {
hsic.tlsCert = cert hsic.tlsCert = cert
hsic.tlsKey = key hsic.tlsKey = key
} }
@ -146,6 +154,13 @@ func WithTestName(testName string) Option {
} }
} }
// WithHostname sets the hostname of the Headscale instance.
func WithHostname(hostname string) Option {
return func(hsic *HeadscaleInContainer) {
hsic.hostname = hostname
}
}
// WithHostnameAsServerURL sets the Headscale ServerURL based on // WithHostnameAsServerURL sets the Headscale ServerURL based on
// the Hostname. // the Hostname.
func WithHostnameAsServerURL() Option { func WithHostnameAsServerURL() Option {
@ -203,6 +218,34 @@ func WithEmbeddedDERPServerOnly() Option {
} }
} }
// WithDERPConfig configures Headscale use a custom
// DERP server only.
func WithDERPConfig(derpMap tailcfg.DERPMap) Option {
return func(hsic *HeadscaleInContainer) {
contents, err := yaml.Marshal(derpMap)
if err != nil {
log.Fatalf("failed to marshal DERP map: %s", err)
return
}
hsic.env["HEADSCALE_DERP_PATHS"] = "/etc/headscale/derp.yml"
hsic.filesInContainer = append(hsic.filesInContainer,
fileInContainer{
path: "/etc/headscale/derp.yml",
contents: contents,
})
// Disable global DERP server and embedded DERP server
hsic.env["HEADSCALE_DERP_URLS"] = ""
hsic.env["HEADSCALE_DERP_SERVER_ENABLED"] = "false"
// Envknob for enabling DERP debug logs
hsic.env["DERP_DEBUG_LOGS"] = "true"
hsic.env["DERP_PROBER_DEBUG_LOGS"] = "true"
}
}
// WithTuning allows changing the tuning settings easily. // WithTuning allows changing the tuning settings easily.
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option { func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
return func(hsic *HeadscaleInContainer) { return func(hsic *HeadscaleInContainer) {
@ -300,6 +343,10 @@ func New(
"HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1", "HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1",
"HEADSCALE_DEBUG_DUMP_CONFIG=1", "HEADSCALE_DEBUG_DUMP_CONFIG=1",
} }
if hsic.hasTLS() {
hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath
hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath
}
for key, value := range hsic.env { for key, value := range hsic.env {
env = append(env, fmt.Sprintf("%s=%s", key, value)) env = append(env, fmt.Sprintf("%s=%s", key, value))
} }
@ -313,7 +360,7 @@ func New(
// Cmd: []string{"headscale", "serve"}, // Cmd: []string{"headscale", "serve"},
// TODO(kradalby): Get rid of this hack, we currently need to give us some // TODO(kradalby): Get rid of this hack, we currently need to give us some
// to inject the headscale configuration further down. // to inject the headscale configuration further down.
Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; headscale serve ; /bin/sleep 30"}, Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; update-ca-certificates ; headscale serve ; /bin/sleep 30"},
Env: env, Env: env,
} }
@ -351,6 +398,14 @@ func New(
hsic.container = container hsic.container = container
// Write the CA certificates to the container
for i, cert := range hsic.caCerts {
err = hsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML())) err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML()))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to write headscale config to container: %w", err) return nil, fmt.Errorf("failed to write headscale config to container: %w", err)
@ -749,86 +804,3 @@ func (t *HeadscaleInContainer) SendInterrupt() error {
return nil return nil
} }
// nolint
func createCertificate(hostname string) ([]byte, []byte, error) {
// From:
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(60 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: big.NewInt(1658),
Subject: pkix.Name{
CommonName: hostname,
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(60 * time.Minute),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
DNSNames: []string{hostname},
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(
rand.Reader,
cert,
ca,
&certPrivKey.PublicKey,
caPrivKey,
)
if err != nil {
return nil, nil, err
}
certPEM := new(bytes.Buffer)
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
if err != nil {
return nil, nil, err
}
certPrivKeyPEM := new(bytes.Buffer)
err = pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
if err != nil {
return nil, nil, err
}
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
}

View file

@ -3,9 +3,16 @@ package integrationutil
import ( import (
"archive/tar" "archive/tar"
"bytes" "bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt" "fmt"
"io" "io"
"math/big"
"path/filepath" "path/filepath"
"time"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
@ -93,3 +100,86 @@ func FetchPathFromContainer(
return buf.Bytes(), nil return buf.Bytes(), nil
} }
// nolint
func CreateCertificate(hostname string) ([]byte, []byte, error) {
// From:
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(60 * time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
cert := &x509.Certificate{
SerialNumber: big.NewInt(1658),
Subject: pkix.Name{
CommonName: hostname,
Organization: []string{"Headscale testing INC"},
Country: []string{"NL"},
Locality: []string{"Leiden"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(60 * time.Minute),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
DNSNames: []string{hostname},
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(
rand.Reader,
cert,
ca,
&certPrivKey.PublicKey,
caPrivKey,
)
if err != nil {
return nil, nil, err
}
certPEM := new(bytes.Buffer)
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
if err != nil {
return nil, nil, err
}
certPrivKeyPEM := new(bytes.Buffer)
err = pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
if err != nil {
return nil, nil, err
}
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
}

View file

@ -14,6 +14,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/dockertestutil"
"github.com/juanfont/headscale/integration/dsic"
"github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/tsic" "github.com/juanfont/headscale/integration/tsic"
"github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3"
@ -140,6 +141,7 @@ type Scenario struct {
// TODO(kradalby): support multiple headcales for later, currently only // TODO(kradalby): support multiple headcales for later, currently only
// use one. // use one.
controlServers *xsync.MapOf[string, ControlServer] controlServers *xsync.MapOf[string, ControlServer]
derpServers []*dsic.DERPServerInContainer
users map[string]*User users map[string]*User
@ -224,6 +226,13 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
} }
} }
for _, derp := range s.derpServers {
err := derp.Shutdown()
if err != nil {
log.Printf("failed to tear down derp server: %s", err)
}
}
if err := s.pool.RemoveNetwork(s.network); err != nil { if err := s.pool.RemoveNetwork(s.network); err != nil {
log.Printf("failed to remove network: %s", err) log.Printf("failed to remove network: %s", err)
} }
@ -352,7 +361,7 @@ func (s *Scenario) CreateTailscaleNodesInUser(
hostname := headscale.GetHostname() hostname := headscale.GetHostname()
opts = append(opts, opts = append(opts,
tsic.WithHeadscaleTLS(cert), tsic.WithCACert(cert),
tsic.WithHeadscaleName(hostname), tsic.WithHeadscaleName(hostname),
) )
@ -651,3 +660,20 @@ func (s *Scenario) WaitForTailscaleLogout() error {
return nil return nil
} }
// CreateDERPServer creates a new DERP server in a container.
func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) {
derp, err := dsic.New(s.pool, version, s.network, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create DERP server: %w", err)
}
err = derp.WaitForRunning()
if err != nil {
return nil, fmt.Errorf("failed to reach DERP server: %w", err)
}
s.derpServers = append(s.derpServers, derp)
return derp, nil
}

View file

@ -30,6 +30,7 @@ type TailscaleClient interface {
FQDN() (string, error) FQDN() (string, error)
Status(...bool) (*ipnstate.Status, error) Status(...bool) (*ipnstate.Status, error)
Netmap() (*netmap.NetworkMap, error) Netmap() (*netmap.NetworkMap, error)
DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error)
Netcheck() (*netcheck.Report, error) Netcheck() (*netcheck.Report, error)
WaitForNeedsLogin() error WaitForNeedsLogin() error
WaitForRunning() error WaitForRunning() error

View file

@ -12,6 +12,7 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"os" "os"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -32,7 +33,7 @@ const (
defaultPingTimeout = 300 * time.Millisecond defaultPingTimeout = 300 * time.Millisecond
defaultPingCount = 10 defaultPingCount = 10
dockerContextPath = "../." dockerContextPath = "../."
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt" caCertRoot = "/usr/local/share/ca-certificates"
dockerExecuteTimeout = 60 * time.Second dockerExecuteTimeout = 60 * time.Second
) )
@ -44,6 +45,11 @@ var (
errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey") errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey")
errTailscaleNotConnected = errors.New("tailscale not connected") errTailscaleNotConnected = errors.New("tailscale not connected")
errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login") errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login")
errInvalidClientConfig = errors.New("verifiably invalid client config requested")
)
const (
VersionHead = "head"
) )
func errTailscaleStatus(hostname string, err error) error { func errTailscaleStatus(hostname string, err error) error {
@ -65,7 +71,7 @@ type TailscaleInContainer struct {
fqdn string fqdn string
// optional config // optional config
headscaleCert []byte caCerts [][]byte
headscaleHostname string headscaleHostname string
withWebsocketDERP bool withWebsocketDERP bool
withSSH bool withSSH bool
@ -74,17 +80,23 @@ type TailscaleInContainer struct {
withExtraHosts []string withExtraHosts []string
workdir string workdir string
netfilter string netfilter string
// build options, solely for HEAD
buildConfig TailscaleInContainerBuildConfig
}
type TailscaleInContainerBuildConfig struct {
tags []string
} }
// Option represent optional settings that can be given to a // Option represent optional settings that can be given to a
// Tailscale instance. // Tailscale instance.
type Option = func(c *TailscaleInContainer) type Option = func(c *TailscaleInContainer)
// WithHeadscaleTLS takes the certificate of the Headscale instance // WithCACert adds it to the trusted surtificate of the Tailscale container.
// and adds it to the trusted surtificate of the Tailscale container. func WithCACert(cert []byte) Option {
func WithHeadscaleTLS(cert []byte) Option {
return func(tsic *TailscaleInContainer) { return func(tsic *TailscaleInContainer) {
tsic.headscaleCert = cert tsic.caCerts = append(tsic.caCerts, cert)
} }
} }
@ -113,7 +125,7 @@ func WithOrCreateNetwork(network *dockertest.Network) Option {
} }
// WithHeadscaleName set the name of the headscale instance, // WithHeadscaleName set the name of the headscale instance,
// mostly useful in combination with TLS and WithHeadscaleTLS. // mostly useful in combination with TLS and WithCACert.
func WithHeadscaleName(hsName string) Option { func WithHeadscaleName(hsName string) Option {
return func(tsic *TailscaleInContainer) { return func(tsic *TailscaleInContainer) {
tsic.headscaleHostname = hsName tsic.headscaleHostname = hsName
@ -175,6 +187,22 @@ func WithNetfilter(state string) Option {
} }
} }
// WithBuildTag adds an additional value to the `-tags=` parameter
// of the Go compiler, allowing callers to customize the Tailscale client build.
// This option is only meaningful when invoked on **HEAD** versions of the client.
// Attempts to use it with any other version is a bug in the calling code.
func WithBuildTag(tag string) Option {
return func(tsic *TailscaleInContainer) {
if tsic.version != VersionHead {
panic(errInvalidClientConfig)
}
tsic.buildConfig.tags = append(
tsic.buildConfig.tags, tag,
)
}
}
// New returns a new TailscaleInContainer instance. // New returns a new TailscaleInContainer instance.
func New( func New(
pool *dockertest.Pool, pool *dockertest.Pool,
@ -219,18 +247,20 @@ func New(
} }
if tsic.withWebsocketDERP { if tsic.withWebsocketDERP {
if version != VersionHead {
return tsic, errInvalidClientConfig
}
WithBuildTag("ts_debug_websockets")(tsic)
tailscaleOptions.Env = append( tailscaleOptions.Env = append(
tailscaleOptions.Env, tailscaleOptions.Env,
fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP), fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP),
) )
} }
if tsic.headscaleHostname != "" { tailscaleOptions.ExtraHosts = append(tailscaleOptions.ExtraHosts,
tailscaleOptions.ExtraHosts = []string{ "host.docker.internal:host-gateway")
"host.docker.internal:host-gateway",
fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname),
}
}
if tsic.workdir != "" { if tsic.workdir != "" {
tailscaleOptions.WorkingDir = tsic.workdir tailscaleOptions.WorkingDir = tsic.workdir
@ -245,14 +275,36 @@ func New(
} }
var container *dockertest.Resource var container *dockertest.Resource
if version != VersionHead {
// build options are not meaningful with pre-existing images,
// let's not lead anyone astray by pretending otherwise.
defaultBuildConfig := TailscaleInContainerBuildConfig{}
hasBuildConfig := !reflect.DeepEqual(defaultBuildConfig, tsic.buildConfig)
if hasBuildConfig {
return tsic, errInvalidClientConfig
}
}
switch version { switch version {
case "head": case VersionHead:
buildOptions := &dockertest.BuildOptions{ buildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.tailscale-HEAD", Dockerfile: "Dockerfile.tailscale-HEAD",
ContextDir: dockerContextPath, ContextDir: dockerContextPath,
BuildArgs: []docker.BuildArg{}, BuildArgs: []docker.BuildArg{},
} }
buildTags := strings.Join(tsic.buildConfig.tags, ",")
if len(buildTags) > 0 {
buildOptions.BuildArgs = append(
buildOptions.BuildArgs,
docker.BuildArg{
Name: "BUILD_TAGS",
Value: buildTags,
},
)
}
container, err = pool.BuildAndRunWithBuildOptions( container, err = pool.BuildAndRunWithBuildOptions(
buildOptions, buildOptions,
tailscaleOptions, tailscaleOptions,
@ -294,8 +346,8 @@ func New(
tsic.container = container tsic.container = container
if tsic.hasTLS() { for i, cert := range tsic.caCerts {
err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert) err = tsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
} }
@ -304,10 +356,6 @@ func New(
return tsic, nil return tsic, nil
} }
func (t *TailscaleInContainer) hasTLS() bool {
return len(t.headscaleCert) != 0
}
// Shutdown stops and cleans up the Tailscale container. // Shutdown stops and cleans up the Tailscale container.
func (t *TailscaleInContainer) Shutdown() error { func (t *TailscaleInContainer) Shutdown() error {
err := t.SaveLog("/tmp/control") err := t.SaveLog("/tmp/control")
@ -682,6 +730,34 @@ func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error
} }
} }
func (t *TailscaleInContainer) DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error) {
if !util.TailscaleVersionNewerOrEqual("1.34", t.version) {
panic("tsic.DebugDERPRegion() called with unsupported version: " + t.version)
}
command := []string{
"tailscale",
"debug",
"derp",
region,
}
result, stderr, err := t.Execute(command)
if err != nil {
fmt.Printf("stderr: %s\n", stderr) // nolint
return nil, fmt.Errorf("failed to execute tailscale debug derp command: %w", err)
}
var report ipnstate.DebugDERPRegionReport
err = json.Unmarshal([]byte(result), &report)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale derp region report: %w", err)
}
return &report, err
}
// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance. // Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance.
func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) { func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
command := []string{ command := []string{